Example #1
0
def main(args):
    np.random.seed(args.seed)
    use_cuda = args.cuda and torch.cuda.is_available()
    device = 'cuda' if use_cuda else 'cpu'

    model = FactorVAE(args.z_dim).to(device)
    model_found = load_checkpoint(model, args.dir, args.name, device)

    if not model_found:
        return

    gcam = GradCAM(model.encode, args.target_layer, device, args.image_size)

    _, dataset = return_data(args)

    input = dataset[np.arange(0, args.sample_count)][0].to(device)
    recon, mu, logvar, z = model(input)

    input, recon = input.repeat(1, 3, 1, 1), recon.repeat(1, 3, 1, 1)

    maps = gcam.generate(z)
    maps = maps.transpose(0,1)

    first_cam, second_cam = [], []
    for map in maps:
        response = map.flatten(1).sum(1)
        argmax = torch.argmax(response).item()
        first_cam.append(normalize_tensor(map[argmax]))

        response = torch.cat((response[:argmax], response[argmax+1:]))
        second_cam.append(normalize_tensor(map[torch.argmax(response).item()]))

    first_cam = ((torch.stack(first_cam, axis=1)).transpose(0,1)).unsqueeze(1)
    second_cam = ((torch.stack(second_cam, axis=1)).transpose(0,1)).unsqueeze(1)

    input, recon, first_cam, second_cam = process_imgs(input.detach(), recon.detach(), first_cam.detach(), second_cam.detach(), args.sample_count)

    heatmap = add_heatmap(input, first_cam)
    heatmap2 = add_heatmap(input, second_cam)

    input = np.uint8(np.asarray(input, dtype=np.float)*255)
    recon = np.uint8(np.asarray(recon, dtype=np.float)*255)
    grid = np.concatenate((input, heatmap, heatmap2))

    cv2.imshow('Attention Maps of ' + args.name, grid)
    cv2.waitKey(0)
Example #2
0
def main():
    parser = argparse.ArgumentParser(
        description='Explainable VAE MNIST Example')
    parser.add_argument('--result_dir',
                        type=str,
                        default='test_results',
                        metavar='DIR',
                        help='output directory')
    parser.add_argument('--batch_size',
                        type=int,
                        default=128,
                        metavar='N',
                        help='input batch size for training (default: 128)')
    parser.add_argument('--seed',
                        type=int,
                        default=1,
                        metavar='S',
                        help='random seed (default: 1)')

    # model options
    parser.add_argument('--latent_size',
                        type=int,
                        default=32,
                        metavar='N',
                        help='latent vector size of encoder')
    parser.add_argument('--model_path',
                        type=str,
                        default='./ckpt/model_best.pth',
                        metavar='DIR',
                        help='pretrained model directory')
    parser.add_argument('--one_class',
                        type=int,
                        default=8,
                        metavar='N',
                        help='outlier digit for one-class VAE testing')

    args = parser.parse_args()

    torch.manual_seed(args.seed)

    kwargs = {'num_workers': 1, 'pin_memory': True} if cuda else {}

    one_class = args.one_class  # Choose the current outlier digit to be 8
    one_mnist_test_dataset = OneClassMnist.OneMNIST(
        './data', one_class, train=False, transform=transforms.ToTensor())

    test_loader = torch.utils.data.DataLoader(one_mnist_test_dataset,
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              **kwargs)

    model = ConvVAE(args.latent_size).to(device)
    checkpoint = torch.load(args.model_path)
    model.load_state_dict(checkpoint['state_dict'])
    mu_avg, logvar_avg = 0, 1
    gcam = GradCAM(model, target_layer='encoder.2', cuda=True)
    test_index = 0
    for batch_idx, (x, _) in enumerate(test_loader):
        model.eval()
        x = x.to(device)
        x_rec, mu, logvar = gcam.forward(x)

        model.zero_grad()
        gcam.backward(mu, logvar, mu_avg, logvar_avg)
        gcam_map = gcam.generate()

        ## Visualize and save attention maps  ##
        x = x.repeat(1, 3, 1, 1)
        for i in range(x.size(0)):
            raw_image = x[i] * 255.0
            ndarr = raw_image.permute(1, 2, 0).cpu().byte().numpy()
            im = Image.fromarray(ndarr.astype(np.uint8))
            im_path = args.result_dir
            if not os.path.exists(im_path):
                os.mkdir(im_path)
            im.save(
                os.path.join(
                    im_path, "{}-{}-origin.png".format(test_index,
                                                       str(one_class))))

            file_path = os.path.join(
                im_path, "{}-{}-attmap.png".format(test_index, str(one_class)))
            r_im = np.asarray(im)
            save_cam(r_im, file_path, gcam_map[i].squeeze().cpu().data.numpy())
            test_index += 1
