def main():
    # no_of_classes = 20
    no_of_classes = 21
    # FCN8s()
    # pretrained = FCN8s_hand()
    trained = FCN8s_hand()
    # load_npz(pretrained_file, pretrained)
    # trained = pretrained
    load_npz(MODEL_PATH, trained)

    if GPU_DEVICE >= 0:
        # pretrained.to_gpu()
        trained.to_gpu()
    i = 0

    cap = cv2.VideoCapture(VIDEO_PATH)
    if not os.path.exists(VIDEO_PATH):
        print("video file doesn't exist!")
        quit()
    skip_frames(cap, 3)
    while cap.isOpened():
        _, img_np = cap.read()
        show_multi_gcam(trained, img_np)
        # cv2.imshow("mask", img_np)
        # cv2.waitKey(0)
        skip_frames(cap, FRAMES_PER_INFERENCE)
Ejemplo n.º 2
0
def main():
	parser = argparse.ArgumentParser(
	formatter_class=argparse.ArgumentDefaultsHelpFormatter)
	parser.add_argument('--device', type=int, default=0, help='gpu id')
	parser.add_argument('--modelfile', help='pretrained model file of FCN8', required=True)
	parser.add_argument('--lr', type=float, default=5e-5, help='init learning rate')
	parser.add_argument('--name', type=str, default='FCN8_SEG', help='name of the experiment')
	parser.add_argument('--resume', type=bool, default=False, help='resume training or not')
	parser.add_argument('--snapshot', type=str, help='snapshot file to resume from')
	parser.add_argument('--lambda1', default=1, type=float, help='lambda1 param')
	parser.add_argument('--lambda2', default=1, type=float, help='lambda2 param')
	parser.add_argument('--lambda3', default=1.5, type=float, help='lambda3 param')
	#total_loss = self.lambd1 * cl_loss + self.lambd2 * am_loss + self.lambd3*segment_loss

	args = parser.parse_args()


	resume = args.resume
	device = args.device

	if resume:
		load_snapshot_path = args.snapshot
		load_model_path = args.modelfile
		print("Resuming from model {}, snapshot {}".format(load_model_path, load_snapshot_path))
	else:
		pretrained_model_path = args.modelfile

	experiment = args.name
	lr = args.lr
	optim = Adam
	training_interval = (20000, 'iteration')
	snapshot_interval = (1000, 'iteration')
	lambd1 = args.lambda1
	lambd2 = args.lambda2
	lambd3 = args.lambda3
	updtr = VOC_SEG_Updater_v2

	os.makedirs('result/'+experiment, exist_ok=True)
	f = open('result/'+experiment+'/details.txt', "w+")
	f.write("lr - "+str(lr)+"\n")
	f.write("optimizer - "+str(optim)+"\n")
	f.write("lambd1 - "+str(lambd1)+"\n")
	f.write("lambd2 - "+str(lambd2)+"\n")
	f.write("lambd3 - "+str(lambd3)+"\n")
	f.write("training_interval - "+str(training_interval)+"\n")
	f.write("Updater - "+str(updtr)+"\n")
	f.close()

	if resume:
		model = FCN8s_hand()
		chainer.serializers.load_npz(load_model_path, model)
	else:
		model = FCN8s_hand()
		chainer.serializers.load_npz(pretrained_model_path, model)


	if device >= 0:
		model.to_gpu(device)
	dataset = MyTrainingDataset()
	iterator = SerialIterator(dataset, 1, shuffle=False)

	optimizer = Adam(alpha=lr)
	optimizer.setup(model)

	updater = updtr(iterator, optimizer, device=device, lambd1=lambd1, lambd2=lambd2)
	trainer = Trainer(updater, training_interval)
	log_keys = ['epoch', 'iteration', 'main/SG_Loss', 'main/TotalLoss']
	trainer.extend(extensions.LogReport(log_keys, (10, 'iteration'), log_name='log'+experiment))
	trainer.extend(extensions.PrintReport(log_keys), trigger=(100, 'iteration'))
	trainer.extend(extensions.ProgressBar(training_length=training_interval, update_interval=100))
	
	trainer.extend(extensions.snapshot(filename=experiment+'_snapshot_{.updater.iteration}'), trigger=snapshot_interval)
	trainer.extend(extensions.snapshot_object(trainer.updater._optimizers['main'].target, 
		experiment+'_model_{.updater.iteration}'), trigger=snapshot_interval)

	trainer.extend(
		extensions.PlotReport(['main/SG_Loss'], 'iteration', (20, 'iteration'), file_name=experiment + '/sg_loss.png',grid=True, marker=" "))
	trainer.extend(extensions.PlotReport(['main/TotalLoss'], 'iteration',(20, 'iteration'), file_name=experiment+'/total_loss.png', grid=True, marker=" "))
	trainer.extend(extensions.PlotReport(log_keys[2:], 'iteration',(20, 'iteration'), file_name=experiment+'/all_loss.png', grid=True, marker=" "))

	if resume:
		chainer.serializers.load_npz(load_snapshot_path, trainer)
	print("Running - - ", experiment)
	print('initial lr ',lr)
	print('optimizer ', optim)
	print('lambd1 ', lambd1)
	print('lambd2 ', lambd2)
	print('lambd3', lambd3)
	trainer.run()
