def sample_cgan_concat_given_labels(netG,
                                    given_labels,
                                    batch_size=100,
                                    denorm=True,
                                    to_numpy=True,
                                    verbose=True):
    '''
    netG: pretrained generator network
    given_labels: float. unnormalized labels. we need to convert them to values in [-1,1]. 
    '''

    ## num of fake images will be generated
    nfake = len(given_labels)

    ## normalize regression
    labels = given_labels / max_label

    ## generate images
    if batch_size > nfake:
        batch_size = nfake

    fake_images = []
    ## concat to avoid out of index errors
    labels = np.concatenate((labels, labels[0:batch_size]), axis=0)

    netG = netG.cuda()
    netG.eval()
    with torch.no_grad():
        if verbose:
            pb = SimpleProgressBar()
        tmp = 0
        while tmp < nfake:
            z = torch.randn(batch_size, dim_gan, dtype=torch.float).cuda()
            c = torch.from_numpy(labels[tmp:(tmp + batch_size)]).type(
                torch.float).cuda()
            batch_fake_images = netG(z, c)
            if denorm:  #denorm imgs to save memory
                assert batch_fake_images.max().item(
                ) <= 1.0 and batch_fake_images.min().item() >= -1.0
                batch_fake_images = batch_fake_images * 0.5 + 0.5
                batch_fake_images = batch_fake_images * 255.0
                batch_fake_images = batch_fake_images.type(torch.uint8)
            fake_images.append(batch_fake_images.detach().cpu())
            tmp += batch_size
            if verbose:
                pb.update(min(float(tmp) / nfake, 1) * 100)

    fake_images = torch.cat(fake_images, dim=0)
    #remove extra entries
    fake_images = fake_images[0:nfake]

    if to_numpy:
        fake_images = fake_images.numpy()

    return fake_images, given_labels
