(Un)successfully Implementing DiffEdit

The (Un)expected Difficulties of Editing

Diffusion
Creating Models
Papers
An attempt at implementing the DiffEdit paper.
Author

Salman Naqvi

Published

Monday, 29 May 2023

This notebook follows the fastai style guide.

Well, my implementation was a partial success: I managed to generate a mask, but failed to apply it. If you don’t understand, hold on as I’ll explain DiffEdit.

In this notebook, I try to implement the DiffEdit paper: a diffusion algorithm that allows us to replace the subject of an image with another subject, simply through a text prompt.

This notebook does not closely follow the paper implementation of DiffEdit. However, it does capture the underlying mechanisms.

In a nutshell, this is done by generating a mask from the text prompt. This mask cuts out the subject from the image, which allows a new subject to be added to the image.

While I was successful in generating a mask, I wasn’t successful in applying it to an image. So at the end of this notebook, I’ll use the Hugging Face Stable Diffusion Inpaint Pipeline to see the mask in action.

If you would like a refresher on how Stable Diffusion can be implemented from its various components, you can read my post on this here.

Basic Workings

Let’s say we have an image of a horse in front of a forest. We want to replace the horse with a zebra. At a high level, DiffEdit achieves this in the following manner.

  1. Using our image, we generate a further image with the prompt ‘horse’.
  2. We similarly generate another further image with the prompt ‘zebra’.
  3. The difference between both generated images is then taken.
  4. The difference is normalized1 and binarized2 to obtain the mask.
  5. We again generate an image with the prompt ‘zebra’.
    • However this time, after each denoising step, apply the mask to the latent to obtain a cutout of the zebra.
    • Then add the noised background pixels of the original image to the cutout.

1 In this case, normalizing means scaling the values to be between 0 and 1.

2 Binarizing means making values to be any of 2 possible values. In this case, either 0 or 1.

Setup

! pip install -Uqq fastcore transformers diffusers
1import logging; logging.disable(logging.WARNING)
from fastcore.all import *
from fastai.imports import *
from fastai.vision.all import *
1
Hugging Face can be verbose.

Get Components

from transformers import CLIPTokenizer, CLIPTextModel

tokz = CLIPTokenizer.from_pretrained('openai/clip-vit-large-patch14', torch_dtype=torch.float16)
txt_enc = CLIPTextModel.from_pretrained('openai/clip-vit-large-patch14', torch_dtype=torch.float16).to('cuda')
from diffusers import AutoencoderKL, UNet2DConditionModel

vae = AutoencoderKL.from_pretrained('stabilityai/sd-vae-ft-ema', torch_dtype=torch.float16).to('cuda')
unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet", torch_dtype=torch.float16).to("cuda")
from diffusers import LMSDiscreteScheduler

sched = LMSDiscreteScheduler(
    beta_start = 0.00085,
    beta_end = 0.012,
    beta_schedule = 'scaled_linear',
    num_train_timesteps = 1000
)

Simple Loop

In this simple loop, I’m making sure I can correctly generate an image based on another image as the starting point.

Hyperparameters

prompt = ['earth']
neg_prompt = ['']
w, h = 512, 512
n_inf_steps = 50
g_scale = 8
bs = 1
seed = 77

Encode Prompt

txt_inp = tokz(
    prompt,
    padding = 'max_length',
    max_length = tokz.model_max_length,
    truncation = True,
    return_tensors = 'pt',
)
txt_emb = txt_enc(txt_inp['input_ids'].to('cuda'))[0].half()
neg_inp = tokz(
    [''] * bs,
    padding = 'max_length',
    max_length = txt_inp['input_ids'].shape[-1],
    return_tensors = 'pt'
)
neg_emb = txt_enc(neg_inp['input_ids'].to('cuda'))[0].half()
embs = torch.cat([neg_emb, txt_emb])

Compress Image

!curl --output planet.png 'https://images.unsplash.com/photo-1630839437035-dac17da580d0?ixlib=rb-4.0.3&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=2515&q=80'
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0100  188k  100  188k    0     0  4829k      0 --:--:-- --:--:-- --:--:-- 4829k
img = Image.open('/content/planet.png').resize((512, 512)); img

import torchvision.transforms as T
with torch.no_grad():
  img = T.ToTensor()(img).unsqueeze(0).half().to('cuda') * 2 - 1
  lat = vae.encode(img)
  lat = 0.18215 * lat.latent_dist.sample(); lat.shape

Below we can see the all 4 channels of the compressed image.

