def load_patched_inception_v3():
    # inception = inception_v3(pretrained=True)
    # inception_feat = Inception3Feature()
    # inception_feat.load_state_dict(inception.state_dict())
    inception_feat = InceptionV3([3], normalize_input=False)

    return inception_feat
예제 #2
0
def calculate_fid_given_dataset(dataloader, model_s, batch_size, cuda, dims, device, num_images):
    """Calculates the FID"""

    block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]

    model = InceptionV3([block_idx])
    if cuda:
        model.to(device)

    m1, s1 = _compute_statistics_of_given_dataset(dataloader, model, batch_size,
                                                  dims, cuda, device, num_images)
    m2, s2 = _compute_statistics_of_generate(model_s, model, batch_size,
                                             dims, cuda, device, num_images)
    fid_value = calculate_frechet_distance(m1, s1, m2, s2)

    return fid_value
예제 #3
0
def calculate_fid_given_paths(paths, batch_size, cuda, dims):
    """Calculates the FID of two paths"""
    for p in paths:
        if not os.path.exists(p):
            raise RuntimeError('Invalid path: %s' % p)

    block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]

    model = InceptionV3([block_idx])
    if cuda:
        model.cuda()
    m1, s1 = _compute_statistics_of_path(paths[0], model, batch_size, dims,
                                         cuda)
    m2, s2 = _compute_statistics_of_path(paths[1], model, batch_size, dims,
                                         cuda)
    fid_value = calculate_frechet_distance(m1, s1, m2, s2)

    return fid_value
예제 #4
0
def calculate_fid_given_paths(paths, batch_size, device, dims, lenet, num_workers=8):
    """Calculates the FID of two paths"""
    for p in paths:
        if not os.path.exists(p):
            raise RuntimeError('Invalid path: %s' % p)

    if lenet is None:
        # TODO: Inception Net 
        block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
        model = InceptionV3([block_idx]).to(device)
        pass
    else:
        block_idx = LeNet5.BLOCK_INDEX_BY_DIM[dims]
        model = LeNet5(lenet,[block_idx]).to(device)

    m1, s1 = compute_statistics_of_path(paths[0], model, batch_size,
                                        dims, device, num_workers)
    m2, s2 = compute_statistics_of_path(paths[1], model, batch_size,
                                        dims, device, num_workers)
    fid_value = calculate_frechet_distance(m1, s1, m2, s2)

    return fid_value
예제 #5
0
def calculate_fid_given_dataset_sanity(cfg, dataset, model_s, batch_size, cuda,
                                       dims, device, num_images):
    """Calculates the FID of actual images and images from dataset, should be 0"""
    print("fid sanity check")

    path = str(Path.home()) + '/../../mnt/data/tal/celebhq_256_train'

    block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]

    model = InceptionV3([block_idx])
    if cuda:
        model.to(device)

    m1, s1 = _compute_statistics_of_given_dataset(cfg, dataset, model,
                                                  batch_size, dims, cuda,
                                                  device, num_images)
    m2, s2 = _compute_statistics_of_generate(cfg, model_s, model, batch_size,
                                             dims, cuda, device, num_images)
    # m1, s1 = _compute_statistics_of_path(path, model, batch_size,
    #                                      dims, cuda, device)
    fid_value = calculate_frechet_distance(m1, s1, m2, s2)

    return fid_value
