! pip install -Uqq fastcore transformers diffusers
(Un)successfully Implementing DiffEdit
The (Un)expected Difficulties of Editing
This notebook follows the fastai style guide.
Well, my implementation was a partial success: I managed to generate a mask, but failed to apply it. If you don’t understand, hold on as I’ll explain DiffEdit.
In this notebook, I try to implement the DiffEdit paper: a diffusion algorithm that allows us to replace the subject of an image with another subject, simply through a text prompt.
This notebook does not closely follow the paper implementation of DiffEdit. However, it does capture the underlying mechanisms.
In a nutshell, this is done by generating a mask from the text prompt. This mask cuts out the subject from the image, which allows a new subject to be added to the image.
While I was successful in generating a mask, I wasn’t successful in applying it to an image. So at the end of this notebook, I’ll use the Hugging Face Stable Diffusion Inpaint Pipeline to see the mask in action.
If you would like a refresher on how Stable Diffusion can be implemented from its various components, you can read my post on this here.
Basic Workings
Let’s say we have an image of a horse in front of a forest. We want to replace the horse with a zebra. At a high level, DiffEdit achieves this in the following manner.
- Using our image, we generate a further image with the prompt ‘horse’.
- We similarly generate another further image with the prompt ‘zebra’.
- The difference between both generated images is then taken.
- The difference is normalized1 and binarized2 to obtain the mask.
- We again generate an image with the prompt ‘zebra’.
- However this time, after each denoising step, apply the mask to the latent to obtain a cutout of the zebra.
- Then add the noised background pixels of the original image to the cutout.
1 In this case, normalizing means scaling the values to be between 0 and 1.
2 Binarizing means making values to be any of 2 possible values. In this case, either 0 or 1.
Setup
1import logging; logging.disable(logging.WARNING)
from fastcore.all import *
from fastai.imports import *
from fastai.vision.all import *
- 1
- Hugging Face can be verbose.
Get Components
from transformers import CLIPTokenizer, CLIPTextModel
= CLIPTokenizer.from_pretrained('openai/clip-vit-large-patch14', torch_dtype=torch.float16)
tokz = CLIPTextModel.from_pretrained('openai/clip-vit-large-patch14', torch_dtype=torch.float16).to('cuda') txt_enc
from diffusers import AutoencoderKL, UNet2DConditionModel
= AutoencoderKL.from_pretrained('stabilityai/sd-vae-ft-ema', torch_dtype=torch.float16).to('cuda')
vae = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet", torch_dtype=torch.float16).to("cuda") unet
from diffusers import LMSDiscreteScheduler
= LMSDiscreteScheduler(
sched = 0.00085,
beta_start = 0.012,
beta_end = 'scaled_linear',
beta_schedule = 1000
num_train_timesteps )
Simple Loop
In this simple loop, I’m making sure I can correctly generate an image based on another image as the starting point.
Hyperparameters
= ['earth']
prompt = ['']
neg_prompt = 512, 512
w, h = 50
n_inf_steps = 8
g_scale = 1
bs = 77 seed
Encode Prompt
= tokz(
txt_inp
prompt,= 'max_length',
padding = tokz.model_max_length,
max_length = True,
truncation = 'pt',
return_tensors )
= txt_enc(txt_inp['input_ids'].to('cuda'))[0].half() txt_emb
= tokz(
neg_inp ''] * bs,
[= 'max_length',
padding = txt_inp['input_ids'].shape[-1],
max_length = 'pt'
return_tensors )
= txt_enc(neg_inp['input_ids'].to('cuda'))[0].half() neg_emb
= torch.cat([neg_emb, txt_emb]) embs
Compress Image
!curl --output planet.png 'https://images.unsplash.com/photo-1630839437035-dac17da580d0?ixlib=rb-4.0.3&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=2515&q=80'
% Total % Received % Xferd Average Speed Time Time Time Current
Dload Upload Total Spent Left Speed
0 0 0 0 0 0 0 0 --:--:-- --:--:-- --:--:-- 0100 188k 100 188k 0 0 4829k 0 --:--:-- --:--:-- --:--:-- 4829k
= Image.open('/content/planet.png').resize((512, 512)); img img
import torchvision.transforms as T
with torch.no_grad():
= T.ToTensor()(img).unsqueeze(0).half().to('cuda') * 2 - 1
img = vae.encode(img)
lat = 0.18215 * lat.latent_dist.sample(); lat.shape lat
Below we can see the all 4 channels of the compressed image.
= plt.subplots(1, 4, figsize=(16, 4))
fig, axs for c in range(4):
0][c].cpu(), cmap='Greys') axs[c].imshow(lat[
Noise Image
= LMSDiscreteScheduler(
sched =0.00085,
beta_start=0.012,
beta_end='scaled_linear',
beta_schedule=1000
num_train_timesteps; sched )
LMSDiscreteScheduler {
"_class_name": "LMSDiscreteScheduler",
"_diffusers_version": "0.16.1",
"beta_end": 0.012,
"beta_schedule": "scaled_linear",
"beta_start": 0.00085,
"num_train_timesteps": 1000,
"prediction_type": "epsilon",
"trained_betas": null
}
sched.set_timesteps(n_inf_steps)
torch.manual_seed(seed)= torch.randn_like(lat)
noise = sched.timesteps.to(torch.float32)
sched.timesteps = 10
start_step = tensor([sched.timesteps[start_step]])
ts = sched.add_noise(lat, noise, timesteps=ts) lat
Denoise
from tqdm.auto import tqdm
for i, ts in enumerate(tqdm(sched.timesteps)):
if i >= start_step:
= torch.cat([lat] * 2)
inp = sched.scale_model_input(inp, ts)
inp
with torch.no_grad(): preds = unet(inp, ts, encoder_hidden_states=embs)['sample']
= preds.chunk(2)
pred_neg, pred_txt = pred_neg + g_scale * (pred_txt - pred_neg)
pred
= sched.step(pred, ts, lat).prev_sample lat
Uncompress
lat.shape
torch.Size([1, 4, 64, 64])
*= (1/0.18215)
lat with torch.no_grad(): img = vae.decode(lat).sample
= (img / 2 + 0.5).clamp(0, 1)
img = img[0].detach().cpu().permute(1, 2, 0).numpy()
img = (img * 255).round().astype('uint8')
img Image.fromarray(img)
Encapsulate
I’ll encapsulate the code above so we can focus on DiffEdit.
def get_embs(prompt, neg_prompt):
= tok_seq(prompt)
txt_inp = calc_emb(txt_inp['input_ids'])
txt_emb
= tok_seq(neg_prompt)
neg_inp = calc_emb(neg_inp['input_ids'])
neg_emb
return torch.cat([neg_emb, txt_emb])
def tok_seq(prompt):
return tokz(
prompt,= 'max_length',
padding = tokz.model_max_length,
max_length = True,
truncation = 'pt',
return_tensors
)
def calc_emb(inp_ids):
return txt_enc(inp_ids.to('cuda'))[0].half()
def get_lat(img, start_step=30):
return noise_lat(compress_img(img), start_step)
def compress_img(img):
with torch.no_grad():
= T.ToTensor()(img).unsqueeze(0).half().to('cuda') * 2 - 1
img = vae.encode(img)
lat return 0.18215 * lat.latent_dist.sample()
def noise_lat(lat, start_step):
torch.manual_seed(seed)= torch.randn_like(lat)
noise
sched.set_timesteps(n_inf_steps)= sched.timesteps.to(torch.float32)
sched.timesteps = tensor([sched.timesteps[start_step]])
ts
return sched.add_noise(lat, noise, timesteps=ts)
def denoise(lat, ts):
= torch.cat([lat] * 2)
inp = sched.scale_model_input(inp, ts)
inp
with torch.no_grad(): preds = unet(inp, ts, encoder_hidden_states=embs)['sample']
= preds.chunk(2)
pred_neg, pred_txt = pred_neg + g_scale * (pred_txt - pred_neg)
pred
return sched.step(pred, ts, lat).prev_sample
def decompress(lat):
with torch.no_grad(): img = vae.decode(lat*(1/0.18215)).sample
= (img / 2 + 0.5).clamp(0, 1)
img = img[0].detach().cpu().permute(1, 2, 0).numpy()
img return (img * 255).round().astype('uint8')
= ['basketball']
prompt = ['']
neg_prompt = 512, 512
w, h = 70
n_inf_steps = 30
start_step = 7.5
g_scale = 1
bs = 77
seed
! curl --output img.png 'https://images.unsplash.com/photo-1630839437035-dac17da580d0?ixlib=rb-4.0.3&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=2515&q=80'
= Image.open('/content/img.png').resize((512, 512))
img
= get_embs(prompt, neg_prompt)
embs = get_lat(img)
lat for i, ts in enumerate(tqdm(sched.timesteps)):
if i >= start_step: lat = denoise(lat, ts)
= decompress(lat)
img Image.fromarray(img)
% Total % Received % Xferd Average Speed Time Time Time Current
Dload Upload Total Spent Left Speed
0 0 0 0 0 0 0 0 --:--:-- --:--:-- --:--:-- 0100 188k 100 188k 0 0 5232k 0 --:--:-- --:--:-- --:--:-- 5381k
DiffEdit
Let’s review the steps of DiffEdit once more.
- Using our image, we generate a further image with the prompt ‘horse’.
- We similarly generate another further image with the prompt ‘zebra’.
- The difference between both generated images is then taken.
- The difference is normalized3 and binarized4 to obtain the mask.
- We then again generate an image with the prompt ‘zebra’.
- However this time, after each denoising step, apply the mask to the latent to obtain a cutout of the zebra.
- Then add the noised background pixels of the original image to the cutout.
3 In this case, normalizing means scaling the values to be between 0 and 1.
4 Binarizing means making values to be any of 2 possible values. In this case, either 0 or 1.
Obtain two latents
First, we need to obtain an image of a horse and an image of a zebra.
We’ll use this as our original image.
! curl --output img.png 'https://images.unsplash.com/photo-1553284965-fa61e9ad4795?ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D&auto=format&fit=crop&w=1742&q=80'
open('/content/img.png').resize((512, 512)) Image.
% Total % Received % Xferd Average Speed Time Time Time Current
Dload Upload Total Spent Left Speed
0 0 0 0 0 0 0 0 --:--:-- --:--:-- --:--:-- 0100 515k 100 515k 0 0 10.7M 0 --:--:-- --:--:-- --:--:-- 10.7M
This is the generated image of the horse.
= ['horse']
prompt = Image.open('/content/img.png').resize((512, 512))
img = get_embs(prompt, neg_prompt)
embs = get_lat(img)
lat1 for i, ts in enumerate(tqdm(sched.timesteps)):
if i >= start_step: lat1 = denoise(lat1, ts)
Image.fromarray(decompress(lat1))
And this is the generated image of the zebra.
= ['zebra']
prompt = Image.open('/content/img.png').resize((512, 512))
img = get_embs(prompt, neg_prompt)
embs = get_lat(img)
lat2 for i, ts in enumerate(tqdm(sched.timesteps)):
if i >= start_step: lat2 = denoise(lat2, ts)
Image.fromarray(decompress(lat2))
lat1[:].shape
torch.Size([1, 4, 64, 64])
Create Mask
We’ll first convert the generated images to grayscale and then take their difference.
import torchvision.transforms.functional as F
= F.to_tensor(F.to_grayscale(Image.fromarray(decompress(lat1[:]))))
img1 = F.to_tensor(F.to_grayscale(Image.fromarray(decompress(lat2[:]))))
img2 = torch.abs(img1 - img2) diff
Then we’ll normalize the difference to have values between 0 and 1.
= diff / torch.max(diff)
norm *255).squeeze().numpy().round().astype(np.uint8)) Image.fromarray((norm
And then finally binarize the values so they are either 0 or 1.
= 0.5
thresh bin = (norm > thresh).float()
bin.squeeze().numpy()*255).astype(np.uint8)) Image.fromarray((
bin.squeeze().numpy()*255).astype(np.uint8)).save('mask.png') Image.fromarray((
Now we need to apply transformations to the binarized mask so it encapsulates the shape of the horbra/zeborse (horse + zebra 🫤).
import cv2 as cv
from google.colab.patches import cv2_imshow
= cv.imread('mask.png', cv.IMREAD_GRAYSCALE)
mask = cv.getStructuringElement(cv.MORPH_ELLIPSE, (10, 10)) kernel
The kernel is essentially a shape. Multiple shapes are be applied to the image in order to perform transformations.
I’ve chosen to use an ellipse of size 10 by 10 units.
Applying an erosion transformation makes our binarized mask look like this. Such transformations remove can remove small, noisy objects.
cv2_imshow(cv.erode(mask, kernel))
Applying a dilation transformation makes our binarized mask look like this. Such transformations can fill in gaps and smooth edges.
cv2_imshow(cv.dilate(mask, kernel))
To produce the final mask, I’ll apply the closing transform5 7 times consecutively…
5 The closing transform is a dilation transform followed immediately by an erosion transform. This allows holes or small black points to be closed.
= mask
mask_closed for _ in range(7):
= cv.morphologyEx(mask_closed, cv.MORPH_CLOSE, kernel)
mask_closed cv2_imshow(mask_closed)
…and then apply the dilation transform 3 times consecutively.
= mask_closed
mask_dilated for _ in range(3):
= cv.dilate(mask_dilated, kernel)
mask_dilated cv2_imshow(mask_dilated)
A more concise way of doing the above.
= cv.morphologyEx(mask, cv.MORPH_CLOSE, kernel, iterations=7)
mask_closed = cv.dilate(mask_closed, kernel, iterations=3)
mask_dilated
cv2_imshow(mask_dilated)
Then I’ll stack the mask together so I have a 3 channel image.
= np.stack((mask_dilated, mask_dilated, mask_dilated), axis=-1)/255; mask.shape mask
(512, 512, 3)
To read more about such transformations applied above, you can read them at the OpenCV docs here.
Apply Mask
Now for the part I couldn’t figure out how to do.
By applying the mask to the original iamge. This is how the cutout of the horse looks like.
= torch.mul(F.to_tensor(img).permute(1, 2, 0), torch.from_numpy(mask))
fore *255).numpy().round().astype('uint8')) Image.fromarray((fore
You can see that it does not exactly cut out the outline: this is good because different subjects will have different levels of protrusion.
And this is how the background pixels look like.
= 1 - mask
inv_mask = torch.mul(F.to_tensor(img).permute(1, 2, 0), torch.from_numpy(inv_mask))
back *255).numpy().round().astype('uint8')) Image.fromarray((back
Adding both the foreground and the background together…
+back)*255).numpy().round().astype(np.uint8)) Image.fromarray(((fore
Detour
Note the subtle, yet very important difference in the two cells below, along with their output.
= tensor([1, 2, 3])
x def foo(y):
+= 1
y return y
foo(x) x
tensor([2, 3, 4])
= tensor([1, 2, 3])
x def foo(y):
= y + 1
z return z
foo(x) x
tensor([1, 2, 3])
This was the reason for the bug that had me pulling my hair out for hours — when you pass a list or any list-like object (or even just objects I think), a copy is not passed, but rather the same object.
However, I can’t quite correctly apply the mask to the latent when denoising.
= ['zebra']
prompt = Image.open('/content/img.png').resize((512, 512))
img = get_embs(prompt, neg_prompt)
embs = get_lat(img)
lat = 1 - mask
inv_mask
for i, ts in enumerate(tqdm(sched.timesteps)):
if i >= start_step:
= torch.mul(torch.from_numpy(decompress(get_lat(img, start_step=i)))/255, torch.from_numpy(inv_mask))
back = torch.mul(torch.from_numpy(decompress(lat))/255, torch.from_numpy(mask))
fore = (back + fore)*255
bafo = compress_img(Image.fromarray(bafo.numpy().round().astype(np.uint8))) lat
Image.fromarray(decompress(lat))
After asking on the fastai forum, and hours of fiddling about, the reason why this is happening is most likely due to the fact that I keep uncompressing and recompressing the latent. The compression that the VAE performs is lossy, so detail is lost during each compression and decompression.
My mask is not calculated in the same latent space as my latent. In other words, my mask was calculated as a 512x512 pixel and 3 channel image, whereas my latent is a 64x64 pixel and 4 channel image. I’m uncompressing the latent so that I can apply the mask to cutout the zebra and add the background pixels, and then recompressing.
To fix this, I would need to generate the mask as a 64x64 pixel and 3 channel image.
To at least see the mask in action, let’s use the Hugging Face Stable Diffusion Pipeline.
Pipeline
The Hugging Face Stable Diffusion pipeline works by simply providing the starting image and a mask. The pipeline will handle the rest.
from diffusers import StableDiffusionInpaintPipeline
= StableDiffusionInpaintPipeline.from_pretrained(
pipe "runwayml/stable-diffusion-inpainting",
="fp16",
revision=torch.float16,
torch_dtype"cuda") ).to(
/usr/local/lib/python3.10/dist-packages/transformers/models/clip/feature_extraction_clip.py:28: FutureWarning: The class CLIPFeatureExtractor is deprecated and will be removed in version 5 of Transformers. Please use CLIPImageProcessor instead.
warnings.warn(
img
77)
torch.manual_seed(# 35 or 25 steps are good
= pipe(
out =["zebra"],
prompt=img,
image=Image.fromarray((mask*255).round().astype(np.uint8)),
mask_image= 25
num_inference_steps
).images0] out[
Takeaways
Looking back, the actual problem for me was that I let the paper feel intimidating; all those symbols, variables, jargon, and notation. I ended up glazing over the paper and missing the smaller details.
To help prevent this the next time, I should
- list out the variables and what they represent
- write out the steps in simpler terms
- and take a deep breath before reading, so I take things slowly.
And that’s that.
If you have any comments, questions, suggestions, feedback, criticisms, or corrections, please do post them down in the comment section below!