flowchart LR subgraph A [Layer 1] direction LR id1[Affine Function] --> id2[ReLU] end subgraph B [Layer 2] direction LR id2 --> id3[Affine Function] end subgraph C [Loss Function] direction LR id3 --> id4[MSE] end
This guide assumes a basic understanding of derivatives and matrices.
Backpropagation sounds and looks daunting. It doesn’t need to be. In fact, backpropagation is really just a fancy word for the chain rule. Implementing a backpropagation algorithm is simply implementing one big fat chain rule equation.
Let’s remind ourselves of the chain rule. The chain rule lets us figure out how much a given variable indirectly changes with respect to another variable. Take the example below.
We want to figure out how much changes with each increment in . The problem is that doesn’t direcly change . Rather, changes which in turn changes .
The chain rule allows us to solve this problem. In this case, the chain rule tells us that we can figure out how much indirecly changes by multiplying the derivative of with respect to , and the derivative of with respect to .
Aaand I’ve just described backpropagation in a nutshell. That’s all there really is to it. The only difference is that in a neural network there are many more intermediate variables and functions, and that we want to find out how the weights indirectly change the loss.
Let’s see this tangibly in action.
We have the following neural network comprised of two layers: the first layer contains the affine function^{1} together with the ReLU, while the second layer contains only the affine function. The loss, which is MSE (Mean Squared Error), will then be calculated from the output of the second layer.
^{1} Affine function is a fancy name for the linear function
flowchart LR subgraph A [Layer 1] direction LR id1[Affine Function] --> id2[ReLU] end subgraph B [Layer 2] direction LR id2 --> id3[Affine Function] end subgraph C [Loss Function] direction LR id3 --> id4[MSE] end
Mathematically speaking, the first layer with a single sample looks like this.
The second layer looks like this.
And the loss function looks like this.
MSE in its most basic form looks like this.
If we have multiple data points, then it looks like this.
However, when working with multiple samples, the mean squared error comes out looking like this, where represents the total number of samples.
Or more simply…^{2}
^{2} is known as the summation or sigma operator. If we have the equation , it means sum the equation for all values of from 1 to 4. Find out more here.
…or even more simply.
Our goal for the rest of this guide is to derive the gradients of .
The equation above looks quite the mouthful though. One might even say scary. How would you even apply the chain rule here? How would you use the chain rule to derive the gradients of the weights and biases?
Let’s simplify things by introducing a bunch of intermediate variables. We’ll begin by substituting the innermost pieces of the equation, and then gradually make our way out.
The menacing equation above now gradually simplifies into the cute equation below.
Very cute, hey?
In this cuter version of the equation, it is visible that incrementing does not directly change the MSE. Rather, incrementing changes , which changes , which changes , which changes , which in turn changes .
See? Just a big, fat, and simple chain rule problem.
is a curly “d” and can be read as “curly d”, or simply as “d”. notation will be used below, due to a concept known as partial derivatives. We will not go into this concept here, however, this is a great brief rundown on partial derivatives.
Now we can tackle finding the gradients for . To do so, let’s find the gradients of each intermediate variable.^{3} ^{4}
^{3} If needed, get a refresher of the derivative rules here.
^{4} denotes a piecewise function. The most simplest piecewise function returns one calculation if a condition is met, and another calculation if the condition is not met. It can be thought of as an if-else statement in programming. Find out more here.
Now we multiply everything together.
And it all eventually expands out to the following.
We can further simplify by taking and common.
We can simplify even further, by letting . The stands for “error”.
And there you go! We’ve derived the formula that will allow us to calculate the gradients of .
When implementing backpropagation in a program, it is often better to implement the entire equation in pieces, as opposed to a single line of code, through storing the result of each intermediate gradient. These intermediate gradients can be reused to calculate the gradients of another variable, such as the bias .
Instead of implementing the following in a single line of code.
We can instead first calculate the gradients of .
Then calculate the gradients of and multiply it with it with the gradients of .
Then multiply the product above with the gradients of .
Then multiply the product above with the gradients of .
And finally multiply the product above with the gradients of
Let’s see this using Python instead.
The following is our neural network.
l1
is the first layer, .
l2
is the second layer, .
loss
is the MSE, .
First we need to calculate the gradients of .
diff = (trn_y - l2)
1loss.g = (2/trn_x.shape[0]) * diff
trn_x.shape[0]
, in this case, returns the total number of samples.
Next are the gradients of
diff.g = loss.g * -1
Then the gradients of
l2.g = diff.g @ w2.T
Then the gradients of
l1.g = l2.g * (l1 > 0).float()
And finally the gradients of .
w1.g = (l1.g * trn_x).sum()
The equation for the gradient of is almost the same as the equation for the gradients of , save for the last line where we do not have to matrix multiply with . Therefore, we can reuse all previous gradient calculations to find the gradient of .
b1.g = (l1.g * 1).sum()
When multiplying various tensors together, make sure their shapes are compatible. Shape manipulations have been omitted above for simplicity.
And that’s all there really is to backpropagation; think of it a one big chain rule problem.
To make sure you’ve got it hammered down, get out a pen and paper and derivate the equations that would compute the gradients of , , , and respectively with respect to the MSE.
And if you really want to hammer down your understanding on what’s happening, then I highly recommend reading The Matrix Calculus You Need For Deep Learning. I’ve also compiled backpropagation practice questions from this paper!
If you have any comments, questions, suggestions, feedback, criticisms, or corrections, please do post them down in the comment section below!
This notebook follows the fastai style guide.
Meanshift clustering is a technique for unsupervised learning. Give this algorithm a bunch of data and it will figure out what groups the data can be sorted into. It does this by iteratively moving all data points until they converge to a single point.
The steps of the algorithm can be summarized as follows:
This is the data we will work with to illustrate meanshift clustering. The data points are put into clearly seperate clusters for the sake of clarity.
In the end, all clusters will converge at their respective center (marked by X).
Let’s start off simple and apply the algorithm to a single point.
For each data point in the dataset, calculate the distance between and every other data point in the dataset.
data
tensor([[ 0.611, -20.199],
[ 4.455, -24.188],
[ 2.071, -20.446],
...,
[ 25.927, 6.597],
[ 18.549, 3.411],
[ 24.617, 8.485]])
X = data.clone(); X.shape
torch.Size([1500, 2])
Each point has an coordinate and a coordinate.
x = X[0, :]; x - X
tensor([[ 0.000, 0.000],
[ -3.844, 3.989],
[ -1.460, 0.247],
...,
[-25.316, -26.796],
[-17.938, -23.610],
[-24.006, -28.684]])
The distance metric we’ll use is Euclidean distance — also better known as Pythagoras’ theorem.
dists = (x - X).square().sum(dim=1).sqrt(); dists
tensor([ 0.000, 5.540, 1.481, ..., 36.864, 29.651, 37.404])
Calculate weights for each point in the dataset by passing the calculated distances through the normal distribution.
The normal distribution is also known as the Gaussian distribution. A distribution is simply a way to describe how data is spread out — this isn’t applicable in our case. What is applicable is the shape of this distribution which we will use to calculate the weights.
def gauss_kernel(x, mean, std):
return torch.exp(-(x - mean) ** 2 / (2 * std ** 2)) / (std * torch.sqrt(2 * tensor(torch.pi)))
This is how it looks like.
From the shape of this graph, we can see that larger values of give smaller values of , which is what we want — longer distances should have smaller weights meaning they have a smaller effect on the new position of the point.
We can control the rate at which the weights go to zero by varying what’s known as the bandwidth, or the standard deviation. The graph above is generated with a bandwith of 2.5.
The graph below is generated with a bandwidth of 1.
Let’s get our weights now.
gauss_kernel(dists, mean=0, std=2.5)
tensor([ 0.160, 0.014, 0.134, ..., 0.000, 0.000, 0.000])
bw = 2.5
ws = gauss_kernel(x=dists, mean=0, std=bw)
Calculate the weighted average for all points in the dataset. This weighted average is the new location for
ws.shape, X.shape
(torch.Size([1500]), torch.Size([1500, 2]))
ws[:, None].shape, X.shape
(torch.Size([1500, 1]), torch.Size([1500, 2]))
Below is the formula for weighted average.
In words, multiply each data point in the set with its corresponding weight and sum all products. Divide that with the sum of all weights.
ws[:, None] * X, ws[0] * X[0, :]
(tensor([[ 0.097, -3.223],
[ 0.061, -0.331],
[ 0.277, -2.738],
...,
[ 0.000, 0.000],
[ 0.000, 0.000],
[ 0.000, 0.000]]),
tensor([ 0.097, -3.223]))
Let’s calculate the weighted average and assign it as the new location for our point .
x = (ws[:, None] * X).sum(dim=0) / ws.sum(); x
tensor([ 1.695, -20.786])
And there you have it! We just moved a single data point.
Let’s do this for all data points and for a single iteration.
for i, x in enumerate(X):
dist = (x - X).square().sum(dim=1).sqrt()
ws = gauss_kernel(x=dist, mean=0, std=bw)
X[i] = (ws[:, None] * X).sum(dim=0) / ws.sum()
plot_data(centroids+2, X, n_samples)
Let’s encapsulate the algorithm so we can run it for multiple iterations.
def update(X):
for i, x in enumerate(X):
dist = (x - X).square().sum(dim=1).sqrt()
ws = gauss_kernel(x=dist, mean=0, std=bw)
X[i] = (ws[:, None] * X).sum(dim=0) / ws.sum()
def meanshift(data):
X = data.clone()
for _ in range(5): update(X)
return X
plot_data(centroids+2, meanshift(data), n_samples)
All points have converged.
%timeit -n 10 meanshift(data)
1.7 s ± 282 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
The algorithm took roughly 1.5 seconds to run 5 iterations. We’ll optimize the algorithm further in Optimized Implementation.
As we can see below, simply moving the algorithm to the GPU won’t help — in fact, it becamse a bit slower.
def update(X):
for i, x in enumerate(X):
dist = (x - X).square().sum(dim=1).sqrt()
ws = gauss_kernel(x=dist, mean=0, std=bw)
X[i] = (ws[:, None] * X).sum(dim=0) / ws.sum()
def meanshift(data):
X = data.clone().to('cuda')
for _ in range(5): update(X)
return X.detach().cpu()
%timeit -n 10 meanshift(data)
1.67 s ± 49.7 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
Let’s see meanshift clustering happen in real time.
X = data.clone()
fig = plot_data(centroids+2, X, n_samples, display=False)
fig.update_layout(xaxis_range=[-40, 40], yaxis_range=[-40, 40], updatemenus=[dict(type='buttons', buttons=[
dict(label='Play', method='animate', args=[None]),
dict(label='Pause', method='animate', args=[[None], dict(frame_duration=0, frame_redraw='False', mode='immediate', transition_duration=0)])
])])
frames = [go.Frame(data=fig.data)]
for _ in range(5):
update(X)
frames.append(go.Frame(data=plot_data(centroids+2, X, n_samples, display=False).data))
fig.frames = frames
fig.show()
The implementation above is roughly 1.5s which is slow. Let’s perform the algorithm on multiple data points simulataneously. We’ll then move the operations onto the GPU.
For each data point in the dataset, calculate the distance between and every other data point in the dataset.
X = data.clone(); X.shape
torch.Size([1500, 2])
We’ll begin with a batch size of 8.
bs = 8
x = X[:bs, :]; x
tensor([[ 0.611, -20.199],
[ 4.455, -24.188],
[ 2.071, -20.446],
[ 1.011, -23.082],
[ 4.516, -22.281],
[ -0.149, -22.113],
[ 4.029, -18.819],
[ 2.960, -18.646]])
x.shape, X.shape
(torch.Size([8, 2]), torch.Size([1500, 2]))
x[:, None, :].shape, X[None, ...].shape
(torch.Size([8, 1, 2]), torch.Size([1, 1500, 2]))
x[:, None, :] - X[None, ...]
tensor([[[ 0.000, 0.000],
[ -3.844, 3.989],
[ -1.460, 0.247],
...,
[-25.316, -26.796],
[-17.938, -23.610],
[-24.006, -28.684]],
[[ 3.844, -3.989],
[ 0.000, 0.000],
[ 2.383, -3.742],
...,
[-21.472, -30.786],
[-14.094, -27.599],
[-20.162, -32.673]],
[[ 1.460, -0.247],
[ -2.383, 3.742],
[ 0.000, 0.000],
...,
[-23.856, -27.043],
[-16.477, -23.857],
[-22.546, -28.931]],
...,
[[ -0.759, -1.914],
[ -4.603, 2.076],
[ -2.220, -1.667],
...,
[-26.076, -28.710],
[-18.697, -25.523],
[-24.766, -30.598]],
[[ 3.418, 1.380],
[ -0.426, 5.369],
[ 1.958, 1.627],
...,
[-21.898, -25.417],
[-14.520, -22.230],
[-20.588, -27.304]],
[[ 2.349, 1.553],
[ -1.495, 5.542],
[ 0.889, 1.800],
...,
[-22.967, -25.243],
[-15.589, -22.057],
[-21.657, -27.131]]])
(x[:, None, :] - X[None, ...]).shape
torch.Size([8, 1500, 2])
dists = (x[:, None, :] - X[None, ...]).square().sum(dim=-1).sqrt(); dists, dists.shape
(tensor([[ 0.000, 5.540, 1.481, ..., 36.864, 29.651, 37.404],
[ 5.540, 0.000, 4.437, ..., 37.534, 30.989, 38.394],
[ 1.481, 4.437, 0.000, ..., 36.062, 28.994, 36.679],
...,
[ 2.059, 5.050, 2.776, ..., 38.784, 31.639, 39.364],
[ 3.686, 5.386, 2.546, ..., 33.549, 26.552, 34.196],
[ 2.816, 5.740, 2.007, ..., 34.128, 27.009, 34.715]]),
torch.Size([8, 1500]))
Calculate weights for each point in the dataset by passing the calculated distances through the normal distribution.
We can simplify the guassian kernel to a triangular kernel and still achieve the same results, with less computation.
plot_func(partial(gauss_kernel, mean=0, std=2.5))
def tri_kernel(x, bw): return (-x+bw).clamp_min(0)/bw
plot_func(partial(tri_kernel, bw=8))
%timeit gauss_kernel(dists, mean=0, std=2.5)
311 µs ± 8.06 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit tri_kernel(dists, bw=8)
25 µs ± 594 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
gauss_kernel(dists, mean=0, std=2.5), tri_kernel(dists, bw=8)
(tensor([[ 0.160, 0.014, 0.134, ..., 0.000, 0.000, 0.000],
[ 0.014, 0.160, 0.033, ..., 0.000, 0.000, 0.000],
[ 0.134, 0.033, 0.160, ..., 0.000, 0.000, 0.000],
...,
[ 0.114, 0.021, 0.086, ..., 0.000, 0.000, 0.000],
[ 0.054, 0.016, 0.095, ..., 0.000, 0.000, 0.000],
[ 0.085, 0.011, 0.116, ..., 0.000, 0.000, 0.000]]),
tensor([[1.000, 0.308, 0.815, ..., 0.000, 0.000, 0.000],
[0.308, 1.000, 0.445, ..., 0.000, 0.000, 0.000],
[0.815, 0.445, 1.000, ..., 0.000, 0.000, 0.000],
...,
[0.743, 0.369, 0.653, ..., 0.000, 0.000, 0.000],
[0.539, 0.327, 0.682, ..., 0.000, 0.000, 0.000],
[0.648, 0.282, 0.749, ..., 0.000, 0.000, 0.000]]))
ws = tri_kernel(dists, bw=8); ws.shape
torch.Size([8, 1500])
Calculate the weighted average for all points in the dataset. This weighted average is the new location for
ws.shape, X.shape
(torch.Size([8, 1500]), torch.Size([1500, 2]))
ws[..., None].shape, X[None, ...].shape
(torch.Size([8, 1500, 1]), torch.Size([1, 1500, 2]))
(ws[..., None] * X[None, ...]).shape
torch.Size([8, 1500, 2])
(ws[..., None] * X[None, ...]).sum(1).shape
torch.Size([8, 2])
%timeit (ws[..., None] * X[None, ...]).sum(1)
144 µs ± 31.2 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
Let’s have another look at formula for weighted average.
The numerator is actually the definition for matrix multiplication! Therefore we can speed up the operation above by using the @
operator!
%timeit ws @ X
7.64 µs ± 184 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
A roughly 40% speed up!
x = (ws @ X) / ws.sum(dim=1, keepdim=True); x
tensor([[ 2.049, -20.954],
[ 3.108, -21.923],
[ 2.441, -21.021],
[ 2.176, -21.616],
[ 3.082, -21.466],
[ 1.842, -21.393],
[ 2.946, -20.632],
[ 2.669, -20.594]])
And there you have it! We performed this algorithm on 8 data points simultaneously!
Let’s encapsulate the code so we can perform it over all data points and time it.
?slice
bs
8
min(1508, 1500)
1500
X = data.clone()
n = len(data)
bs = 8
for i in range(0, n, bs):
s = slice(i, min(i+bs, n))
dists = (X[s][:, None, :] - X[None, ...]).square().sum(dim=-1).sqrt()
ws = egauss_kernel(dists, mean=0, std=2.5)
X[s] = (ws @ X) / ws.sum(dim=1, keepdim=True)
plot_data(centroids+2, X, n_samples)
def update(X):
for i in range(0, n, bs):
s = slice(i, min(i+bs, n))
dists = (X[s][:, None, :] - X[None, ...]).square().sum(dim=-1).sqrt()
ws = egauss_kernel(dists, mean=0, std=2.5)
X[s] = (ws @ X) / ws.sum(dim=1, keepdim=True)
def meanshift(data):
X = data.clone()
for _ in range(5): update(X)
return X
plot_data(centroids+2, meanshift(data), n_samples)
%timeit -n 10 meanshift(data)
700 ms ± 43.4 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
From 1.5 seconds to 0.5 seconds! A 3x speed increase — very nice!
Let’s move onto the GPU and now see what improvements we get.
def meanshift(data):
X = data.clone().to('cuda')
for _ in range(5): update(X)
return X.detach().cpu()
%timeit -n 10 meanshift(data)
263 ms ± 27.4 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
0.5s to 0.25s — a 2x speed increase!
Meanshift clustering simply involves moving points, by taking into account surrounding points, iteratively until they converge.
If you have any comments, questions, suggestions, feedback, criticisms, or corrections, please do post them down in the comment section below!
This post covers einstein summation notation syntax in terms of programming languages.
Einstein summation notation (or einsum notation for short) is a handy way to write various matrix operations in a succinct, universal manner. With it, you can probably forget all the various symbols and operators there are and stick to one common syntax, that once understood, can be more intuitive.
For example, matrix multiplication can be written as ik, kj -> ij
and a transpose of a matrix can be written as ij -> ji
.
Let’s figure this out.
The following are two general rules one can use to quickly write einsum notation.
Repeating letters between input arrays means that values along those axes will be multiplied together.
Omitting a letter from the output means that values along that axis will be summed.
However, I don’t find these rules intuitive, even a little confusing. Why?
Matrices have the order of row by column. A 2x3 matrix has 2 rows and 3 columns. When we perform matrix multiplication, we take the dot product of each row in the first matrix with each column in the second matrix.
However, when the einsum rules above — specifically the first rule — are used to denote matrix multiplication (, as depicted below), the order of a matrix appears to change.
In order for the einsum rules and the definition of matrix multiplcation above to stay consistent, now denotes the rows in the first matrix and columns in the second matrix, thereby changing the order of a matrix to column by row.
But even if we let denote the columns in the first matrix, we end up doing dot products with each column in the first matrix and with each row in the second matrix.
Not intuitive.
The key to understanding einsum notation is to not think of axes, but of iterators. For example, is an iterator that returns the rows of a matrix. is an iterator that returns the columns of a matrix.
Let’s begin with a relatively more simple example: the hadamard product (also known as the elementwise product or elementwise multiplication.)
We have the following two matrices.
To access the element 8 in matrix , we need to return the second row and first column^{1}. This can be denoted as . The first digit in the subscript refers to the row and the second refers to the column. We can refer to any entry generally as .
^{1} This assumes the matrix is zero indexed. This means is the zeroth row of .
Taking the hadamard product looks like this.
In words, what’s happening is that we’re looping through all the rows of and . For each row, we also loop through each column and multiply those columns together.
Let’s focus on that last line above.
This line represents elementwise multiplication. For each row in and , we iterate through each column in those rows, and take their product.
In einsum notation, we can more succinctly write this as . This has 4 parts.
Let’s cover matrix multiplication in the same manner as above.
Matrix multiplication simply involves taking the dot product of each row in the first matrix with each column in the second matrix.
We’ll need to use 3 iterators for this: one iterator to loop through the rows of , another iterator to loop through the columns of , and a third iterator to loop through the elements in a row and column.
Let’s focus in on the last line above.
This can more succinctly be written in einsum notation as — for each row in , and for each column in , iterate through each element , take their product, and sum the those products. The location of the output of the dot product in the output matrix is .
For each row , output the row.
Iterate through each row , and sum all rows.
A scalar is output, hence no output iterator.
For each row in and , multiply them together.
For each row in and , multiply them together, and sum the products.
A scalar is output, hence no output iterator.
For each row in , multiply it with each row in .
The outer product involves multiplying each element in with all elements in .
For each row , iterate through each column and output it.
For each row , iterate through each column and output it in at row and column .
For each row , iterate through each column and output it.
For each row , iterate through each column and sum them.
A scalar is output, hence no output iterator.
For each row , iterate through each column and sum them.
For each column , iterate through each row and sum them.
For each row in and , iterate throuch each column , and take their product.
For each row in , and for each row in , iterate through each column in and each column in , and take their product.
For each row in , and for each column in , iterate through each element , take their product, and then sum those products.
For each row in , and for each row in , iterate through each column and take their product.
A three dimensional tensor is output, hence the three output iterators.
For each row in , iterate through each column and multiply it with each row in by iterating through each column in that row .
A four dimensional tensor is output, hence the four output iterators.
And that’s that! The key is to think in terms of iterators that return locations in a matrix.
It may help to implement the operations above by yourself through pencil and paper, and in a programming languge too.
If you have any comments, questions, suggestions, feedback, criticisms, or corrections, please do post them down in the comment section below!
This notebook follows the fastai style guide.
Well, my implementation was a partial success: I managed to generate a mask, but failed to apply it. If you don’t understand, hold on as I’ll explain DiffEdit.
In this notebook, I try to implement the DiffEdit paper: a diffusion algorithm that allows us to replace the subject of an image with another subject, simply through a text prompt.
This notebook does not closely follow the paper implementation of DiffEdit. However, it does capture the underlying mechanisms.
In a nutshell, this is done by generating a mask from the text prompt. This mask cuts out the subject from the image, which allows a new subject to be added to the image.
While I was successful in generating a mask, I wasn’t successful in applying it to an image. So at the end of this notebook, I’ll use the Hugging Face Stable Diffusion Inpaint Pipeline to see the mask in action.
If you would like a refresher on how Stable Diffusion can be implemented from its various components, you can read my post on this here.
Let’s say we have an image of a horse in front of a forest. We want to replace the horse with a zebra. At a high level, DiffEdit achieves this in the following manner.
^{1} In this case, normalizing means scaling the values to be between 0 and 1.
^{2} Binarizing means making values to be any of 2 possible values. In this case, either 0 or 1.
! pip install -Uqq fastcore transformers diffusers
1import logging; logging.disable(logging.WARNING)
from fastcore.all import *
from fastai.imports import *
from fastai.vision.all import *
from transformers import CLIPTokenizer, CLIPTextModel
tokz = CLIPTokenizer.from_pretrained('openai/clip-vit-large-patch14', torch_dtype=torch.float16)
txt_enc = CLIPTextModel.from_pretrained('openai/clip-vit-large-patch14', torch_dtype=torch.float16).to('cuda')
from diffusers import AutoencoderKL, UNet2DConditionModel
vae = AutoencoderKL.from_pretrained('stabilityai/sd-vae-ft-ema', torch_dtype=torch.float16).to('cuda')
unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet", torch_dtype=torch.float16).to("cuda")
from diffusers import LMSDiscreteScheduler
sched = LMSDiscreteScheduler(
beta_start = 0.00085,
beta_end = 0.012,
beta_schedule = 'scaled_linear',
num_train_timesteps = 1000
)
In this simple loop, I’m making sure I can correctly generate an image based on another image as the starting point.
prompt = ['earth']
neg_prompt = ['']
w, h = 512, 512
n_inf_steps = 50
g_scale = 8
bs = 1
seed = 77
txt_inp = tokz(
prompt,
padding = 'max_length',
max_length = tokz.model_max_length,
truncation = True,
return_tensors = 'pt',
)
txt_emb = txt_enc(txt_inp['input_ids'].to('cuda'))[0].half()
neg_inp = tokz(
[''] * bs,
padding = 'max_length',
max_length = txt_inp['input_ids'].shape[-1],
return_tensors = 'pt'
)
neg_emb = txt_enc(neg_inp['input_ids'].to('cuda'))[0].half()
embs = torch.cat([neg_emb, txt_emb])
!curl --output planet.png 'https://images.unsplash.com/photo-1630839437035-dac17da580d0?ixlib=rb-4.0.3&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=2515&q=80'
% Total % Received % Xferd Average Speed Time Time Time Current
Dload Upload Total Spent Left Speed
0 0 0 0 0 0 0 0 --:--:-- --:--:-- --:--:-- 0100 188k 100 188k 0 0 4829k 0 --:--:-- --:--:-- --:--:-- 4829k
img = Image.open('/content/planet.png').resize((512, 512)); img
import torchvision.transforms as T
with torch.no_grad():
img = T.ToTensor()(img).unsqueeze(0).half().to('cuda') * 2 - 1
lat = vae.encode(img)
lat = 0.18215 * lat.latent_dist.sample(); lat.shape
Below we can see the all 4 channels of the compressed image.
fig, axs = plt.subplots(1, 4, figsize=(16, 4))
for c in range(4):
axs[c].imshow(lat[0][c].cpu(), cmap='Greys')
sched = LMSDiscreteScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule='scaled_linear',
num_train_timesteps=1000
); sched
LMSDiscreteScheduler {
"_class_name": "LMSDiscreteScheduler",
"_diffusers_version": "0.16.1",
"beta_end": 0.012,
"beta_schedule": "scaled_linear",
"beta_start": 0.00085,
"num_train_timesteps": 1000,
"prediction_type": "epsilon",
"trained_betas": null
}
sched.set_timesteps(n_inf_steps)
torch.manual_seed(seed)
noise = torch.randn_like(lat)
sched.timesteps = sched.timesteps.to(torch.float32)
start_step = 10
ts = tensor([sched.timesteps[start_step]])
lat = sched.add_noise(lat, noise, timesteps=ts)
from tqdm.auto import tqdm
for i, ts in enumerate(tqdm(sched.timesteps)):
if i >= start_step:
inp = torch.cat([lat] * 2)
inp = sched.scale_model_input(inp, ts)
with torch.no_grad(): preds = unet(inp, ts, encoder_hidden_states=embs)['sample']
pred_neg, pred_txt = preds.chunk(2)
pred = pred_neg + g_scale * (pred_txt - pred_neg)
lat = sched.step(pred, ts, lat).prev_sample
lat.shape
torch.Size([1, 4, 64, 64])
lat *= (1/0.18215)
with torch.no_grad(): img = vae.decode(lat).sample
img = (img / 2 + 0.5).clamp(0, 1)
img = img[0].detach().cpu().permute(1, 2, 0).numpy()
img = (img * 255).round().astype('uint8')
Image.fromarray(img)
I’ll encapsulate the code above so we can focus on DiffEdit.
def get_embs(prompt, neg_prompt):
txt_inp = tok_seq(prompt)
txt_emb = calc_emb(txt_inp['input_ids'])
neg_inp = tok_seq(neg_prompt)
neg_emb = calc_emb(neg_inp['input_ids'])
return torch.cat([neg_emb, txt_emb])
def tok_seq(prompt):
return tokz(
prompt,
padding = 'max_length',
max_length = tokz.model_max_length,
truncation = True,
return_tensors = 'pt',
)
def calc_emb(inp_ids):
return txt_enc(inp_ids.to('cuda'))[0].half()
def get_lat(img, start_step=30):
return noise_lat(compress_img(img), start_step)
def compress_img(img):
with torch.no_grad():
img = T.ToTensor()(img).unsqueeze(0).half().to('cuda') * 2 - 1
lat = vae.encode(img)
return 0.18215 * lat.latent_dist.sample()
def noise_lat(lat, start_step):
torch.manual_seed(seed)
noise = torch.randn_like(lat)
sched.set_timesteps(n_inf_steps)
sched.timesteps = sched.timesteps.to(torch.float32)
ts = tensor([sched.timesteps[start_step]])
return sched.add_noise(lat, noise, timesteps=ts)
def denoise(lat, ts):
inp = torch.cat([lat] * 2)
inp = sched.scale_model_input(inp, ts)
with torch.no_grad(): preds = unet(inp, ts, encoder_hidden_states=embs)['sample']
pred_neg, pred_txt = preds.chunk(2)
pred = pred_neg + g_scale * (pred_txt - pred_neg)
return sched.step(pred, ts, lat).prev_sample
def decompress(lat):
with torch.no_grad(): img = vae.decode(lat*(1/0.18215)).sample
img = (img / 2 + 0.5).clamp(0, 1)
img = img[0].detach().cpu().permute(1, 2, 0).numpy()
return (img * 255).round().astype('uint8')
prompt = ['basketball']
neg_prompt = ['']
w, h = 512, 512
n_inf_steps = 70
start_step = 30
g_scale = 7.5
bs = 1
seed = 77
! curl --output img.png 'https://images.unsplash.com/photo-1630839437035-dac17da580d0?ixlib=rb-4.0.3&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=2515&q=80'
img = Image.open('/content/img.png').resize((512, 512))
embs = get_embs(prompt, neg_prompt)
lat = get_lat(img)
for i, ts in enumerate(tqdm(sched.timesteps)):
if i >= start_step: lat = denoise(lat, ts)
img = decompress(lat)
Image.fromarray(img)
% Total % Received % Xferd Average Speed Time Time Time Current
Dload Upload Total Spent Left Speed
0 0 0 0 0 0 0 0 --:--:-- --:--:-- --:--:-- 0100 188k 100 188k 0 0 5232k 0 --:--:-- --:--:-- --:--:-- 5381k
Let’s review the steps of DiffEdit once more.
^{3} In this case, normalizing means scaling the values to be between 0 and 1.
^{4} Binarizing means making values to be any of 2 possible values. In this case, either 0 or 1.
First, we need to obtain an image of a horse and an image of a zebra.
We’ll use this as our original image.
! curl --output img.png 'https://images.unsplash.com/photo-1553284965-fa61e9ad4795?ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D&auto=format&fit=crop&w=1742&q=80'
Image.open('/content/img.png').resize((512, 512))
% Total % Received % Xferd Average Speed Time Time Time Current
Dload Upload Total Spent Left Speed
0 0 0 0 0 0 0 0 --:--:-- --:--:-- --:--:-- 0100 515k 100 515k 0 0 10.7M 0 --:--:-- --:--:-- --:--:-- 10.7M
This is the generated image of the horse.
prompt = ['horse']
img = Image.open('/content/img.png').resize((512, 512))
embs = get_embs(prompt, neg_prompt)
lat1 = get_lat(img)
for i, ts in enumerate(tqdm(sched.timesteps)):
if i >= start_step: lat1 = denoise(lat1, ts)
Image.fromarray(decompress(lat1))
And this is the generated image of the zebra.
prompt = ['zebra']
img = Image.open('/content/img.png').resize((512, 512))
embs = get_embs(prompt, neg_prompt)
lat2 = get_lat(img)
for i, ts in enumerate(tqdm(sched.timesteps)):
if i >= start_step: lat2 = denoise(lat2, ts)
Image.fromarray(decompress(lat2))
lat1[:].shape
torch.Size([1, 4, 64, 64])
We’ll first convert the generated images to grayscale and then take their difference.
import torchvision.transforms.functional as F
img1 = F.to_tensor(F.to_grayscale(Image.fromarray(decompress(lat1[:]))))
img2 = F.to_tensor(F.to_grayscale(Image.fromarray(decompress(lat2[:]))))
diff = torch.abs(img1 - img2)
Then we’ll normalize the difference to have values between 0 and 1.
norm = diff / torch.max(diff)
Image.fromarray((norm*255).squeeze().numpy().round().astype(np.uint8))
And then finally binarize the values so they are either 0 or 1.
thresh = 0.5
bin = (norm > thresh).float()
Image.fromarray((bin.squeeze().numpy()*255).astype(np.uint8))
Image.fromarray((bin.squeeze().numpy()*255).astype(np.uint8)).save('mask.png')
Now we need to apply transformations to the binarized mask so it encapsulates the shape of the horbra/zeborse (horse + zebra ).
import cv2 as cv
from google.colab.patches import cv2_imshow
mask = cv.imread('mask.png', cv.IMREAD_GRAYSCALE)
kernel = cv.getStructuringElement(cv.MORPH_ELLIPSE, (10, 10))
The kernel is essentially a shape. Multiple shapes are be applied to the image in order to perform transformations.
I’ve chosen to use an ellipse of size 10 by 10 units.
Applying an erosion transformation makes our binarized mask look like this. Such transformations remove can remove small, noisy objects.
cv2_imshow(cv.erode(mask, kernel))
Applying a dilation transformation makes our binarized mask look like this. Such transformations can fill in gaps and smooth edges.
cv2_imshow(cv.dilate(mask, kernel))
To produce the final mask, I’ll apply the closing transform^{5} 7 times consecutively…
^{5} The closing transform is a dilation transform followed immediately by an erosion transform. This allows holes or small black points to be closed.
mask_closed = mask
for _ in range(7):
mask_closed = cv.morphologyEx(mask_closed, cv.MORPH_CLOSE, kernel)
cv2_imshow(mask_closed)
…and then apply the dilation transform 3 times consecutively.
mask_dilated = mask_closed
for _ in range(3):
mask_dilated = cv.dilate(mask_dilated, kernel)
cv2_imshow(mask_dilated)
A more concise way of doing the above.
mask_closed = cv.morphologyEx(mask, cv.MORPH_CLOSE, kernel, iterations=7)
mask_dilated = cv.dilate(mask_closed, kernel, iterations=3)
cv2_imshow(mask_dilated)
Then I’ll stack the mask together so I have a 3 channel image.
mask = np.stack((mask_dilated, mask_dilated, mask_dilated), axis=-1)/255; mask.shape
(512, 512, 3)
To read more about such transformations applied above, you can read them at the OpenCV docs here.
Now for the part I couldn’t figure out how to do.
By applying the mask to the original iamge. This is how the cutout of the horse looks like.
fore = torch.mul(F.to_tensor(img).permute(1, 2, 0), torch.from_numpy(mask))
Image.fromarray((fore*255).numpy().round().astype('uint8'))
You can see that it does not exactly cut out the outline: this is good because different subjects will have different levels of protrusion.
And this is how the background pixels look like.
inv_mask = 1 - mask
back = torch.mul(F.to_tensor(img).permute(1, 2, 0), torch.from_numpy(inv_mask))
Image.fromarray((back*255).numpy().round().astype('uint8'))
Adding both the foreground and the background together…
Image.fromarray(((fore+back)*255).numpy().round().astype(np.uint8))
Note the subtle, yet very important difference in the two cells below, along with their output.
x = tensor([1, 2, 3])
def foo(y):
y += 1
return y
foo(x)
x
tensor([2, 3, 4])
x = tensor([1, 2, 3])
def foo(y):
z = y + 1
return z
foo(x)
x
tensor([1, 2, 3])
This was the reason for the bug that had me pulling my hair out for hours — when you pass a list or any list-like object (or even just objects I think), a copy is not passed, but rather the same object.
However, I can’t quite correctly apply the mask to the latent when denoising.
prompt = ['zebra']
img = Image.open('/content/img.png').resize((512, 512))
embs = get_embs(prompt, neg_prompt)
lat = get_lat(img)
inv_mask = 1 - mask
for i, ts in enumerate(tqdm(sched.timesteps)):
if i >= start_step:
back = torch.mul(torch.from_numpy(decompress(get_lat(img, start_step=i)))/255, torch.from_numpy(inv_mask))
fore = torch.mul(torch.from_numpy(decompress(lat))/255, torch.from_numpy(mask))
bafo = (back + fore)*255
lat = compress_img(Image.fromarray(bafo.numpy().round().astype(np.uint8)))
Image.fromarray(decompress(lat))
After asking on the fastai forum, and hours of fiddling about, the reason why this is happening is most likely due to the fact that I keep uncompressing and recompressing the latent. The compression that the VAE performs is lossy, so detail is lost during each compression and decompression.
My mask is not calculated in the same latent space as my latent. In other words, my mask was calculated as a 512x512 pixel and 3 channel image, whereas my latent is a 64x64 pixel and 4 channel image. I’m uncompressing the latent so that I can apply the mask to cutout the zebra and add the background pixels, and then recompressing.
To fix this, I would need to generate the mask as a 64x64 pixel and 3 channel image.
To at least see the mask in action, let’s use the Hugging Face Stable Diffusion Pipeline.
The Hugging Face Stable Diffusion pipeline works by simply providing the starting image and a mask. The pipeline will handle the rest.
from diffusers import StableDiffusionInpaintPipeline
pipe = StableDiffusionInpaintPipeline.from_pretrained(
"runwayml/stable-diffusion-inpainting",
revision="fp16",
torch_dtype=torch.float16,
).to("cuda")
/usr/local/lib/python3.10/dist-packages/transformers/models/clip/feature_extraction_clip.py:28: FutureWarning: The class CLIPFeatureExtractor is deprecated and will be removed in version 5 of Transformers. Please use CLIPImageProcessor instead.
warnings.warn(
img
torch.manual_seed(77)
# 35 or 25 steps are good
out = pipe(
prompt=["zebra"],
image=img,
mask_image=Image.fromarray((mask*255).round().astype(np.uint8)),
num_inference_steps = 25
).images
out[0]
Looking back, the actual problem for me was that I let the paper feel intimidating; all those symbols, variables, jargon, and notation. I ended up glazing over the paper and missing the smaller details.
To help prevent this the next time, I should
And that’s that.
If you have any comments, questions, suggestions, feedback, criticisms, or corrections, please do post them down in the comment section below!
This notebook follows the fastai style guide.
This bits and bobs is explained in terms of my explorations and experimentations. Therefore, explanations and descriptions below may not necessarily be accurate.
iter
iter
creates what’s known as an iterator. It is a type of iterable.
An iterable is anything that can be looped through (e.g., a list or a string).
iter
essentially allows you to loop through an iterable without using a for
loop. It gives you finer and more granuler control over when you loop, and how how much you loop.
Docstring:
iter(iterable) -> iterator
iter(callable, sentinel) -> iterator
Get an iterator from an object. In the first form, the argument must
supply its own iterator, or be a sequence.
In the second form, the callable is called until it returns the sentinel.
Type: builtin_function_or_method
l = list(range(10)); l
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
it = iter(l); it
<list_iterator at 0x11e29a6e0>
next(it)
0
next(it)
1
next(it)
2
islice
islice
is a type of iterator that returns items from an iterable at a time.
Init signature: islice(self, /, *args, **kwargs)
Docstring:
islice(iterable, stop) --> islice object
islice(iterable, start, stop[, step]) --> islice object
Return an iterator whose next() method returns selected values from an
iterable. If start is specified, will skip all preceding elements;
otherwise, start defaults to zero. Step defaults to one. If
specified as another value, step determines how many values are
skipped between successive calls. Works like a slice() on a list
but returns an iterator.
Type: type
Subclasses:
from itertools import islice
it = iter(l)
list(islice(it, 5))
[0, 1, 2, 3, 4]
list(islice(it, 5))
[5, 6, 7, 8, 9]
list(islice(it, 5))
[]
yield
yield
is a substitute for return
in a function or method. When yield
is used, the function is known as a generator.
yield
essentially allows you to perform multiple returns, and also allows you to treat a function as an iterator.
To demonstrate multiple returns, let’s create a function that chops a list up into smaller lists.
def chunks(l, step):
for i in range(0, len(l), step): yield l[i:i+step]
list(chunks(l, 5))
[[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]
l_iter = chunks(l, 5); l_iter
<generator object chunks at 0x11e2a8cf0>
next(l_iter)
[0, 1, 2, 3, 4]
next(l_iter)
[5, 6, 7, 8, 9]
next(l_iter)
StopIteration:
If you have any comments, questions, suggestions, feedback, criticisms, or corrections, please do post them down in the comment section below!
This notebook follows the fastai style guide.
In this notebook, we’ll implement stable diffusion from its various components through the Hugging Face Diffusers library.
At the end, we’ll have our own custom stable diffusion class, from which we can generate images as simply as diffuser.diffuse()
.
If you would like a refresher, I’ve summarized at a high level how a diffuser is trained in this post. Though this notebook focuses on inference and not the training aspect, the linked summary may be helpful.
Let’s begin.
Before we get hands on with the code, let’s refresh how inference works for a diffuser.
The main components in use are:
! pip install -Uqq fastcore transformers diffusers
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 7.0/7.0 MB 40.9 MB/s eta 0:00:00
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 934.9/934.9 kB 57.7 MB/s eta 0:00:00
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 224.5/224.5 kB 23.9 MB/s eta 0:00:00
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 7.8/7.8 MB 29.1 MB/s eta 0:00:00
1import logging; logging.disable(logging.WARNING)
from fastcore.all import *
from fastai.imports import *
from fastai.vision.all import *
To process the prompt, we need to download a tokenizer and a text encoder. The tokenizer will split the prompt into tokens while the text encoder will convert the tokens into a numerical representation (an embedding).
from transformers import CLIPTokenizer, CLIPTextModel
tokz = CLIPTokenizer.from_pretrained('openai/clip-vit-large-patch14', torch_dtype=torch.float16)
txt_enc = CLIPTextModel.from_pretrained('openai/clip-vit-large-patch14', torch_dtype=torch.float16).to('cuda')
float16
is used for faster performance.
The U-Net will predict the noise in the image, while the VAE will decompress the generated image.
from diffusers import AutoencoderKL, UNet2DConditionModel
vae = AutoencoderKL.from_pretrained('stabilityai/sd-vae-ft-ema', torch_dtype=torch.float16).to('cuda')
unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet", torch_dtype=torch.float16).to("cuda")
The scheduler will control how much noise is intially added to the image, and will also control how much of the noise predicted from the U-Net will be subtracted from the image.
from diffusers import LMSDiscreteScheduler
sched = LMSDiscreteScheduler(
beta_start = 0.00085,
beta_end = 0.012,
beta_schedule = 'scaled_linear',
num_train_timesteps = 1000
); sched
LMSDiscreteScheduler {
"_class_name": "LMSDiscreteScheduler",
"_diffusers_version": "0.16.1",
"beta_end": 0.012,
"beta_schedule": "scaled_linear",
"beta_start": 0.00085,
"num_train_timesteps": 1000,
"prediction_type": "epsilon",
"trained_betas": null
}
The six main parameters needed for generation are:
prompt = ['a photograph of an astronaut riding a horse']
w, h = 512, 512
n_inf_steps = 70
g_scale = 7.5
bs = 1
seed = 77
Now we need to parse the prompt. To do so, we’ll first tokenize it, and then encode the tokens to produce an embedding.
First, let’s tokenize.
txt_inp = tokz(
prompt,
padding = 'max_length',
max_length = tokz.model_max_length,
truncation = True,
return_tensors = 'pt'
); txt_inp
{'input_ids': tensor([[49406, 320, 8853, 539, 550, 18376, 6765, 320, 4558, 49407,
49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
49407, 49407, 49407, 49407, 49407, 49407, 49407]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0]])}
The token 49407
is a padding token and represents '<|endoftext|>'
. These tokens have been given an attention mask of 0.
tokz.decode(49407)
'<|endoftext|>'
Now using the text encoder, we’ll create an embedding out of these tokens.
txt_emb = txt_enc(txt_inp['input_ids'].to('cuda'))[0].half(); txt_emb
tensor([[[-0.3884, 0.0229, -0.0523, ..., -0.4902, -0.3066, 0.0674],
[ 0.0292, -1.3242, 0.3076, ..., -0.5254, 0.9766, 0.6655],
[ 0.4609, 0.5610, 1.6689, ..., -1.9502, -1.2266, 0.0093],
...,
[-3.0410, -0.0674, -0.1777, ..., 0.3950, -0.0174, 0.7671],
[-3.0566, -0.1058, -0.1936, ..., 0.4258, -0.0184, 0.7588],
[-2.9844, -0.0850, -0.1726, ..., 0.4373, 0.0092, 0.7490]]],
device='cuda:0', dtype=torch.float16, grad_fn=<NativeLayerNormBackward0>)
txt_emb.shape
torch.Size([1, 77, 768])
We also need to create an embedding for an empty prompt, also known as the uncondtional prompt. This embedding is what is used to control the guidance.
txt_inp['input_ids'].shape
torch.Size([1, 77])
1max_len = txt_inp['input_ids'].shape[-1]
uncond_inp = tokz(
2 [''] * bs,
padding = 'max_length',
max_length = max_len,
return_tensors = 'pt',
); uncond_inp
{'input_ids': tensor([[49406, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
49407, 49407, 49407, 49407, 49407, 49407, 49407]]), 'attention_mask': tensor([[1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0]])}
uncond_inp['input_ids'].shape
torch.Size([1, 77])
uncond_emb = txt_enc(uncond_inp['input_ids'].to('cuda'))[0].half()
uncond_emb.shape
torch.Size([1, 77, 768])
We can then concatenate both the unconditonal embedding and the text embedding together. This allows images to be generated from each prompt without having to go through the U-Net twice.
embs = torch.cat([uncond_emb, txt_emb])
It’s now time to create our noisy image, which will be the starting point for generation.
We’ll create a single latent that is 64 by 64 pixels, and that also has 4 channels. After the latent is denoised, we’ll decompress it to a 512 by 512 pixel image with 3 channels.
bs, unet.config.in_channels, h//8, w//8
(1, 4, 64, 64)
print(torch.randn((2, 3, 4)))
print(torch.randn((2, 3, 4)).shape)
tensor([[[ 0.0800, -1.3597, -0.2033, -0.5647],
[-1.6066, 0.8178, 1.0832, 0.0638],
[ 0.3133, 1.8516, 0.4320, -0.9295]],
[[-1.0798, 3.2928, 0.7443, 1.2190],
[-0.4984, 0.3551, -0.6012, -0.5856],
[-0.3988, -1.2950, -1.6061, -0.0207]]])
torch.Size([2, 3, 4])
torch.manual_seed(seed)
lats = torch.randn((bs, unet.config.in_channels, h//8, w//8)); lats.shape
torch.Size([1, 4, 64, 64])
The latent is a rank 4 tensor. 1
refers to the batch size, which is the number of images being generated. 4
is the number of channels, and 64
is the number of pixel with regard to both height and width.
lats = lats.to('cuda').half(); lats
tensor([[[[-0.5044, -0.4163, -0.1365, ..., -1.6104, 0.1381, 1.7676],
[ 0.7017, 1.5947, -1.4434, ..., -1.5859, -0.4089, -2.8164],
[ 1.0664, -0.0923, 0.3462, ..., -0.2390, -1.0947, 0.7554],
...,
[-1.0283, 0.2433, 0.3337, ..., 0.6641, 0.4219, 0.7065],
[ 0.4280, -1.5439, 0.1409, ..., 0.8989, -1.0049, 0.0482],
[-1.8682, 0.4988, 0.4668, ..., -0.5874, -0.4019, -0.2856]],
[[ 0.5688, -1.2715, -1.4980, ..., 0.2230, 1.4785, -0.6821],
[ 1.8418, -0.5117, 1.1934, ..., -0.7222, -0.7417, 1.0479],
[-0.6558, 0.1201, 1.4971, ..., 0.1454, 0.4714, 0.2441],
...,
[ 0.9492, 0.1953, -2.4141, ..., -0.5176, 1.1191, 0.5879],
[ 0.2129, 1.8643, -1.8506, ..., 0.8096, -1.5264, 0.3191],
[-0.3640, -0.9189, 0.8931, ..., -0.4944, 0.3916, -0.1406]],
[[-0.5259, 1.5059, -0.3413, ..., 1.2539, 0.3669, -0.1593],
[-0.2957, -0.1169, -2.0078, ..., 1.9268, 0.3833, -0.0992],
[ 0.5020, 1.0068, -0.9907, ..., -0.3008, 0.7324, -1.1963],
...,
[-0.7437, -1.1250, 0.1349, ..., -0.6714, -0.6753, -0.7920],
[ 0.5415, -0.5269, -1.0166, ..., 1.1270, -1.7637, -1.5156],
[-0.2319, 0.9165, 1.6318, ..., 0.6602, -1.2871, 1.7568]],
[[ 0.7100, 0.4133, 0.5513, ..., 0.0326, 0.9175, 1.4922],
[ 0.8862, 1.3760, 0.8599, ..., -2.1172, -1.6533, 0.8955],
[-0.7783, -0.0246, 1.4717, ..., 0.0328, 0.4316, -0.6416],
...,
[ 0.0855, -0.1279, -0.0319, ..., -0.2817, 1.2744, -0.5854],
[ 0.2402, 1.3945, -2.4062, ..., 0.3435, -0.5254, 1.2441],
[ 1.6377, 1.2539, 0.6099, ..., 1.5391, -0.6304, 0.9092]]]],
device='cuda:0', dtype=torch.float16)
Our latent has random values which represent noise. This noise needs to be scaled so it can work with the scheduler.
sched.set_timesteps(n_inf_steps); sched
LMSDiscreteScheduler {
"_class_name": "LMSDiscreteScheduler",
"_diffusers_version": "0.16.1",
"beta_end": 0.012,
"beta_schedule": "scaled_linear",
"beta_start": 0.00085,
"num_train_timesteps": 1000,
"prediction_type": "epsilon",
"trained_betas": null
}
lats *= sched.init_noise_sigma; sched.init_noise_sigma
tensor(14.6146)
sched.sigmas
tensor([14.6146, 13.3974, 12.3033, 11.3184, 10.4301, 9.6279, 8.9020, 8.2443,
7.6472, 7.1044, 6.6102, 6.1594, 5.7477, 5.3709, 5.0258, 4.7090,
4.4178, 4.1497, 3.9026, 3.6744, 3.4634, 3.2680, 3.0867, 2.9183,
2.7616, 2.6157, 2.4794, 2.3521, 2.2330, 2.1213, 2.0165, 1.9180,
1.8252, 1.7378, 1.6552, 1.5771, 1.5031, 1.4330, 1.3664, 1.3030,
1.2427, 1.1852, 1.1302, 1.0776, 1.0272, 0.9788, 0.9324, 0.8876,
0.8445, 0.8029, 0.7626, 0.7236, 0.6858, 0.6490, 0.6131, 0.5781,
0.5438, 0.5102, 0.4770, 0.4443, 0.4118, 0.3795, 0.3470, 0.3141,
0.2805, 0.2455, 0.2084, 0.1672, 0.1174, 0.0292, 0.0000])
sched.timesteps
tensor([999.0000, 984.5217, 970.0435, 955.5652, 941.0870, 926.6087, 912.1304,
897.6522, 883.1739, 868.6957, 854.2174, 839.7391, 825.2609, 810.7826,
796.3043, 781.8261, 767.3478, 752.8696, 738.3913, 723.9130, 709.4348,
694.9565, 680.4783, 666.0000, 651.5217, 637.0435, 622.5652, 608.0870,
593.6087, 579.1304, 564.6522, 550.1739, 535.6957, 521.2174, 506.7391,
492.2609, 477.7826, 463.3043, 448.8261, 434.3478, 419.8696, 405.3913,
390.9130, 376.4348, 361.9565, 347.4783, 333.0000, 318.5217, 304.0435,
289.5652, 275.0870, 260.6087, 246.1304, 231.6522, 217.1739, 202.6957,
188.2174, 173.7391, 159.2609, 144.7826, 130.3043, 115.8261, 101.3478,
86.8696, 72.3913, 57.9130, 43.4348, 28.9565, 14.4783, 0.0000],
dtype=torch.float64)
plt.plot(sched.timesteps, sched.sigmas[:-1])
The denoising process can now begin!
from tqdm.auto import tqdm
for i, ts in enumerate(tqdm(sched.timesteps)):
1 inp = torch.cat([lats] * 2)
2 inp = sched.scale_model_input(inp, ts)
3 with torch.no_grad(): preds = unet(inp, ts, encoder_hidden_states=embs).sample
4 pred_uncond, pred_txt = preds.chunk(2)
pred = pred_uncond + g_scale * (pred_txt -</