def main(args):
    """
    Main Function for testing and saving attention maps.
    Inputs:
        args - Namespace object from the argument parser
    """

    torch.manual_seed(args.seed)

    # Load dataset
    if args.dataset == 'mnist':
        test_dataset = OneClassMnist.OneMNIST('./data',
                                              args.one_class,
                                              train=False,
                                              transform=transforms.ToTensor())
    elif args.dataset == 'ucsd_ped1':
        test_dataset = Ped1_loader.UCSDAnomalyDataset('./data',
                                                      train=False,
                                                      resize=args.image_size)
    elif args.dataset == 'mvtec_ad':
        class_name = mvtec.CLASS_NAMES[args.one_class]
        test_dataset = mvtec.MVTecDataset(class_name=class_name,
                                          is_train=False,
                                          grayscale=False,
                                          root_path=args.data_path)

    test_steps = len(test_dataset)
    kwargs = {
        'num_workers': args.num_workers,
        'pin_memory': True
    } if device == "cuda" else {}
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=args.batch_size,
                                              shuffle=True,
                                              **kwargs)

    # Select a model architecture
    if args.model == 'vanilla_mnist':
        imshape = [1, 28, 28]
        model = ConvVAE_mnist(args.latent_size).to(device)
    elif args.model == 'vanilla_ped1':
        imshape = [1, args.image_size, args.image_size]
        model = ConvVAE_ped1(args.latent_size, args.image_size,
                             args.batch_norm).to(device)
    elif args.model == 'resnet18_3':
        imshape = [3, 256, 256]
        model = ResNet18VAE_3(args.latent_size,
                              x_dim=imshape[-1],
                              nc=imshape[0]).to(device)

    print("Layer is:", args.target_layer)

    # Load model
    checkpoint = torch.load(args.model_path)
    model.load_state_dict(checkpoint['state_dict'])
    mu_avg, logvar_avg = (0, 1)
    gcam = GradCAM(model, target_layer=args.target_layer, device=device)

    prediction_stack = np.zeros((test_steps, imshape[-1], imshape[-1]),
                                dtype=np.float32)
    gt_mask_stack = np.zeros((test_steps, imshape[-1], imshape[-1]),
                             dtype=np.uint8)

    # Generate attention maps
    for batch_idx, (x, y) in enumerate(test_loader):

        # print("batch_idx", batch_idx)
        model.eval()
        x = x.to(device)
        x_rec, mu, logvar = gcam.forward(x)

        model.zero_grad()
        gcam.backward(mu, logvar, mu_avg, logvar_avg)
        gcam_map = gcam.generate()
        gcam_max = torch.max(gcam_map).item()

        # If image has one channel, make it three channel(need for heatmap)
        if x.size(1) == 1:
            x = x.repeat(1, 3, 1, 1)

        # Visualize and save attention maps
        for i in range(x.size(0)):
            x_arr = x[i].permute(1, 2, 0).cpu().numpy() * 255
            x_im = Image.fromarray(x_arr.astype(np.uint8))

            # Get the gradcam for this image
            prediction = gcam_map[i].squeeze().cpu().data.numpy()

            # Add prediction and mask to the stacks
            prediction_stack[batch_idx * args.batch_size + i] = prediction
            gt_mask_stack[batch_idx * args.batch_size + i] = y[i]

            if save_gcam_image:
                im_path = args.result_dir
                if not os.path.exists(im_path):
                    os.mkdir(im_path)
                x_im.save(
                    os.path.join(im_path,
                                 "{}-{}-origin.png".format(batch_idx, i)))
                file_path = os.path.join(
                    im_path, "{}-{}-attmap.png".format(batch_idx, i))
                save_gradcam(x_arr, file_path, prediction, gcam_max=gcam_max)

    # Stop of dataset is mnist because there aren't GTs available
    if args.dataset != 'mnist':

        # Compute area under the ROC score
        auc = roc_auc_score(gt_mask_stack.flatten(),
                            prediction_stack.flatten())
        print(f"AUROC score: {auc}")

        fpr, tpr, thresholds = roc_curve(gt_mask_stack.flatten(),
                                         prediction_stack.flatten())
        if plot_ROC:
            plt.plot(tpr, fpr, label="ROC")
            plt.xlabel("FPR")
            plt.ylabel("TPR")
            plt.legend()
            plt.savefig(
                str(args.result_dir) + "auroc_" + str(args.model) +
                str(args.target_layer) + str(args.one_class) + ".png")

            # Compute IoU
        if args.iou == True:
            print(f"IoU score: {j_score}")
            max_val = np.max(prediction_stack)

            max_steps = 100
            best_thres = 0
            best_iou = 0
            # Ge the IoU for 100 different thresholds
            for i in range(1, max_steps):
                thresh = i / max_steps * max_val
                prediction_bin_stack = prediction_stack > thresh
                iou = jaccard_score(gt_mask_stack.flatten(),
                                    prediction_bin_stack.flatten())
                if iou > best_iou:
                    best_iou = iou
                    best_thres = thresh
            print("Best threshold;", best_thres)
            print("Best IoU score:", best_iou)

    return