import os
import fcn
import chainer

from models.fcn8_hand_v2 import FCN8s_hand

if __name__ == '__main__':
    model_own = FCN8s_hand()
    model_original = fcn.models.FCN8s()
    model_file = fcn.models.FCN8s.download()
    chainer.serializers.load_npz(model_file, model_original)

    ignored_layers = {
        'score_fr', 'upscore2', 'upscore8', 'score_pool3', 'score_pool4',
        'upscore_pool4'
    }
    print("Copying layers from pretrained fcn8s to fcn8s_hand")
    for layers in model_original._children:
        if layers not in ignored_layers:
            print('Copying {}'.format(layers))
            assert str(getattr(model_original, layers)) == str(getattr(model_own, layers)), 'Layer shape mismatch for layer {}!\noriginal: {}\nNew:{}'\
                .format(layers, getattr(model_original, layers), getattr(model_own, layers))
            setattr(model_own, layers, getattr(model_original, layers))
        else:
            print("Ignored copying attributes from layer {}".format(layers))
    print('\nSaving...')
    chainer.serializers.save_npz('fcn8s_hand_gain_pretrained.npz', model_own)
def main():
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--device', type=int, default=-1, help='gpu id')
    parser.add_argument('--lr_init',
                        type=float,
                        default=1 * 1e-7,
                        help='init learning rate')
    # parser.add_argument('--lr_trigger', type=float, default=5, help='trigger to decreace learning rate')
    # parser.add_argument('--lr_target', type=float, default=5*1e-5, help='target learning rate')
    # parser.add_argument('--lr_factor', type=float, default=.75, help='decay factor')
    parser.add_argument('--name',
                        type=str,
                        default='classifier_gain_dropout',
                        help='name of the experiment')
    parser.add_argument(
        '--modelfile',
        type=str,
        help='name of the model to resume from or if starting anew, the '
        'pretrained FCN8s_Hand model with empty final layers',
        required=True)
    parser.add_argument('--resume',
                        type=bool,
                        default=False,
                        help='resume training or not')
    parser.add_argument('--snapshot',
                        type=str,
                        default=None,
                        help='snapshot file of the trainer to resume from')

    args = parser.parse_args()

    if args.resume:
        assert args.snapshot is not None

    resume = args.resume
    device = args.device
    #os.environ["CUDA_VISIBLE_DEVICES"]=str(device)
    if resume:
        load_snapshot_path = args.snapshot

    experiment = args.name
    lr_init = args.lr_init
    # lr_target = args.lr_target
    # lr_factor = args.lr_factor
    # lr_trigger_interval = (args.lr_trigger, 'epoch')

    os.makedirs('result/' + experiment, exist_ok=True)
    f = open('result/' + experiment + '/details.txt', "w+")
    f.write("lr - " + str(lr_init) + "\n")
    f.write("optimizer - " + str(Adam))
    # f.write("lr_trigger_interval - "+str(lr_trigger_interval)+"\n")
    f.close()

    # if resume:
    model_own = FCN8s_hand()
    chainer.serializers.load_npz(args.modelfile, model_own)

    if device >= 0:
        print('sending model to gpu ' + str(device))
        model_own.to_gpu(device)

    dataset = MyTrainingDataset()
    iterator = SerialIterator(dataset, 1)
    optimizer = Adam(alpha=lr_init)
    optimizer.setup(model_own)

    updater = VOC_ClassificationUpdater_v2(iterator,
                                           optimizer,
                                           device=device,
                                           dropout=0.5)
    trainer = Trainer(updater, (100, 'epoch'))
    log_keys = ['epoch', 'iteration', 'main/Loss']
    trainer.extend(
        extensions.LogReport(log_keys, (100, 'iteration'),
                             log_name='log_' + experiment))
    trainer.extend(extensions.PrintReport(log_keys),
                   trigger=(100, 'iteration'))
    trainer.extend(extensions.snapshot(filename=experiment +
                                       "_snapshot_{.updater.iteration}"),
                   trigger=(1, 'epoch'))
    trainer.extend(extensions.snapshot_object(
        trainer.updater._optimizers['main'].target,
        experiment + "_model_{.updater.iteration}"),
                   trigger=(1, 'epoch'))
    trainer.extend(
        extensions.PlotReport(['main/Loss'],
                              'iteration', (100, 'iteration'),
                              file_name=experiment + '/loss.png',
                              grid=True,
                              marker=" "))

    # trainer.extend(extensions.ExponentialShift('lr', lr_factor, target=lr_target), trigger=lr_trigger_interval)
    if resume:
        chainer.serializers.load_npz(load_snapshot_path, trainer)

    print("Running - - ", experiment)
    print('initial lr ', lr_init)
    # print('lr_trigger_interval ', lr_trigger_interval)
    trainer.run()