示例#2
0
def cal_labelscore(PreNet, images, labels_assi, min_label_before_shift, max_label_after_shift, batch_size = 200, resize = None, norm_img = False, num_workers=0):
    '''
    PreNet: pre-trained CNN
    images: fake images
    labels_assi: assigned labels
    resize: if None, do not resize; if resize = (H,W), resize images to 3 x H x W
    '''

    PreNet.eval()

    # assume images are nxncximg_sizeximg_size
    n = images.shape[0]
    nc = images.shape[1] #number of channels
    img_size = images.shape[2]
    labels_assi = labels_assi.reshape(-1)

    eval_trainset = IMGs_dataset(images, labels_assi, normalize=False)
    eval_dataloader = torch.utils.data.DataLoader(eval_trainset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

    labels_pred = np.zeros(n+batch_size)

    nimgs_got = 0
    pb = SimpleProgressBar()
    for batch_idx, (batch_images, batch_labels) in enumerate(eval_dataloader):
        batch_images = batch_images.type(torch.float).cuda()
        batch_labels = batch_labels.type(torch.float).cuda()
        batch_size_curr = len(batch_labels)

        if norm_img:
            batch_images = normalize_images(batch_images)

        batch_labels_pred, _ = PreNet(batch_images)
        labels_pred[nimgs_got:(nimgs_got+batch_size_curr)] = batch_labels_pred.detach().cpu().numpy().reshape(-1)

        nimgs_got += batch_size_curr
        pb.update((float(nimgs_got)/n)*100)

        del batch_images; gc.collect()
        torch.cuda.empty_cache()
    #end for batch_idx

    labels_pred = labels_pred[0:n]


    labels_pred = (labels_pred*max_label_after_shift)-np.abs(min_label_before_shift)
    labels_assi = (labels_assi*max_label_after_shift)-np.abs(min_label_before_shift)

    ls_mean = np.mean(np.abs(labels_pred-labels_assi))
    ls_std = np.std(np.abs(labels_pred-labels_assi))

    return ls_mean, ls_std
示例#3
0
def sample_cgan_given_labels(netG, given_labels, batch_size=500):
    '''
    netG: pretrained generator network
    given_labels: float. unnormalized labels. we need to convert them to values in [-1,1]. 
    '''

    ## num of fake images will be generated
    nfake = len(given_labels)

    ## normalize regression
    labels = given_labels / max_label

    ## generate images
    if batch_size > nfake:
        batch_size = nfake

    netG = netG.cuda()
    netG.eval()

    ## concat to avoid out of index errors
    labels = np.concatenate((labels, labels[0:batch_size]), axis=0)

    fake_images = []

    with torch.no_grad():
        pb = SimpleProgressBar()
        tmp = 0
        while tmp < nfake:
            z = torch.randn(batch_size, dim_z, dtype=torch.float).cuda()
            c = torch.from_numpy(labels[tmp:(tmp + batch_size)]).type(
                torch.float).cuda()
            batch_fake_images = netG(z, c)
            fake_images.append(batch_fake_images.detach().cpu().numpy())
            tmp += batch_size
            pb.update(min(float(tmp) / nfake, 1) * 100)

    fake_images = np.concatenate(fake_images, axis=0)
    #remove extra images
    fake_images = fake_images[0:nfake]

    #denomarlized fake images
    if fake_images.max() <= 1.0:
        fake_images = fake_images * 0.5 + 0.5
        fake_images = (fake_images * 255.0).astype(np.uint8)

    return fake_images, given_labels
示例#4
0
def sample_ccgan_given_labels(netG, net_y2h, labels, batch_size = 500, to_numpy=True, denorm=True, verbose=True):
    '''
    netG: pretrained generator network
    labels: float. normalized labels.
    '''

    nfake = len(labels)
    if batch_size>nfake:
        batch_size=nfake

    fake_images = []
    fake_labels = np.concatenate((labels, labels[0:batch_size]))
    netG=netG.cuda()
    netG.eval()
    net_y2h = net_y2h.cuda()
    net_y2h.eval()
    with torch.no_grad():
        if verbose:
            pb = SimpleProgressBar()
        n_img_got = 0
        while n_img_got < nfake:
            z = torch.randn(batch_size, dim_gan, dtype=torch.float).cuda()
            y = torch.from_numpy(fake_labels[n_img_got:(n_img_got+batch_size)]).type(torch.float).view(-1,1).cuda()
            batch_fake_images = netG(z, net_y2h(y))
            if denorm: #denorm imgs to save memory
                assert batch_fake_images.max().item()<=1.0 and batch_fake_images.min().item()>=-1.0
                batch_fake_images = batch_fake_images*0.5+0.5
                batch_fake_images = batch_fake_images*255.0
                batch_fake_images = batch_fake_images.type(torch.uint8)
                # assert batch_fake_images.max().item()>1
            fake_images.append(batch_fake_images.cpu())
            n_img_got += batch_size
            if verbose:
                pb.update(min(float(n_img_got)/nfake, 1)*100)
        ##end while

    fake_images = torch.cat(fake_images, dim=0)
    #remove extra entries
    fake_images = fake_images[0:nfake]
    fake_labels = fake_labels[0:nfake]

    if to_numpy:
        fake_images = fake_images.numpy()

    return fake_images, fake_labels
示例#5
0
def cal_labelscore(PreNet, images, labels_assi, min_label_before_shift, max_label_after_shift, batch_size = 500, resize = None):
    '''
    PreNet: pre-trained CNN
    images: fake images
    labels_assi: assigned labels
    resize: if None, do not resize; if resize = (H,W), resize images to 3 x H x W
    '''
    PreNet.eval()

    # assume images are nxncximg_sizeximg_size
    n = images.shape[0]
    nc = images.shape[1] #number of channels
    img_size = images.shape[2]
    labels_assi = labels_assi.reshape(-1)

    # predict labels
    labels_pred = np.zeros(n)
    with torch.no_grad():
        tmp = 0
        pb = SimpleProgressBar()
        for i in range(n//batch_size):
            pb.update(float(i)*100/(n//batch_size))
            image_tensor = torch.from_numpy(images[tmp:(tmp+batch_size)]).type(torch.float).cuda()
            if resize is not None:
                image_tensor = nn.functional.interpolate(image_tensor, size = resize, scale_factor=None, mode='bilinear', align_corners=False)
            labels_batch, _ = PreNet(image_tensor)
            labels_pred[tmp:(tmp+batch_size)] = labels_batch.detach().cpu().numpy().reshape(-1)
            tmp+=batch_size
        del image_tensor; gc.collect()
        torch.cuda.empty_cache()

    labels_pred = (labels_pred*max_label_after_shift)-np.abs(min_label_before_shift)
    labels_assi = (labels_assi*max_label_after_shift)-np.abs(min_label_before_shift)

    ls_mean = np.mean(np.abs(labels_pred-labels_assi))
    ls_std = np.std(np.abs(labels_pred-labels_assi))

    return ls_mean, ls_std
示例#6
0
def sample_cgan_given_labels(netG,
                             given_labels,
                             class_cutoff_points,
                             batch_size=200,
                             denorm=True,
                             to_numpy=True,
                             verbose=True):
    '''
    given_labels: a numpy array; raw label without any normalization; not class label
    class_cutoff_points: the cutoff points to determine the membership of a give label
    '''

    class_cutoff_points = np.array(class_cutoff_points)
    num_classes = len(class_cutoff_points) - 1

    nfake = len(given_labels)
    given_class_labels = np.zeros(nfake)
    for i in range(nfake):
        curr_given_label = given_labels[i]
        diff_tmp = class_cutoff_points - curr_given_label
        indx_nonneg = np.where(diff_tmp >= 0)[0]
        if len(indx_nonneg
               ) == 1:  #the last element of diff_tmp is non-negative
            curr_given_class_label = num_classes - 1
            assert indx_nonneg[0] == num_classes
        elif len(indx_nonneg) > 1:
            if diff_tmp[indx_nonneg[0]] > 0:
                curr_given_class_label = indx_nonneg[0] - 1
            else:
                curr_given_class_label = indx_nonneg[0]
        given_class_labels[i] = curr_given_class_label
    given_class_labels = np.concatenate(
        (given_class_labels, given_class_labels[0:batch_size]))

    if batch_size > nfake:
        batch_size = nfake
    fake_images = []
    netG = netG.cuda()
    netG.eval()
    with torch.no_grad():
        if verbose:
            pb = SimpleProgressBar()
        tmp = 0
        while tmp < nfake:
            z = torch.randn(batch_size, dim_gan, dtype=torch.float).cuda()
            labels = torch.from_numpy(
                given_class_labels[tmp:(tmp + batch_size)]).type(
                    torch.long).cuda()
            if labels.max().item() > num_classes:
                print("Error: max label {}".format(labels.max().item()))
            batch_fake_images = netG(z, labels)
            if denorm:  #denorm imgs to save memory
                assert batch_fake_images.max().item(
                ) <= 1.0 and batch_fake_images.min().item() >= -1.0
                batch_fake_images = batch_fake_images * 0.5 + 0.5
                batch_fake_images = batch_fake_images * 255.0
                batch_fake_images = batch_fake_images.type(torch.uint8)
                # assert batch_fake_images.max().item()>1
            fake_images.append(batch_fake_images.detach().cpu())
            tmp += batch_size
            if verbose:
                pb.update(min(float(tmp) / nfake, 1) * 100)

    fake_images = torch.cat(fake_images, dim=0)
    #remove extra entries
    fake_images = fake_images[0:nfake]

    if to_numpy:
        fake_images = fake_images.numpy()

    return fake_images, given_labels
示例#7
0
def cal_FID(PreNetFID,
            IMGSr,
            IMGSg,
            batch_size=500,
            resize=None,
            norm_img=False):
    #resize: if None, do not resize; if resize = (H,W), resize images to 3 x H x W

    PreNetFID.eval()

    nr = IMGSr.shape[0]
    ng = IMGSg.shape[0]

    nc = IMGSr.shape[1]  #IMGSr is nrxNCxIMG_SIExIMG_SIZE
    img_size = IMGSr.shape[2]

    if batch_size > min(nr, ng):
        batch_size = min(nr, ng)
        # print("FID: recude batch size to {}".format(batch_size))

    #compute the length of extracted features
    with torch.no_grad():
        test_img = torch.from_numpy(IMGSr[0].reshape(
            (1, nc, img_size, img_size))).type(torch.float).cuda()
        if resize is not None:
            test_img = nn.functional.interpolate(test_img,
                                                 size=resize,
                                                 scale_factor=None,
                                                 mode='bilinear',
                                                 align_corners=False)
        if norm_img:
            test_img = normalize_images(test_img)
        # _, test_features = PreNetFID(test_img)
        test_features = PreNetFID(test_img)
        d = test_features.shape[1]  #length of extracted features

    Xr = np.zeros((nr, d))
    Xg = np.zeros((ng, d))

    #batch_size = 500
    with torch.no_grad():
        tmp = 0
        pb1 = SimpleProgressBar()
        for i in range(nr // batch_size):
            imgr_tensor = torch.from_numpy(IMGSr[tmp:(tmp + batch_size)]).type(
                torch.float).cuda()
            if resize is not None:
                imgr_tensor = nn.functional.interpolate(imgr_tensor,
                                                        size=resize,
                                                        scale_factor=None,
                                                        mode='bilinear',
                                                        align_corners=False)
            if norm_img:
                imgr_tensor = normalize_images(imgr_tensor)
            # _, Xr_tmp = PreNetFID(imgr_tensor)
            Xr_tmp = PreNetFID(imgr_tensor)
            Xr[tmp:(tmp + batch_size)] = Xr_tmp.detach().cpu().numpy()
            tmp += batch_size
            # pb1.update(min(float(i)*100/(nr//batch_size), 100))
            pb1.update(min(max(tmp / nr * 100, 100), 100))
        del Xr_tmp, imgr_tensor
        gc.collect()
        torch.cuda.empty_cache()

        tmp = 0
        pb2 = SimpleProgressBar()
        for j in range(ng // batch_size):
            imgg_tensor = torch.from_numpy(IMGSg[tmp:(tmp + batch_size)]).type(
                torch.float).cuda()
            if resize is not None:
                imgg_tensor = nn.functional.interpolate(imgg_tensor,
                                                        size=resize,
                                                        scale_factor=None,
                                                        mode='bilinear',
                                                        align_corners=False)
            if norm_img:
                imgg_tensor = normalize_images(imgg_tensor)
            # _, Xg_tmp = PreNetFID(imgg_tensor)
            Xg_tmp = PreNetFID(imgg_tensor)
            Xg[tmp:(tmp + batch_size)] = Xg_tmp.detach().cpu().numpy()
            tmp += batch_size
            # pb2.update(min(float(j)*100/(ng//batch_size), 100))
            pb2.update(min(max(tmp / ng * 100, 100), 100))
        del Xg_tmp, imgg_tensor
        gc.collect()
        torch.cuda.empty_cache()

    fid_score = FID(Xr, Xg, eps=1e-6)

    return fid_score