def main(args):

    # Load the synset words
    file_name = 'synset_words.txt'
    classes = list()
    with open(file_name) as class_file:
        for line in class_file:
            classes.append(line.strip().split(' ', 1)[1].split(', ',
                                                               1)[0].replace(
                                                                   ' ', '_'))

    print('Loading a model...')
    model = torchvision.models.resnet152(pretrained=True)
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])

    print('\nGrad-CAM')
    gcam = GradCAM(model=model,
                   target_layer='layer4.2',
                   n_class=1000,
                   cuda=args.cuda)
    gcam.load_image(args.image, transform)
    gcam.forward()

    for i in range(0, 5):
        gcam.backward(idx=gcam.idx[i])
        cls_name = classes[gcam.idx[i]]
        output = gcam.generate()
        print('\t{:.5f}\t{}'.format(gcam.prob[i], cls_name))
        gcam.save('results/{}_gcam.png'.format(cls_name), output)

    print('\nBackpropagation')
    bp = BackPropagation(model=model,
                         target_layer='conv1',
                         n_class=1000,
                         cuda=args.cuda)
    bp.load_image(args.image, transform)
    bp.forward()

    for i in range(0, 5):
        bp.backward(idx=bp.idx[i])
        cls_name = classes[bp.idx[i]]
        output = bp.generate()
        print('\t{:.5f}\t{}'.format(bp.prob[i], cls_name))
        bp.save('results/{}_bp.png'.format(cls_name), output)

    print('\nGuided Backpropagation')
    gbp = GuidedBackPropagation(model=model,
                                target_layer='conv1',
                                n_class=1000,
                                cuda=args.cuda)
    gbp.load_image(args.image, transform)
    gbp.forward()

    for i in range(0, 5):
        cls_idx = gcam.idx[i]
        cls_name = classes[cls_idx]

        gcam.backward(idx=cls_idx)
        output_gcam = gcam.generate()

        gbp.backward(idx=cls_idx)
        output_gbp = gbp.generate()

        output_gcam -= output_gcam.min()
        output_gcam /= output_gcam.max()
        output_gcam = cv2.resize(output_gcam, (224, 224))
        output_gcam = cv2.cvtColor(output_gcam, cv2.COLOR_GRAY2BGR)

        output = output_gbp * output_gcam

        print('\t{:.5f}\t{}'.format(gbp.prob[i], cls_name))
        gbp.save('results/{}_gbp.png'.format(cls_name), output_gbp)
        gbp.save('results/{}_ggcam.png'.format(cls_name), output)
    activation_layer = 'block5_conv3'

    img_path = '../images/cat_dog.jpg'
    img = load_image(path=img_path, target_size=(img_width, img_height))

    preds = model.predict(img)
    predicted_class = preds.argmax(axis=1)[0]
    # decode the results into a list of tuples (class, description, probability)
    # (one such list for each sample in the batch)
    print("predicted top1 class:", predicted_class)
    print('Predicted:', decode_predictions(preds, top=1)[0])
    # Predicted: [(u'n02504013', u'Indian_elephant', 0.82658225), (u'n01871265', u'tusker', 0.1122357), (u'n02504458', u'African_elephant', 0.061040461)]

    # create Grad-CAM generator
    gradcam_generator = GradCAM(model, activation_layer, predicted_class)
    grad_cam, grad_val = gradcam_generator.generate(img)

    # create Convolution Visualizer
    vis_conv = VisConvolution(model, VGG16, activation_layer)
    gradient = vis_conv.generate(img)

    img = cv2.imread(img_path)
    img = cv2.resize(img, (img_width, img_height))

    grad_cam = grad_cam / grad_cam.max()
    grad_cam = grad_cam * 255
    grad_cam = cv2.resize(grad_cam, (img_width, img_height))
    grad_cam = np.uint8(grad_cam)

    cv_cam = cv2.applyColorMap(grad_cam, cv2.COLORMAP_JET)
    fin = cv2.addWeighted(cv_cam, 0.5, img, 0.5, 0)