Ejemplo n.º 5
0
def main():
	parser = argparse.ArgumentParser(
	formatter_class=argparse.ArgumentDefaultsHelpFormatter)
	parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
	parser.add_argument('--pretrained', type=str, help='path to model that has trained classifier but has not been trained through GAIN routine')
	parser.add_argument('--trained', type=str, help='path to model trained through GAIN')
	parser.add_argument('--device', type=int, default=-1, help='gpu id')
	parser.add_argument('--shuffle', type=bool, default=False, help='whether to shuffle dataset')
	parser.add_argument('--whole', type=bool, default=False, help='whether to test for the whole validation dataset')
	parser.add_argument('--no', type=int, default=20, help='if not whole, then no of images to visualize')
	parser.add_argument('--name', type=str, default='viz1', help='name of the subfolder or experiment under which to save')

	args = parser.parse_args()

	pretrained_file = args.pretrained
	trained_file = args.trained
	device = args.device
	shuffle = args.shuffle
	whole = args.whole
	name = args.name
	N = args.no

	dataset = MyTrainingDataset(split='train')
	iterator = SerialIterator(dataset, 1, shuffle=shuffle, repeat=False)
	converter = chainer.dataset.concat_examples
	os.makedirs('viz/'+name, exist_ok=True)
	# no_of_classes = 20
	no_of_classes = 21
	#FCN8s()
	pretrained = FCN8s_hand()
	trained = FCN8s_hand()
	load_npz(pretrained_file, pretrained)
	load_npz(trained_file, trained)
	
	if device >=0:
		pretrained.to_gpu()
		trained.to_gpu()
	i = 0
	
	while not iterator.is_new_epoch:
		
		if not whole and i >= N:
			break

		# image, labels = converter(iterator.next()
		image, labels, metadata = converter(iterator.next())
		image = Variable(image)
		if device >=0:
			image.to_gpu()

		xp = get_array_module(image.data)
		to_substract = np.array((-1, 0))
		noise_classes = np.unique(labels[0]).astype(np.int32)
		target = xp.asarray([[0]*(no_of_classes)])
		gt_labels = np.setdiff1d(noise_classes, to_substract) - 1

		# gcam1, cl_scores1, class_id1 = pretrained.stream_cl(image, gt_labels)
		# gcam2, cl_scores2, class_id2 = trained.stream_cl(image, gt_labels)
		# cl_output = pretrained.classify(image, is_training=False)
		# print(cp.asnumpy(trained.classify(image, is_training=False).data))
		lbl1 = pretrained.predict(image)
		lbl1 = cp.asnumpy(lbl1[0].data)
		# lbl1[lbl1 != 21] = 0
		print(np.unique(lbl1))
		# print("Non zero mask pixels {}".format(np.max(cp.asnumpy(lbl1[0].data))))

		if device>-0:
			class_id = cp.asnumpy(class_id)
		fig1 = plt.figure(figsize=(20,10))
		ax1= plt.subplot2grid((3, 9), (0, 0), colspan=3, rowspan=3)
		ax1.axis('off')
		ax1.imshow(cp.asnumpy(F.transpose(F.squeeze(image, 0), (1, 2, 0)).data) / 255.)

		ax2= plt.subplot2grid((3, 9), (0, 3), colspan=3, rowspan=3)
		ax2.axis('off')
		ax2.imshow(cp.asnumpy(F.transpose(F.squeeze(image, 0), (1, 2, 0)).data) / 255.)
		# ax2.imshow(cp.asnumpy(F.squeeze(gcam1[0], 0).data), cmap='jet', alpha=.5)
		# print("Mask dims {}".format(cp.asnumpy(lbl1[0].data).shape))
		# print("Non zero mask pixels {}".format(np.max(cp.asnumpy(lbl1[0].data))))
		ax2.imshow(lbl1, cmap='jet')
		# ax2.set_title("For class - "+str(voc_semantic_segmentation_label_names[cp.asnumpy(class_id1[0])+1]), color='teal')
		del lbl1
		lbl2 = trained.predict(image)
		lbl2 = cp.asnumpy(lbl2[0].data)
		# lbl2[lbl2 != 21] = 0
		print(np.unique(lbl2))
		ax3= plt.subplot2grid((3, 9), (0, 6), colspan=3, rowspan=3)
		ax3.axis('off')
		ax3.imshow(cp.asnumpy(F.transpose(F.squeeze(image, 0), (1, 2, 0)).data) / 255.)
		# ax3.imshow(cp.asnumpy(F.squeeze(gcam2[0], 0).data), cmap='jet', alpha=.5)
		# print("Mask dims {}".format(cp.asnumpy(lbl2[0].data).shape))
		# print("Non zero mask pixels {}".format(np.max(cp.asnumpy(lbl2[0].data))))
		ax3.imshow(lbl2, cmap='jet')
		# ax3.set_title("For class - "+str(voc_semantic_segmentation_label_names[cp.asnumpy(class_id2[0])+1]), color='teal')
		del lbl2
		fig1.savefig('viz/'+name+'/'+str(i)+'.png')
		plt.close()
		print(i)
		i += 1
