Skip to content

Commit 9b2ec7f

Browse files
committed
add mlsmote implementation
1 parent 802caae commit 9b2ec7f

File tree

2 files changed

+204
-0
lines changed

2 files changed

+204
-0
lines changed

imblearn/over_sampling/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from ._smote import KMeansSMOTE
1111
from ._smote import SVMSMOTE
1212
from ._smote import SMOTENC
13+
from ._mlsmote import MLSMOTE
1314

1415
__all__ = [
1516
"ADASYN",
@@ -19,4 +20,5 @@
1920
"BorderlineSMOTE",
2021
"SVMSMOTE",
2122
"SMOTENC",
23+
"MLSMOTE"
2224
]

imblearn/over_sampling/_mlsmote.py

Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
import numpy as np
2+
import itertools
3+
import collections
4+
import random
5+
6+
class MLSMOTE:
7+
"""Over-sampling using MLSMOTE.
8+
9+
Parameters
10+
----------
11+
sampling_strategy: 'ranking','union' or 'intersection' default: 'ranking'
12+
Strategy to generate labelsets
13+
14+
15+
k_neighbors : int or object, default=5
16+
If ``int``, number of nearest neighbours to used to construct synthetic
17+
samples.
18+
19+
categorical_features : ndarray of shape (n_cat_features,) or (n_features,)
20+
Specified which features are categorical. Can either be:
21+
22+
- array of indices specifying the categorical features;
23+
- mask array of shape (n_features, ) and ``bool`` dtype for which
24+
``True`` indicates the categorical features.
25+
26+
Notes
27+
-----
28+
See the original papers: [1]_ for more details.
29+
30+
31+
References
32+
----------
33+
.. [1] Charte, F. & Rivera Rivas, Antonio & Del Jesus, María José & Herrera, Francisco. (2015).
34+
MLSMOTE: Approaching imbalanced multilabel learning through synthetic instance generation.
35+
Knowledge-Based Systems. -. 10.1016/j.knosys.2015.07.019.
36+
37+
"""
38+
def __init__(self,categorical_features,k_neighbors=5 ,sampling_strategy='ranking'):
39+
self.k_neighbors=k_neighbors
40+
self.sampling_strategy_=sampling_strategy
41+
self.categorical_features = categorical_features
42+
self.continuous_features_= None
43+
self.unique_labels = []
44+
self.labels=[]
45+
self.features=[]
46+
47+
def fit_resample(self,X,y):
48+
self.n_features_ = X.shape[1]
49+
self.labels=np.array([np.array(xi) for xi in y])
50+
51+
self._validate_estimator()
52+
53+
X_resampled = X.copy()
54+
y_resampled = y.copy()
55+
56+
self.unique_labels = self._collect_unique_labels(y)
57+
self.features=X
58+
59+
X_synth=[]
60+
y_synth=[]
61+
62+
append_X_synth=X_synth.append
63+
append_y_synth=y_synth.append
64+
mean_ir=self._get_mean_imbalance_ratio()
65+
for label in self.unique_labels:
66+
irlbl=self._get_imbalance_ratio_per_label(label)
67+
if irlbl > mean_ir:
68+
min_bag=self._get_all_instances_of_label(label)
69+
for sample in min_bag:
70+
distances=self._calc_distances(sample,min_bag)
71+
distances=np.sort(distances,order='distance')
72+
neighbours=distances[:self.k_neighbors]
73+
ref_neigh=np.random.choice(neighbours,1)[0]
74+
X_new,y_new=self._create_new_sample(sample,ref_neigh[1],[x[1] for x in neighbours])
75+
append_X_synth(X_new)
76+
append_y_synth(y_new)
77+
78+
return np.concatenate((X_resampled,np.array(X_synth))),np.array(y_resampled.tolist()+y_synth)
79+
80+
def _validate_estimator(self):
81+
categorical_features = np.asarray(self.categorical_features)
82+
if categorical_features.dtype.name == "bool":
83+
self.categorical_features_ = np.flatnonzero(categorical_features)
84+
else:
85+
if any(
86+
[
87+
cat not in np.arange(self.n_features_)
88+
for cat in categorical_features
89+
]
90+
):
91+
raise ValueError(
92+
"Some of the categorical indices are out of range. Indices"
93+
" should be between 0 and {}".format(self.n_features_)
94+
)
95+
self.categorical_features_ = categorical_features
96+
self.continuous_features_ = np.setdiff1d(
97+
np.arange(self.n_features_), self.categorical_features_
98+
)
99+
100+
def _collect_unique_labels(self, y):
101+
"""A support function that flattens the labelsets and return one set of unique labels"""
102+
return np.unique(np.array([a for x in y for a in (x if isinstance(x, list) else [x])]))
103+
104+
def _create_new_sample(self,sample_id,ref_neigh_id,neighbour_ids):
105+
sample=self.features[sample_id]
106+
sample_labels=self.labels[sample_id]
107+
synth_sample=np.copy(sample)
108+
ref_neigh=self.features[ref_neigh_id]
109+
neighbours_labels=[]
110+
for ni in neighbour_ids:
111+
neighbours_labels.append(self.labels[ni].tolist())
112+
for i in range(synth_sample.shape[0]):
113+
if i in self.continuous_features_:
114+
diff=ref_neigh[i]-sample[i]
115+
offset=diff*random.uniform(0,1)
116+
synth_sample[i]=sample[i]+offset
117+
if i in self.categorical_features_:
118+
synth_sample[i]=self._get_most_frequent_value(self.features[neighbour_ids,i])
119+
120+
labels=sample_labels.tolist()
121+
labels+=[a for x in neighbours_labels for a in (x if isinstance(x, list) else [x])]
122+
labels=list(set(labels))
123+
if self.sampling_strategy_=='ranking':
124+
head_index=int((self.k_neighbors+ 1)/2)
125+
y=labels[:head_index]
126+
if self.sampling_strategy_=='union':
127+
y=labels[:]
128+
if self.sampling_strategy_=='intersection':
129+
y=list(set.intersection(*neighbours_labels))
130+
131+
X=synth_sample
132+
return X,y
133+
134+
135+
def _calc_distances(self,sample,min_bag):
136+
distances=[]
137+
append_distances=distances.append
138+
for bag_sample in min_bag:
139+
nominal_distances=np.array([self._get_vdm(self.features[sample,cat],self.features[bag_sample,cat])for cat in self.categorical_features_])
140+
ordinal_distances=np.array([self._get_euclidean_distance(self.features[sample,num],self.features[bag_sample,num])for num in self.continuous_features_])
141+
dists=np.array([nominal_distances.sum(),ordinal_distances.sum()])
142+
append_distances((dists.sum(),bag_sample))
143+
dtype = np.dtype([('distance', float), ('index', int)])
144+
return np.array(distances,dtype=dtype)
145+
146+
147+
def _get_euclidean_distance(self,first,second):
148+
euclidean_distance=np.linalg.norm(first-second)
149+
return euclidean_distance
150+
151+
def _get_vdm(self,first,second):
152+
"""A support function to compute the Value Difference Metric(VDM) discribed in https://arxiv.org/pdf/cs/9701101.pdf"""
153+
def f(c):
154+
N_ax=len(np.where(self.features[:,self.categorical_features_]==first))
155+
N_ay=len(np.where(self.features[:,self.categorical_features_]==second))
156+
c_instances=self._get_all_instances_of_label(c)
157+
N_axc=len(np.where(self.features[np.ix_(c_instances,self.categorical_features_)]==first)[0])
158+
N_ayc=len(np.where(self.features[np.ix_(c_instances,self.categorical_features_)]==second)[0])
159+
return np.square(np.abs((N_axc/N_ax)-(N_ayc/N_ay)))
160+
161+
return np.sum(np.array([f(c)for c in self.unique_labels]))
162+
163+
def _get_all_instances_of_label(self,label):
164+
instance_ids=[]
165+
append_instance_id=instance_ids.append
166+
for i,label_set in enumerate(self.labels):
167+
if label in label_set:
168+
append_instance_id(i)
169+
return np.array(instance_ids)
170+
171+
def _get_mean_imbalance_ratio(self):
172+
ratio_sum=np.sum(np.array(list(map(self._get_imbalance_ratio_per_label,self.unique_labels))))
173+
return ratio_sum/self.unique_labels.shape[0]
174+
175+
def _get_imbalance_ratio_per_label(self,label):
176+
sum_array=list(map(self._sum_h,self.unique_labels))
177+
sum_array=np.array(sum_array)
178+
return sum_array.max()/self._sum_h(label)
179+
180+
def _sum_h(self,label):
181+
h_sum=0
182+
def h(l,Y):
183+
if l in Y:
184+
return 1
185+
else:
186+
return 0
187+
188+
for label_set in self.labels:
189+
h_sum+=h(label,label_set)
190+
return h_sum
191+
192+
193+
def _get_label_frequencies(self,labels):
194+
""""A support function to get the frequencies of labels"""
195+
frequency_map=np.array(np.unique(labels, return_counts=True)).T
196+
frequencies=np.array([x[1] for x in count_map])
197+
return frequencies
198+
199+
def _get_most_frequent_value(self, values):
200+
""""A support function to get most frequent value if a list of values"""
201+
uniques, indices = np.unique(values, return_inverse=True)
202+
return uniques[np.argmax(np.bincount(indices))]

0 commit comments

Comments
 (0)