예제 #6
0
def inception_score(imgs,
                    cuda=True,
                    batch_size=33,
                    resize=False,
                    splits=1,
                    classifier=None,
                    log_logit=False,
                    true_dist=None,
                    requires_grad=True):
    """Computes the inception score of the generated images imgs

    imgs -- Torch dataset of (3xHxW) numpy images normalized in the range [-1, 1]
    cuda -- whether or not to run on GPU
    batch_size -- batch size for feeding into Classifier
    splits -- number of splits
    """

    N = len(imgs)

    assert true_dist is not None, "true_dist argument should not be None"
    assert batch_size > 0
    assert N > batch_size

    if cuda:
        device = torch.device('cuda:0')
    else:
        device = torch.device('cpu')

    # Set up dtype
    if cuda:
        dtype = torch.cuda.FloatTensor
    else:
        if torch.cuda.is_available():
            print(
                "WARNING: You have a CUDA device, so you should probably set cuda=True"
            )
        dtype = torch.FloatTensor

    # Set up dataloader
    dataloader = torch.utils.data.DataLoader(imgs, batch_size=batch_size)

    ### Load pretrained classifier
    if classifier is None:
        # Load inception model
        block_idx = InceptionV3.BLOCK_INDEX_BY_DIM['prob']
        model = InceptionV3([block_idx],
                            requires_grad=requires_grad).to(device)
        model.eval()

        upsample = torch.nn.Upsample((299, 299),
                                     mode='bilinear',
                                     align_corners=False)

        def get_pred(x):
            if resize:
                x = upsample(x)
            out = model(x)
            if not requires_grad:
                out = out.data
            return out
    else:
        classifier.eval()

        def get_pred(x):
            x = classifier(x)
            if log_logit:
                out = x.exp()
            else:
                out = x
            if not requires_grad:
                out = out.data
            return out

    # Get predictions
    output_sample = next(iter(dataloader))
    if cuda:
        output_sample = output_sample.cuda()
    output_shape = get_pred(output_sample).shape
    preds = torch.zeros((N, output_shape[-1]))
    for i, batch in enumerate(dataloader, 0):
        batch = batch.type(dtype)
        batch_size_i = batch.size()[0]
        if cuda:
            batch = batch.cuda()

        preds[i * batch_size:i * batch_size + batch_size_i] = get_pred(batch)

    if true_dist is not None and cuda:
        true_dist = true_dist.cpu()

    # Now compute the mean kl-div
    split_scores = torch.zeros(splits)
    split_reg_term = torch.zeros(splits)
    split_mis = torch.zeros(splits)

    kl_d = torch.nn.KLDivLoss(reduction='sum')

    for k in range(splits):
        part = preds[k * (N // splits):(k + 1) * (N // splits), :]

        py = torch.mean(part, axis=0)

        scores = torch.zeros(part.shape[0])
        reg_term = torch.zeros(part.shape[0])
        for i in range(part.shape[0]):
            pyx = part[i, :]
            scores[i] = kl_d(py.log(), pyx)
            reg_term[i] = 0.5 * (kl_d(py.log(), 0.5 * (true_dist + py)) +
                                 kl_d(true_dist.log(), 0.5 * (py + true_dist)))

        temp_scores2 = torch.zeros(10)
        for n in range(10):
            part_n = part[torch.argmax(part, dim=1) == n]
            #part_n = part
            #            print('part_n', part_n)
            if part_n.shape[0] == 0:
                continue
            px_js = torch.mean(part_n, axis=0)
            for m in range(part_n.shape[0]):
                px_i = part_n[m, :]
                temp_scores2[n] += kl_d(px_js.log(), px_i)
            temp_scores2[n] /= part_n.shape[0]

        scores2 = torch.mean(temp_scores2[temp_scores2 != 0.])

        split_scores[k] = torch.exp(torch.mean(scores))
        split_reg_term[k] = torch.exp(torch.mean(reg_term))
        split_mis[k] = torch.exp(torch.mean(scores2))

    return torch.mean(split_scores), torch.mean(split_reg_term), torch.mean(
        split_mis)
예제 #7
0
def inception_score(imgs,
                    cuda=True,
                    batch_size=32,
                    resize=False,
                    splits=1,
                    classifier=None,
                    log_logit=False,
                    requires_grad=False):
    """Computes the inception score of the generated images imgs

    imgs -- Torch dataset of (3xHxW) numpy images normalized in the range [-1, 1]
    cuda -- whether or not to run on GPU
    batch_size -- batch size for feeding into Classifier
    splits -- number of splits
    """

    if cuda:
        device = torch.device('cuda:0')
    else:
        device = torch.device('cpu')

    N = len(imgs)

    assert batch_size > 0
    assert N > batch_size

    # Set up dtype
    if cuda:
        dtype = torch.cuda.FloatTensor
    else:
        if torch.cuda.is_available():
            print(
                "WARNING: You have a CUDA device, so you should probably set cuda=True"
            )
        dtype = torch.FloatTensor

    # Set up dataloader
    dataloader = torch.utils.data.DataLoader(imgs, batch_size=batch_size)

    ### Load pretrained classifier
    if classifier is None:
        # Load inception model
        block_idx = InceptionV3.BLOCK_INDEX_BY_DIM['prob']
        model = InceptionV3([block_idx],
                            requires_grad=requires_grad).to(device)
        model.eval()

        upsample = torch.nn.Upsample((299, 299),
                                     mode='bilinear',
                                     align_corners=False)

        def get_pred(x):
            if resize:
                x = upsample(x)
            return model(x)
    else:
        classifier.eval()

        def get_pred(x):
            x = classifier(x)
            if log_logit:
                out = x.exp()
            else:
                out = x
            if requires_grad == False:
                out = out.data
            return out

    # Get predictions
    output_sample = next(iter(dataloader))
    if cuda:
        output_sample = output_sample.cuda()
    output_shape = get_pred(output_sample)[0].shape
    preds = torch.zeros((N, output_shape[-1]))

    for i, batch in enumerate(tqdm(dataloader), 0):
        batch = batch.type(dtype)
        batch_size_i = batch.size()[0]
        if cuda:
            batch = batch.cuda()

        preds[i * batch_size:i * batch_size +
              batch_size_i] = get_pred(batch)[0].cpu()

    # Now compute the mean kl-div
    split_scores = torch.zeros(splits)

    kl_d = torch.nn.KLDivLoss(reduction='sum')

    for k in range(splits):
        part = preds[k * (N // splits):(k + 1) * (N // splits), :]
        py = torch.mean(part, axis=0)
        scores = torch.zeros(part.shape[0])
        for i in range(part.shape[0]):
            pyx = part[i, :]
            scores[i] = kl_d(py.log(), pyx)
        split_scores[k] = torch.exp(torch.mean(scores))

    return torch.mean(split_scores), torch.std(split_scores)