Ejemplo n.º 6
0
def main():
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument(
        '--pretrained',
        type=str,
        help=
        'path to model that has trained classifier but has not been trained through GAIN routine',
        default='classifier_padding_1_model_594832')
    parser.add_argument(
        '--trained',
        type=str,
        help='path to model trained through GAIN',
        default='result/MYGAIN_5_to_1_padding_1_all_update_model_20000')
    parser.add_argument('--device', type=int, default=0, help='gpu id')
    parser.add_argument('--shuffle',
                        type=bool,
                        default=False,
                        help='whether to shuffle dataset')
    parser.add_argument(
        '--whole',
        type=bool,
        default=False,
        help='whether to test for the whole validation dataset')
    parser.add_argument('--no',
                        type=int,
                        default=5,
                        help='if not whole, then no of images to visualize')
    parser.add_argument(
        '--name',
        type=str,
        default='viz1',
        help='name of the subfolder or experiment under which to save')

    args = parser.parse_args()

    pretrained_file = args.pretrained
    trained_file = args.trained
    device = args.device
    shuffle = args.shuffle
    whole = args.whole
    name = args.name
    N = args.no

    dataset = MyTrainingDataset(split='val')
    iterator = SerialIterator(dataset, 1, shuffle=shuffle, repeat=False)
    converter = chainer.dataset.concat_examples
    os.makedirs('viz/' + name, exist_ok=True)
    no_of_classes = 21
    device = 0
    pretrained = FCN8s_hand()
    trained = FCN8s_hand()
    load_npz(pretrained_file, pretrained)
    load_npz(trained_file, trained)

    if device >= 0:
        pretrained.to_gpu()
        trained.to_gpu()
    i = 0

    true_positive = [0 for j in range(21)]
    true_negative = [0 for j in range(21)]
    false_positive = [0 for j in range(21)]
    false_negative = [0 for j in range(21)]

    while not iterator.is_new_epoch:

        if not whole and i >= N:
            break

        image, labels, metadata = converter(iterator.next())
        np_input_img = image
        np_input_img = np.uint8(np_input_img[0])
        np_input_img = np.transpose(np_input_img, (1, 2, 0))
        image = Variable(image)
        if device >= 0:
            image.to_gpu()

        xp = get_array_module(image.data)
        to_substract = np.array((-1, 0))
        noise_classes = np.unique(labels[0]).astype(np.int32)
        target = xp.asarray([[0] * (no_of_classes)])
        gt_labels = np.setdiff1d(noise_classes, to_substract) - 1
        target[0][gt_labels] = 1

        gcam1, cl_scores1, class_id1 = pretrained.stream_cl(image)
        gcam2, cl_scores2, class_id2 = trained.stream_cl(image)
        # gcams1, cl_scores1, class_ids1 = pretrained.stream_cl_multi(image)
        # gcams2, cl_scores2, class_ids2 = trained.stream_cl_multi(image)

        target = cp.asnumpy(target)
        cl_scores2 = cp.asnumpy(cl_scores2.data)
        # print(target)
        # print(cl_scores2)
        # print()
        # score_sigmoid = F.sigmoid(cl_scores2)
        for j in range(0, len(target[0])):
            # print(target[0][j] == 1)
            if target[0][j] == 1:
                if cl_scores2[0][j] >= 0:
                    true_positive[j] += 1
                else:
                    false_negative[j] += 1
            else:
                if cl_scores2[0][j] <= 0:
                    true_negative[j] += 1
                else:
                    false_positive[j] += 1
        # bboxes = gcams_to_bboxes(gcams2, class_ids2, input_image=np_input_img)
        # cv2.imshow('input', np_input_img)
        # cv2.waitKey(0)
        if device > -0:
            class_id = cp.asnumpy(class_id)
        # fig1 = plt.figure(figsize=(20, 10))
        # ax1 = plt.subplot2grid((3, 9), (0, 0), colspan=3, rowspan=3)
        # ax1.axis('off')
        # ax1.imshow(cp.asnumpy(F.transpose(F.squeeze(image, 0), (1, 2, 0)).data) / 255.)
        #
        # ax2 = plt.subplot2grid((3, 9), (0, 3), colspan=3, rowspan=3)
        # ax2.axis('off')
        # ax2.imshow(cp.asnumpy(F.transpose(F.squeeze(image, 0), (1, 2, 0)).data) / 255.)
        # ax2.imshow(cp.asnumpy(F.squeeze(gcam1[0], 0).data), cmap='jet', alpha=.5)
        # ax2.set_title("Before GAIN for class - " + str(dataset.class_names[cp.asnumpy(class_id1)+1]),
        #               color='teal')
        #
        # ax3 = plt.subplot2grid((3, 9), (0, 6), colspan=3, rowspan=3)
        # ax3.axis('off')
        # ax3.imshow(cp.asnumpy(F.transpose(F.squeeze(image, 0), (1, 2, 0)).data) / 255.)
        # ax3.imshow(cp.asnumpy(F.squeeze(gcam2[0], 0).data), cmap='jet', alpha=.5)
        # ax3.set_title("After GAIN for class - " + str(dataset.class_names[cp.asnumpy(class_id2)+1]),
        #               color='teal')
        # fig1.savefig('viz/' + name + '/' + str(i) + '.png')
        # plt.close()
        print(i)
        i += 1
    print("true postive {}".format(true_positive))
    print("true negative {}".format(true_negative))
    print("false positive {}".format(false_positive))
    print("false negative {}".format(false_negative))