fig, axs = plt.subplots(1, 4, figsize=(16, 4))
for c in range(4):
  axs[c].imshow(lat[0][c].cpu(), cmap='Greys')

Noise Image

sched = LMSDiscreteScheduler(
    beta_start=0.00085,
    beta_end=0.012,
    beta_schedule='scaled_linear',
    num_train_timesteps=1000
); sched
LMSDiscreteScheduler {
  "_class_name": "LMSDiscreteScheduler",
  "_diffusers_version": "0.16.1",
  "beta_end": 0.012,
  "beta_schedule": "scaled_linear",
  "beta_start": 0.00085,
  "num_train_timesteps": 1000,
  "prediction_type": "epsilon",
  "trained_betas": null
}
sched.set_timesteps(n_inf_steps)
torch.manual_seed(seed)
noise = torch.randn_like(lat)
sched.timesteps = sched.timesteps.to(torch.float32)
start_step = 10
ts = tensor([sched.timesteps[start_step]])
lat = sched.add_noise(lat, noise, timesteps=ts)

Denoise

from tqdm.auto import tqdm

for i, ts in enumerate(tqdm(sched.timesteps)):
  if i >= start_step:
    inp = torch.cat([lat] * 2)
    inp = sched.scale_model_input(inp, ts)

    with torch.no_grad(): preds = unet(inp, ts, encoder_hidden_states=embs)['sample']

    pred_neg, pred_txt = preds.chunk(2)
    pred = pred_neg + g_scale * (pred_txt - pred_neg)

    lat = sched.step(pred, ts, lat).prev_sample

Uncompress

lat.shape
torch.Size([1, 4, 64, 64])
lat *= (1/0.18215)
with torch.no_grad(): img = vae.decode(lat).sample
img = (img / 2 + 0.5).clamp(0, 1)
img = img[0].detach().cpu().permute(1, 2, 0).numpy()
img = (img * 255).round().astype('uint8')
Image.fromarray(img)

Encapsulate

I’ll encapsulate the code above so we can focus on DiffEdit.

def get_embs(prompt, neg_prompt):
  txt_inp = tok_seq(prompt)
  txt_emb = calc_emb(txt_inp['input_ids'])

  neg_inp = tok_seq(neg_prompt)
  neg_emb = calc_emb(neg_inp['input_ids'])

  return torch.cat([neg_emb, txt_emb])

def tok_seq(prompt):
  return tokz(
      prompt,
      padding = 'max_length',
      max_length = tokz.model_max_length,
      truncation = True,
      return_tensors = 'pt',
  )

def calc_emb(inp_ids):
  return txt_enc(inp_ids.to('cuda'))[0].half()
def get_lat(img, start_step=30):
  return noise_lat(compress_img(img), start_step)

def compress_img(img):
  with torch.no_grad():
    img = T.ToTensor()(img).unsqueeze(0).half().to('cuda') * 2 - 1
    lat = vae.encode(img)
    return 0.18215 * lat.latent_dist.sample()

def noise_lat(lat, start_step):
  torch.manual_seed(seed)
  noise = torch.randn_like(lat)

  sched.set_timesteps(n_inf_steps)
  sched.timesteps = sched.timesteps.to(torch.float32)
  ts = tensor([sched.timesteps[start_step]])

  return sched.add_noise(lat, noise, timesteps=ts)
def denoise(lat, ts):
  inp = torch.cat([lat] * 2)
  inp = sched.scale_model_input(inp, ts)

  with torch.no_grad(): preds = unet(inp, ts, encoder_hidden_states=embs)['sample']

  pred_neg, pred_txt = preds.chunk(2)
  pred = pred_neg + g_scale * (pred_txt - pred_neg)

  return sched.step(pred, ts, lat).prev_sample
def decompress(lat):
  with torch.no_grad(): img = vae.decode(lat*(1/0.18215)).sample
  img = (img / 2 + 0.5).clamp(0, 1)
  img = img[0].detach().cpu().permute(1, 2, 0).numpy()
  return (img * 255).round().astype('uint8')
prompt = ['basketball']
neg_prompt = ['']
w, h = 512, 512
n_inf_steps = 70
start_step = 30
g_scale = 7.5
bs = 1
seed = 77

! curl --output img.png 'https://images.unsplash.com/photo-1630839437035-dac17da580d0?ixlib=rb-4.0.3&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=2515&q=80'
img = Image.open('/content/img.png').resize((512, 512))

embs = get_embs(prompt, neg_prompt)
lat = get_lat(img)
for i, ts in enumerate(tqdm(sched.timesteps)):
  if i >= start_step: lat = denoise(lat, ts)