Example #6
0
def gradCAM():
    # Chap 2 : Train Network
    print('\n[Chapter 2] : Operate gradCAM with trained network')

    # Phase 1 : Model Upload
    print('\n[Phase 1] : Model Weight Upload')
    use_gpu = torch.cuda.is_available()

    # upload labels
    data_dir = cf.test_dir
    trainset_dir = cf.data_base.split("/")[-1] + os.sep

    dsets = datasets.ImageFolder(data_dir, None)
    H = datasets.ImageFolder(cf.aug_base + '/train/')
    dset_classes = H.classes

    def softmax(x):
        return np.exp(x) / np.sum(np.exp(x), axis=0)
        # return np.exp(x) / np.sum(np.exp(x), axis=1)


    def getNetwork(opts):
        if (opts.net_type == 'alexnet'):
            file_name = 'alexnet'
        elif (opts.net_type == 'vggnet'):
            file_name = 'vgg-%s' % (opts.depth)
        elif (opts.net_type == 'resnet'):
            file_name = 'resnet-%s' % (opts.depth)
        else:
            print('[Error]: Network should be either [alexnet / vgget / resnet]')
            sys.exit(1)

        return file_name

    def random_crop(image, dim):
        if len(image.shape):
            W, H, D = image.shape
            w, h, d = dim
        else:
            W, H = image.shape
            w, h = dim[0], dim[1]

        left, top = np.random.randint(W - w + 1), np.random.randint(H - h + 1)
        return image[left:left + w, top:top + h], left, top

    # uploading the model
    print("| Loading checkpoint model for grad-CAM...")
    assert os.path.isdir('./path'), '[Error]: No checkpoint directory found!'
    assert os.path.isdir('./path/' + trainset_dir), '[Error]: There is no model weight to upload!'
    file_name = getNetwork(opts)
    checkpoint = torch.load('./path/' + trainset_dir + file_name + '.t7')
    model = checkpoint['model']

    if use_gpu:
        model.cuda()
        cudnn.benchmark = True

    model.eval()

    sample_input = Variable(torch.randn(1, 3, 224, 224), volatile=False)
    if use_gpu:
        sampe_input = sample_input.cuda()

    def is_image(f):
        return f.endswith(".png") or f.endswith(".jpg")

    test_transform = transforms.Compose([
        transforms.Resize(224),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(cf.mean, cf.std)
    ])

    """
    #@ Code for inference test

    img = Image.open(cf.image_path)
    if test_transform is not None:
        img = test_transform(img)
    inputs = img
    inputs = Variable(inputs, volatile=False, requires_grad=True)

    if use_gpu:
        inputs = inputs.cuda()
    inputs = inputs.view(1, inputs.size(0), inputs.size(1), inputs.size(2))

    outputs = model(inputs)
    softmax_res = softmax(outputs.data.cpu().numpy()[0])

    index,score = max(enumerate(softmax_res), key=operator.itemgetter(1))

    print('| Uploading %s' %(cf.image_path.split("/")[-1]))
    print('| prediction = ' + dset_classes[index])
    """

    # @ Code for extracting a grad-CAM region for a given class
    gcam = GradCAM(list(model._modules.items())[0][1],
                   cuda=use_gpu)  # model=model._modules.items()[0][1], cuda=use_gpu)
    gbp = GuidedBackPropagation(model=list(model._modules.items())[0][1], cuda=use_gpu)

    # print(dset_classes)
    WBC_id = 5  # BHX class
    print("Checking Activated Regions for " + dset_classes[WBC_id] + "...")

    fileList = os.listdir('./samples/')
    i = 1
    for f in fileList:
        file_name = './samples/' + f
        print("Opening " + file_name + "...")

        original_image = cv2.imread(file_name)
        resize_ratio = 224. / min(original_image.shape[0:2])
        resized = cv2.resize(original_image, (0, 0), fx=resize_ratio, fy=resize_ratio)
        cropped, left, top = random_crop(resized, (224, 224, 3))
        print(cropped.size)
        if test_transform is not None:
            img = test_transform(Image.fromarray(cropped, mode='RGB'))
        # center_cropped = original_image[16:240, 16:240, :]
        # expand the image based on the short side

        inputs = img
        inputs = Variable(inputs, requires_grad=True)

        if use_gpu:
            inputs = inputs.cuda()
        inputs = inputs.view(1, inputs.size(0), inputs.size(1), inputs.size(2))

        probs, idx = gcam.forward(inputs)
        # probs, idx = gbp.forward(inputs)

        # Grad-CAM
        gcam.backward(idx=WBC_id)
        if opts.depth == 18:
            output = gcam.generate(target_layer='layer4.1')
        else:
            output = gcam.generate(target_layer='layer4.2')  # a module name to be visualized (required)

        # Guided Back Propagation
        # gbp.backward(idx=WBC_id)
        # feature = gbp.generate(target_layer='conv1')

        # Guided Grad-CAM
        # output = np.multiply(feature, region)

        gcam.save('./results/%s.png' % str(i), output, cropped)
        cv2.imwrite('./results/map%s.png' % str(i), cropped)

        for j in range(3):
            print('\t{:5f}\t{}\n'.format(probs[j], dset_classes[idx[j]]))

        i += 1

    """
Example #7
0
class Solver(object):
    def __init__(self, args):
        self.args = args

        # Misc
        use_cuda = args.cuda and torch.cuda.is_available()
        self.device = 'cuda' if use_cuda else 'cpu'
        self.name = args.name
        self.max_iter = int(args.max_iter)
        self.print_iter = args.print_iter
        self.global_iter = 0
        self.pbar = tqdm(total=self.max_iter)

        # Data
        assert args.dataset == 'dsprites', 'Only dSprites is implemented'
        self.dset_dir = args.dset_dir
        self.dataset = args.dataset
        self.batch_size = args.batch_size
        self.data_loader, self.dataset = return_data(args)

        # Networks & Optimizers
        self.z_dim = args.z_dim
        self.gamma = args.gamma

        self.lr_VAE = args.lr_VAE
        self.beta1_VAE = args.beta1_VAE
        self.beta2_VAE = args.beta2_VAE

        self.lr_D = args.lr_D
        self.beta1_D = args.beta1_D
        self.beta2_D = args.beta2_D

        # Disentanglement score
        self.L = args.L
        self.vote_count = args.vote_count
        self.dis_score = args.dis_score
        self.dis_batch_size = args.dis_batch_size

        # Models and optimizers
        self.VAE = FactorVAE(self.z_dim).to(self.device)
        self.nc = 1

        self.optim_VAE = optim.Adam(self.VAE.parameters(),
                                    lr=self.lr_VAE,
                                    betas=(self.beta1_VAE, self.beta2_VAE))

        self.D = Discriminator(self.z_dim).to(self.device)
        self.optim_D = optim.Adam(self.D.parameters(),
                                  lr=self.lr_D,
                                  betas=(self.beta1_D, self.beta2_D))

        self.nets = [self.VAE, self.D]

        # Attention Disentanglement loss
        self.ad_loss = args.ad_loss
        self.lamb = args.lamb
        if self.ad_loss:
            self.gcam = GradCAM(self.VAE.encode, args.target_layer,
                                self.device, args.image_size)
            self.pick2 = True

        # Checkpoint
        self.ckpt_dir = os.path.join(args.ckpt_dir,
                                     args.name + '_' + str(args.seed))
        self.ckpt_save_iter = args.ckpt_save_iter
        if self.max_iter >= args.ckpt_save_iter:
            mkdirs(self.ckpt_dir)
        if args.ckpt_load:
            self.load_checkpoint(args.ckpt_load)

        # Results
        self.results_dir = os.path.join(args.results_dir,
                                        args.name + '_' + str(args.seed))
        self.results_save = args.results_save

        self.outputs = {
            'vae_recon_loss': [],
            'vae_kld': [],
            'vae_tc_loss': [],
            'D_tc_loss': [],
            'ad_loss': [],
            'dis_score': [],
            'iteration': []
        }

    def train(self):
        self.net_mode(train=True)

        ones = torch.ones(self.batch_size,
                          dtype=torch.long,
                          device=self.device)
        zeros = torch.zeros(self.batch_size,
                            dtype=torch.long,
                            device=self.device)

        out = False
        while not out:
            for x_true1, x_true2 in self.data_loader:
                self.global_iter += 1
                self.pbar.update(1)

                x_true1 = x_true1.to(self.device)
                x_recon, mu, logvar, z = self.VAE(x_true1)
                vae_recon_loss = recon_loss(x_true1, x_recon)
                vae_ad_loss = self.get_ad_loss(z)
                vae_kld = kl_divergence(mu, logvar)

                D_z = self.D(z)
                vae_tc_loss = (D_z[:, :1] - D_z[:, 1:]).mean()

                vae_loss = vae_recon_loss + vae_kld + self.gamma * vae_tc_loss + self.lamb * vae_ad_loss

                x_true2 = x_true2.to(self.device)
                z_prime = self.VAE(x_true2, no_dec=True)
                z_pperm = permute_dims(z_prime).detach()
                D_z_pperm = self.D(z_pperm)
                D_tc_loss = 0.5 * (F.cross_entropy(D_z, zeros) +
                                   F.cross_entropy(D_z_pperm, ones))

                self.optim_VAE.zero_grad()
                vae_loss.backward(retain_graph=True)

                self.optim_D.zero_grad()
                D_tc_loss.backward()

                self.optim_VAE.step()
                self.optim_D.step()

                if self.global_iter % self.print_iter == 0:
                    if self.dis_score:
                        dis_score = disentanglement_score(
                            self.VAE.eval(), self.device, self.dataset,
                            self.z_dim, self.L, self.vote_count,
                            self.dis_batch_size)
                        self.VAE.train()
                    else:
                        dis_score = torch.tensor(0)

                    self.pbar.write(
                        '[{}] vae_recon_loss:{:.3f} vae_kld:{:.3f} vae_tc_loss:{:.3f} ad_loss:{:.3f} D_tc_loss:{:.3f} dis_score:{:.3f}'
                        .format(self.global_iter, vae_recon_loss.item(),
                                vae_kld.item(), vae_tc_loss.item(),
                                vae_ad_loss.item(), D_tc_loss.item(),
                                dis_score.item()))

                    if self.results_save:
                        self.outputs['vae_recon_loss'].append(
                            vae_recon_loss.item())
                        self.outputs['vae_kld'].append(vae_kld.item())
                        self.outputs['vae_tc_loss'].append(vae_tc_loss.item())
                        self.outputs['D_tc_loss'].append(D_tc_loss.item())
                        self.outputs['ad_loss'].append(vae_ad_loss.item())
                        self.outputs['dis_score'].append(dis_score.item())
                        self.outputs['iteration'].append(self.global_iter)

                if self.global_iter % self.ckpt_save_iter == 0:
                    self.save_checkpoint(self.global_iter)

                if self.global_iter >= self.max_iter:
                    out = True
                    break

        self.pbar.write("[Training Finished]")
        self.pbar.close()

        if self.results_save:
            save_args_outputs(self.results_dir, self.args, self.outputs)

    def get_ad_loss(self, z):
        if not self.ad_loss:
            return torch.tensor(0)

        z_picked = z[:, random.randint(0, self.z_dim, size=2)]
        M = self.gcam.generate(z_picked)

        return ad_loss(M.flatten(1), self.batch_size, self.pick2)

    def net_mode(self, train):
        if not isinstance(train, bool):
            raise ValueError('Only bool type is supported. True|False')

        for net in self.nets:
            if train:
                net.train()
            else:
                net.eval()

    def save_checkpoint(self, ckptname='last', verbose=True):
        model_states = {'D': self.D.state_dict(), 'VAE': self.VAE.state_dict()}
        optim_states = {
            'optim_D': self.optim_D.state_dict(),
            'optim_VAE': self.optim_VAE.state_dict()
        }
        states = {
            'iter': self.global_iter,
            'model_states': model_states,
            'optim_states': optim_states
        }

        filepath = os.path.join(self.ckpt_dir, str(ckptname))
        with open(filepath, 'wb+') as f:
            torch.save(states, f)
        if verbose:
            self.pbar.write("=> saved checkpoint '{}' (iter {})".format(
                filepath, self.global_iter))

    def load_checkpoint(self, ckptname='last', verbose=True):
        if ckptname == 'last':
            ckpts = os.listdir(self.ckpt_dir)
            if not ckpts:
                if verbose:
                    self.pbar.write("=> no checkpoint found")
                return

            ckpts = [int(ckpt) for ckpt in ckpts]
            ckpts.sort(reverse=True)
            ckptname = str(ckpts[0])

        filepath = os.path.join(self.ckpt_dir, ckptname)
        if os.path.isfile(filepath):
            with open(filepath, 'rb') as f:
                checkpoint = torch.load(f)

            self.global_iter = checkpoint['iter']
            self.VAE.load_state_dict(checkpoint['model_states']['VAE'])
            self.D.load_state_dict(checkpoint['model_states']['D'])
            self.optim_VAE.load_state_dict(
                checkpoint['optim_states']['optim_VAE'])
            self.optim_D.load_state_dict(checkpoint['optim_states']['optim_D'])
            self.pbar.update(self.global_iter)
            if verbose:
                self.pbar.write("=> loaded checkpoint '{} (iter {})'".format(
                    filepath, self.global_iter))
        else:
            if verbose:
                self.pbar.write(
                    "=> no checkpoint found at '{}'".format(filepath))