Ejemplo n.º 7
0
def main():
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--pretrained', type=str,
                        help='path to model that has trained classifier but has not been trained through GAIN routine',
                        default='classifier_padding_1_model_594832')
    parser.add_argument('--trained', type=str, help='path to model trained through GAIN',
                        default='result/MYGAIN_5_to_1_padding_1_all_update_model_20000')
    parser.add_argument('--device', type=int, default=0, help='gpu id')
    parser.add_argument('--shuffle', type=bool, default=False, help='whether to shuffle dataset')
    parser.add_argument('--whole', type=bool, default=False, help='whether to test for the whole validation dataset')
    parser.add_argument('--no', type=int, default=50, help='if not whole, then no of images to visualize')
    parser.add_argument('--name', type=str, default='viz1', help='name of the subfolder or experiment under which to save')

    args = parser.parse_args()

    # pretrained_file = args.pretrained
    trained_file = args.trained
    device = args.device
    shuffle = args.shuffle
    whole = args.whole
    name = args.name
    N = args.no

    dataset = MyTrainingDataset(split='val')
    iterator = SerialIterator(dataset, 1, shuffle=shuffle, repeat=False)
    converter = chainer.dataset.concat_examples
    os.makedirs('viz/' + name, exist_ok=True)
    no_of_classes = 20
    device = 0
    pretrained = FCN8s_hand()
    trained = FCN8s_hand()
    # load_npz(pretrained_file, pretrained)
    load_npz(trained_file, trained)

    if device >= 0:
        pretrained.to_gpu()
        trained.to_gpu()
    i = 0

    while not iterator.is_new_epoch:

        if not whole and i >= N:
            break

        image, labels, metadata = converter(iterator.next())
        np_input_img = image
        np_input_img = np.uint8(np_input_img[0])
        np_input_img = np.transpose(np_input_img, (1,2,0))
        image = Variable(image)
        if device >= 0:
            image.to_gpu()

        xp = get_array_module(image.data)
        to_substract = np.array((-1, 0))
        noise_classes = np.unique(labels[0]).astype(np.int32)
        target = xp.asarray([[0] * (no_of_classes)])
        gt_labels = np.setdiff1d(noise_classes, to_substract) - 1

        # gcam1, cl_scores1, class_id1 = pretrained.stream_cl(image)
        # gcam2, cl_scores2, class_id2 = trained.stream_cl(image)
        # gcams1, cl_scores1, class_ids1 = pretrained.stream_cl_multi(image)
        gcams2, cl_scores2, class_ids2 = trained.stream_cl_multi(image)

        print(np_input_img.shape)
        bboxes_per_class, pointed_bbox = gcams_to_bboxes(gcams2, class_ids2, input_image=np_input_img)

        # for bboxes in bboxes_per_class:
        #     for bbox in bboxes:
        #         cv2.rectangle(np_input_img, (int(bbox[0]), int(bbox[1])), (int(bbox[2]), int(bbox[3])), [255,255,255], 2)
        display_img = cv2.cvtColor(np_input_img.copy(), cv2.COLOR_RGB2BGR)
        # if there's a hand and a pointed obj, draw rects
        if len(class_ids2) >= 2 and class_ids2[-1] == 20:
            cv2.rectangle(display_img, (int(pointed_bbox[0]), int(pointed_bbox[1])), (int(pointed_bbox[2]), int(pointed_bbox[3])), [255, 255, 255], 2)
            # redraw hand bounding box with different color
            for bbox in bboxes_per_class[-1]:
                cv2.rectangle(display_img, (int(bbox[0]), int(bbox[1])), (int(bbox[2]), int(bbox[3])), [0,255,0], 2)
        cv2.imshow('input img', display_img)
        cv2.waitKey(0)

        if device > -0:
            class_id = cp.asnumpy(class_id)
        # fig1 = plt.figure(figsize=(20, 10))
        # ax1 = plt.subplot2grid((3, 9), (0, 0), colspan=3, rowspan=3)
        # ax1.axis('off')
        # ax1.imshow(cp.asnumpy(F.transpose(F.squeeze(image, 0), (1, 2, 0)).data) / 255.)
        #
        # ax2 = plt.subplot2grid((3, 9), (0, 3), colspan=3, rowspan=3)
        # ax2.axis('off')
        # ax2.imshow(cp.asnumpy(F.transpose(F.squeeze(image, 0), (1, 2, 0)).data) / 255.)
        # ax2.imshow(cp.asnumpy(F.squeeze(gcam1[0], 0).data), cmap='jet', alpha=.5)
        # ax2.set_title("Before GAIN for class - " + str(dataset.class_names[cp.asnumpy(class_id1)+1]),
        #               color='teal')
        #
        # ax3 = plt.subplot2grid((3, 9), (0, 6), colspan=3, rowspan=3)
        # ax3.axis('off')
        # ax3.imshow(cp.asnumpy(F.transpose(F.squeeze(image, 0), (1, 2, 0)).data) / 255.)
        # ax3.imshow(cp.asnumpy(F.squeeze(gcam2[0], 0).data), cmap='jet', alpha=.5)
        # ax3.set_title("After GAIN for class - " + str(dataset.class_names[cp.asnumpy(class_id2)+1]),
        #               color='teal')
        # fig1.savefig('viz/' + name + '/' + str(i) + '.png')
        # plt.close()
        print(i)
        i += 1