img = decompress(lat)
Image.fromarray(img)
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0100  188k  100  188k    0     0  5232k      0 --:--:-- --:--:-- --:--:-- 5381k

DiffEdit

Let’s review the steps of DiffEdit once more.

  1. Using our image, we generate a further image with the prompt ‘horse’.
  2. We similarly generate another further image with the prompt ‘zebra’.
  3. The difference between both generated images is then taken.
  4. The difference is normalized3 and binarized4 to obtain the mask.
  5. We then again generate an image with the prompt ‘zebra’.
    • However this time, after each denoising step, apply the mask to the latent to obtain a cutout of the zebra.
    • Then add the noised background pixels of the original image to the cutout.

3 In this case, normalizing means scaling the values to be between 0 and 1.

4 Binarizing means making values to be any of 2 possible values. In this case, either 0 or 1.

Obtain two latents

First, we need to obtain an image of a horse and an image of a zebra.

We’ll use this as our original image.

! curl --output img.png 'https://images.unsplash.com/photo-1553284965-fa61e9ad4795?ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D&auto=format&fit=crop&w=1742&q=80'
Image.open('/content/img.png').resize((512, 512))
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0100  515k  100  515k    0     0  10.7M      0 --:--:-- --:--:-- --:--:-- 10.7M

This is the generated image of the horse.

prompt = ['horse']
img = Image.open('/content/img.png').resize((512, 512))
embs = get_embs(prompt, neg_prompt)
lat1 = get_lat(img)
for i, ts in enumerate(tqdm(sched.timesteps)):
  if i >= start_step: lat1 = denoise(lat1, ts)
Image.fromarray(decompress(lat1))

And this is the generated image of the zebra.

prompt = ['zebra']
img = Image.open('/content/img.png').resize((512, 512))
embs = get_embs(prompt, neg_prompt)
lat2 = get_lat(img)
for i, ts in enumerate(tqdm(sched.timesteps)):
  if i >= start_step: lat2 = denoise(lat2, ts)
Image.fromarray(decompress(lat2))

lat1[:].shape
torch.Size([1, 4, 64, 64])

Create Mask

We’ll first convert the generated images to grayscale and then take their difference.

import torchvision.transforms.functional as F

img1 = F.to_tensor(F.to_grayscale(Image.fromarray(decompress(lat1[:]))))
img2 = F.to_tensor(F.to_grayscale(Image.fromarray(decompress(lat2[:]))))
diff = torch.abs(img1 - img2)

Then we’ll normalize the difference to have values between 0 and 1.

norm = diff / torch.max(diff)
Image.fromarray((norm*255).squeeze().numpy().round().astype(np.uint8))

And then finally binarize the values so they are either 0 or 1.

thresh = 0.5
bin = (norm > thresh).float()
Image.fromarray((bin.squeeze().numpy()*255).astype(np.uint8))

Image.fromarray((bin.squeeze().numpy()*255).astype(np.uint8)).save('mask.png')

Now we need to apply transformations to the binarized mask so it encapsulates the shape of the horbra/zeborse (horse + zebra 🫤).

import cv2 as cv
from google.colab.patches import cv2_imshow

mask = cv.imread('mask.png', cv.IMREAD_GRAYSCALE)
kernel = cv.getStructuringElement(cv.MORPH_ELLIPSE, (10, 10))

The kernel is essentially a shape. Multiple shapes are be applied to the image in order to perform transformations.

I’ve chosen to use an ellipse of size 10 by 10 units.

Applying an erosion transformation makes our binarized mask look like this. Such transformations remove can remove small, noisy objects.

cv2_imshow(cv.erode(mask, kernel))

Applying a dilation transformation makes our binarized mask look like this. Such transformations can fill in gaps and smooth edges.

cv2_imshow(cv.dilate(mask, kernel))

To produce the final mask, I’ll apply the closing transform5 7 times consecutively…

5 The closing transform is a dilation transform followed immediately by an erosion transform. This allows holes or small black points to be closed.

mask_closed = mask
for _ in range(7):
  mask_closed = cv.morphologyEx(mask_closed, cv.MORPH_CLOSE, kernel)
  cv2_imshow(mask_closed)

…and then apply the dilation transform 3 times consecutively.

mask_dilated = mask_closed
for _ in range(3):
  mask_dilated = cv.dilate(mask_dilated, kernel)
  cv2_imshow(mask_dilated)

A more concise way of doing the above.

