Heusel, M., Ramsauer, H., Unterthiner, T., Nessler, B., & Hochreiter, S. (2017). Gans trained by a two time-scale update rule converge to a local nash equilibrium. In Advances in neural information processing systems (pp. 6626-6637).

The data

path = download_file_from_google_drive(URLs.CELEBA, 'celebA.zip', folder_name='img_align_celeba')
files = get_image_files(path.parent / 'img_align_celeba')
PILImage.create(files[0])

Fréchet Inception Distance (FID)

The original implementation of the FID used the tensorflow weights of the inceptionv3 from 2015-12-05, As noted by hukkelas.

In this implementation, by default, we use the pytorch/vision:v0.6.0 weights that produce different outputs. Nevertheless, the trends and the order of magnitud are the same. If you want, you can set weight='old' to use the original weights translated from TesnorFlow to PyTorch by mseitzer.

class Inception[source]

Inception(weights='new', renormalize=False)

class FIDMetric[source]

FIDMetric(model, dl, get_prediction=noop) :: Metric

Blueprint for defining a metric

tfdl = TfmdLists(files[:1000], [PILImage.create, ToTensor])
dl = tfdl.dataloaders(after_batch=[Resize(64, method=ResizeMethod.Squish),
                                   IntToFloatTensor(),
                                   Normalize.from_stats(*imagenet_stats)])

inception = Inception(weights='old')
fid = FIDMetric(inception, dl.train)
Using cache found in /home/andres/.cache/torch/hub/pytorch_vision_v0.6.0
/home/andres/anaconda3/envs/fastai21/lib/python3.8/site-packages/torchvision/models/inception.py:75: FutureWarning: The default weight initialization of inception_v3 will be changed in future releases of torchvision. If you wish to keep the old behavior (which leads to long initialization times due to scipy/scipy#11299), please set init_weights=True.
  warnings.warn('The default weight initialization of inception_v3 will be changed in future releases of '
100.00% [15/15 00:03<00:00]
class GaussianNoise(DisplayedTransform):
    order = 11
    def __init__(self, alpha):
        self.alpha = alpha
    def encodes(self, x:(Image.Image, TensorImage)): 
        N = torch.randn(x.shape, device=x.device)
        N -= N.min()
        N /= N.max()
        return (1-self.alpha)*x+self.alpha*N
    def decodes(self, x): return x
imgs = []
fids = []
noise_levels = master_bar([0, 0.25, 0.5, 0.75])
for noise in noise_levels:
    tfdl = TfmdLists(files[:10000], [PILImage.create, ToTensor])
    dl = tfdl.dataloaders(after_batch=[Resize(64, method=ResizeMethod.Squish),
                                       IntToFloatTensor(),
                                       GaussianNoise(noise),
                                       Normalize.from_stats(*imagenet_stats)],
                         shuffle_train=False)
    fid.reset()
    for b in progress_bar(dl.train, parent=noise_levels): fid.total = torch.cat([fid.total, fid.func(b).cpu()])
    fid.count = len(dl.train)
    fids.append(fid.value)
    imgs.append(TensorImage(dl.decode(dl.one_batch())[0]))
plt.plot(list(noise_levels), fids)
plt.xlabel('Noise Value')
plt.ylabel('FID')
plt.grid(True)
ImageNTuple(imgs).show();
class Blur(DisplayedTransform):
    order = 11
    def __init__(self, alpha):
        self.alpha = max(alpha, 1)
        padding = (self.alpha-1)//2
        self.conv = torch.nn.Conv2d(3, 3, self.alpha, groups=3, bias=False, padding=padding).cuda()
        torch.nn.init.constant_(self.conv.weight, 1/self.alpha**2)

    def encodes(self, x:(Image.Image, TensorImage)): return self.conv(x)
imgs = []
fids = []
noise_levels = master_bar([0, 3, 5, 9])
for noise in noise_levels:
    tfdl = TfmdLists(files[:10000], [PILImage.create, ToTensor])
    dl = tfdl.dataloaders(after_batch=[Resize(299, method=ResizeMethod.Squish),
                                       IntToFloatTensor(),
                                       Blur(noise),
                                       Normalize.from_stats(*imagenet_stats)],
                         shuffle_train=False)
    fid.reset()
    for b in progress_bar(dl.train, parent=noise_levels):
        fid.total = torch.cat([fid.total, fid.func(b).cpu()])
    fid.count = len(dl.train)
    fids.append(fid.value)
    imgs.append(TensorImage(dl.decode(dl.one_batch())[0]))
plt.plot(list(noise_levels), fids)
plt.xlabel('Blur Value')
plt.ylabel('FID')
plt.grid(True)
ImageNTuple(imgs).show();