data
tensor([[ 0.611, -20.199],
[ 4.455, -24.188],
[ 2.071, -20.446],
...,
[ 25.927, 6.597],
[ 18.549, 3.411],
[ 24.617, 8.485]])
This post includes a cool animation.
Salman Naqvi
Wednesday, 21 June 2023
This notebook follows the fastai style guide.
Meanshift clustering is a technique for unsupervised learning. Give this algorithm a bunch of data and it will figure out what groups the data can be sorted into. It does this by iteratively moving all data points until they converge to a single point.
The steps of the algorithm can be summarized as follows:
This is the data we will work with to illustrate meanshift clustering. The data points are put into clearly seperate clusters for the sake of clarity.
In the end, all clusters will converge at their respective center (marked by X).
Let’s start off simple and apply the algorithm to a single point.
For each data point \(x\) in the dataset, calculate the distance between \(x\) and every other data point in the dataset.
tensor([[ 0.611, -20.199],
[ 4.455, -24.188],
[ 2.071, -20.446],
...,
[ 25.927, 6.597],
[ 18.549, 3.411],
[ 24.617, 8.485]])
Each point has an \(x\) coordinate and a \(y\) coordinate.
tensor([[ 0.000, 0.000],
[ -3.844, 3.989],
[ -1.460, 0.247],
...,
[-25.316, -26.796],
[-17.938, -23.610],
[-24.006, -28.684]])
The distance metric we’ll use is Euclidean distance — also better known as Pythagoras’ theorem.
\[ \sqrt{(x_2 - x_1)^2 + (y_2 - y_1)^2} \]
Calculate weights for each point in the dataset by passing the calculated distances through the normal distribution.
The normal distribution is also known as the Gaussian distribution. A distribution is simply a way to describe how data is spread out — this isn’t applicable in our case. What is applicable is the shape of this distribution which we will use to calculate the weights.
\[ f(x) = \frac{1}{\sigma \sqrt{2\pi} } e^{-\frac{1}{2}\left(\frac{x-\mu}{\sigma}\right)^2} \]
This is how it looks like.
From the shape of this graph, we can see that larger values of \(x\) give smaller values of \(y\), which is what we want — longer distances should have smaller weights meaning they have a smaller effect on the new position of the point.
We can control the rate at which the weights go to zero by varying what’s known as the bandwidth, or the standard deviation. The graph above is generated with a bandwith of 2.5.
The graph below is generated with a bandwidth of 1.
Let’s get our weights now.
Calculate the weighted average for all points in the dataset. This weighted average is the new location for \(x\)
Below is the formula for weighted average.
\[ \frac{\sum wx}{\sum w} \]
In words, multiply each data point in the set with its corresponding weight and sum all products. Divide that with the sum of all weights.
(tensor([[ 0.097, -3.223],
[ 0.061, -0.331],
[ 0.277, -2.738],
...,
[ 0.000, 0.000],
[ 0.000, 0.000],
[ 0.000, 0.000]]),
tensor([ 0.097, -3.223]))
Let’s calculate the weighted average and assign it as the new location for our point \(x\).
And there you have it! We just moved a single data point.
Let’s do this for all data points and for a single iteration.
Let’s encapsulate the algorithm so we can run it for multiple iterations.
All points have converged.
The algorithm took roughly 1.5 seconds to run 5 iterations. We’ll optimize the algorithm further in Optimized Implementation.
As we can see below, simply moving the algorithm to the GPU won’t help — in fact, it becamse a bit slower.
def update(X):
for i, x in enumerate(X):
dist = (x - X).square().sum(dim=1).sqrt()
ws = gauss_kernel(x=dist, mean=0, std=bw)
X[i] = (ws[:, None] * X).sum(dim=0) / ws.sum()
def meanshift(data):
X = data.clone().to('cuda')
for _ in range(5): update(X)
return X.detach().cpu()
%timeit -n 10 meanshift(data)
1.67 s ± 49.7 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
Let’s see meanshift clustering happen in real time.
X = data.clone()
fig = plot_data(centroids+2, X, n_samples, display=False)
fig.update_layout(xaxis_range=[-40, 40], yaxis_range=[-40, 40], updatemenus=[dict(type='buttons', buttons=[
dict(label='Play', method='animate', args=[None]),
dict(label='Pause', method='animate', args=[[None], dict(frame_duration=0, frame_redraw='False', mode='immediate', transition_duration=0)])
])])
frames = [go.Frame(data=fig.data)]
for _ in range(5):
update(X)
frames.append(go.Frame(data=plot_data(centroids+2, X, n_samples, display=False).data))
fig.frames = frames
fig.show()
The implementation above is roughly 1.5s which is slow. Let’s perform the algorithm on multiple data points simulataneously. We’ll then move the operations onto the GPU.
For each data point \(x\) in the dataset, calculate the distance between \(x\) and every other data point in the dataset.
We’ll begin with a batch size of 8.
tensor([[ 0.611, -20.199],
[ 4.455, -24.188],
[ 2.071, -20.446],
[ 1.011, -23.082],
[ 4.516, -22.281],
[ -0.149, -22.113],
[ 4.029, -18.819],
[ 2.960, -18.646]])
tensor([[[ 0.000, 0.000],
[ -3.844, 3.989],
[ -1.460, 0.247],
...,
[-25.316, -26.796],
[-17.938, -23.610],
[-24.006, -28.684]],
[[ 3.844, -3.989],
[ 0.000, 0.000],
[ 2.383, -3.742],
...,
[-21.472, -30.786],
[-14.094, -27.599],
[-20.162, -32.673]],
[[ 1.460, -0.247],
[ -2.383, 3.742],
[ 0.000, 0.000],
...,
[-23.856, -27.043],
[-16.477, -23.857],
[-22.546, -28.931]],
...,
[[ -0.759, -1.914],
[ -4.603, 2.076],
[ -2.220, -1.667],
...,
[-26.076, -28.710],
[-18.697, -25.523],
[-24.766, -30.598]],
[[ 3.418, 1.380],
[ -0.426, 5.369],
[ 1.958, 1.627],
...,
[-21.898, -25.417],
[-14.520, -22.230],
[-20.588, -27.304]],
[[ 2.349, 1.553],
[ -1.495, 5.542],
[ 0.889, 1.800],
...,
[-22.967, -25.243],
[-15.589, -22.057],
[-21.657, -27.131]]])
(tensor([[ 0.000, 5.540, 1.481, ..., 36.864, 29.651, 37.404],
[ 5.540, 0.000, 4.437, ..., 37.534, 30.989, 38.394],
[ 1.481, 4.437, 0.000, ..., 36.062, 28.994, 36.679],
...,
[ 2.059, 5.050, 2.776, ..., 38.784, 31.639, 39.364],
[ 3.686, 5.386, 2.546, ..., 33.549, 26.552, 34.196],
[ 2.816, 5.740, 2.007, ..., 34.128, 27.009, 34.715]]),
torch.Size([8, 1500]))
Calculate weights for each point in the dataset by passing the calculated distances through the normal distribution.
We can simplify the guassian kernel to a triangular kernel and still achieve the same results, with less computation.
311 µs ± 8.06 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
25 µs ± 594 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
(tensor([[ 0.160, 0.014, 0.134, ..., 0.000, 0.000, 0.000],
[ 0.014, 0.160, 0.033, ..., 0.000, 0.000, 0.000],
[ 0.134, 0.033, 0.160, ..., 0.000, 0.000, 0.000],
...,
[ 0.114, 0.021, 0.086, ..., 0.000, 0.000, 0.000],
[ 0.054, 0.016, 0.095, ..., 0.000, 0.000, 0.000],
[ 0.085, 0.011, 0.116, ..., 0.000, 0.000, 0.000]]),
tensor([[1.000, 0.308, 0.815, ..., 0.000, 0.000, 0.000],
[0.308, 1.000, 0.445, ..., 0.000, 0.000, 0.000],
[0.815, 0.445, 1.000, ..., 0.000, 0.000, 0.000],
...,
[0.743, 0.369, 0.653, ..., 0.000, 0.000, 0.000],
[0.539, 0.327, 0.682, ..., 0.000, 0.000, 0.000],
[0.648, 0.282, 0.749, ..., 0.000, 0.000, 0.000]]))
Calculate the weighted average for all points in the dataset. This weighted average is the new location for \(x\)
144 µs ± 31.2 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
Let’s have another look at formula for weighted average.
\[ \frac{\sum wx}{\sum w} \]
The numerator is actually the definition for matrix multiplication! Therefore we can speed up the operation above by using the @
operator!
A roughly 40% speed up!
tensor([[ 2.049, -20.954],
[ 3.108, -21.923],
[ 2.441, -21.021],
[ 2.176, -21.616],
[ 3.082, -21.466],
[ 1.842, -21.393],
[ 2.946, -20.632],
[ 2.669, -20.594]])
And there you have it! We performed this algorithm on 8 data points simultaneously!
Let’s encapsulate the code so we can perform it over all data points and time it.
def update(X):
for i in range(0, n, bs):
s = slice(i, min(i+bs, n))
dists = (X[s][:, None, :] - X[None, ...]).square().sum(dim=-1).sqrt()
ws = egauss_kernel(dists, mean=0, std=2.5)
X[s] = (ws @ X) / ws.sum(dim=1, keepdim=True)
def meanshift(data):
X = data.clone()
for _ in range(5): update(X)
return X
From 1.5 seconds to 0.5 seconds! A 3x speed increase — very nice!
Let’s move onto the GPU and now see what improvements we get.
def meanshift(data):
X = data.clone().to('cuda')
for _ in range(5): update(X)
return X.detach().cpu()
%timeit -n 10 meanshift(data)
263 ms ± 27.4 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
0.5s to 0.25s — a 2x speed increase!
Meanshift clustering simply involves moving points, by taking into account surrounding points, iteratively until they converge.
If you have any comments, questions, suggestions, feedback, criticisms, or corrections, please do post them down in the comment section below!