mask_closed = cv.morphologyEx(mask, cv.MORPH_CLOSE, kernel, iterations=7)
mask_dilated = cv.dilate(mask_closed, kernel, iterations=3)

cv2_imshow(mask_dilated)

Then I’ll stack the mask together so I have a 3 channel image.

mask = np.stack((mask_dilated, mask_dilated, mask_dilated), axis=-1)/255; mask.shape
(512, 512, 3)

To read more about such transformations applied above, you can read them at the OpenCV docs here.

Apply Mask

Now for the part I couldn’t figure out how to do.

By applying the mask to the original iamge. This is how the cutout of the horse looks like.

fore = torch.mul(F.to_tensor(img).permute(1, 2, 0), torch.from_numpy(mask))
Image.fromarray((fore*255).numpy().round().astype('uint8'))

You can see that it does not exactly cut out the outline: this is good because different subjects will have different levels of protrusion.

And this is how the background pixels look like.

inv_mask = 1 - mask
back = torch.mul(F.to_tensor(img).permute(1, 2, 0), torch.from_numpy(inv_mask))
Image.fromarray((back*255).numpy().round().astype('uint8'))

Adding both the foreground and the background together…

Image.fromarray(((fore+back)*255).numpy().round().astype(np.uint8))


Detour

Note the subtle, yet very important difference in the two cells below, along with their output.

x = tensor([1, 2, 3])
def foo(y):
  y += 1
  return y
foo(x)
x
tensor([2, 3, 4])
x = tensor([1, 2, 3])
def foo(y):
  z = y + 1
  return z
foo(x)
x
tensor([1, 2, 3])

This was the reason for the bug that had me pulling my hair out for hours — when you pass a list or any list-like object (or even just objects I think), a copy is not passed, but rather the same object.


However, I can’t quite correctly apply the mask to the latent when denoising.

prompt = ['zebra']
img = Image.open('/content/img.png').resize((512, 512))
embs = get_embs(prompt, neg_prompt)
lat = get_lat(img)
inv_mask = 1 - mask


for i, ts in enumerate(tqdm(sched.timesteps)):
  if i >= start_step: 
    back = torch.mul(torch.from_numpy(decompress(get_lat(img, start_step=i)))/255, torch.from_numpy(inv_mask))
    fore = torch.mul(torch.from_numpy(decompress(lat))/255, torch.from_numpy(mask))
    bafo = (back + fore)*255
    lat = compress_img(Image.fromarray(bafo.numpy().round().astype(np.uint8)))
Image.fromarray(decompress(lat))

After asking on the fastai forum, and hours of fiddling about, the reason why this is happening is most likely due to the fact that I keep uncompressing and recompressing the latent. The compression that the VAE performs is lossy, so detail is lost during each compression and decompression.

My mask is not calculated in the same latent space as my latent. In other words, my mask was calculated as a 512x512 pixel and 3 channel image, whereas my latent is a 64x64 pixel and 4 channel image. I’m uncompressing the latent so that I can apply the mask to cutout the zebra and add the background pixels, and then recompressing.

To fix this, I would need to generate the mask as a 64x64 pixel and 3 channel image.

To at least see the mask in action, let’s use the Hugging Face Stable Diffusion Pipeline.

Pipeline

The Hugging Face Stable Diffusion pipeline works by simply providing the starting image and a mask. The pipeline will handle the rest.

from diffusers import StableDiffusionInpaintPipeline
pipe = StableDiffusionInpaintPipeline.from_pretrained(
    "runwayml/stable-diffusion-inpainting",
    revision="fp16",
    torch_dtype=torch.float16,
).to("cuda")
/usr/local/lib/python3.10/dist-packages/transformers/models/clip/feature_extraction_clip.py:28: FutureWarning: The class CLIPFeatureExtractor is deprecated and will be removed in version 5 of Transformers. Please use CLIPImageProcessor instead.
  warnings.warn(
img

torch.manual_seed(77)
# 35 or 25 steps are good
out = pipe(
    prompt=["zebra"], 
    image=img, 
    mask_image=Image.fromarray((mask*255).round().astype(np.uint8)), 
    num_inference_steps = 25
).images
out[0]

Takeaways

Looking back, the actual problem for me was that I let the paper feel intimidating; all those symbols, variables, jargon, and notation. I ended up glazing over the paper and missing the smaller details.

To help prevent this the next time, I should

  • list out the variables and what they represent
  • write out the steps in simpler terms
  • and take a deep breath before reading, so I take things slowly.

And that’s that.

If you have any comments, questions, suggestions, feedback, criticisms, or corrections, please do post them down in the comment section below!

Back to top