コード例 #1
0
def visualize_test_images(ckpt_list):
    #===========================================================================
    for ckpt_name in ckpt_list:
        try:
            # Step0 ============================================================
            # Parsing the hyper-parameters
            FILE_NAME_FORMAT = ckpt_name.split('.')[0]
            parsing_list = ckpt_name.split('.')[0].split('_')

            # Setting constants
            model_name            = parsing_list[0]
            dataset_name          = parsing_list[1]
            loss_type             = parsing_list[2]
            flag                  = parsing_list[-1]

            if 'attention' in flag:
                attention = True
            else:
                attention = False
            # Step1 ============================================================
            # Load dataset
            test_dataloader = CycleGAN_Dataloader(name=dataset_name,
                                                  train=False,
                                                  num_workers=8)
            print('==> DataLoader ready.')

            # Step2 ============================================================
            # Make the model
            if dataset_name == 'cityscapes':
                A_generator       = Generator(num_resblock=6)
                B_generator       = Generator(num_resblock=6)
                A_discriminator   = Discriminator()
                B_discriminator   = Discriminator()
            else:
                A_generator       = Generator(num_resblock=9)
                B_generator       = Generator(num_resblock=9)
                A_discriminator   = Discriminator()
                B_discriminator   = Discriminator()

            # Check DataParallel available
            if torch.cuda.device_count() > 1:
                A_generator = nn.DataParallel(A_generator)
                B_generator = nn.DataParallel(B_generator)
                A_discriminator = nn.DataParallel(A_discriminator)
                B_discriminator = nn.DataParallel(B_discriminator)

            # Check CUDA available
            if torch.cuda.is_available():
                A_generator.cuda()
                B_generator.cuda()
                A_discriminator.cuda()
                B_discriminator.cuda()
            print('==> Model ready.')

            # Step3 ============================================================
            # Test the model
            checkpoint = torch.load(os.path.join(CHECKPOINT_PATH, ckpt_name))
            A_generator.load_state_dict(checkpoint['A_generator_state_dict'])
            B_generator.load_state_dict(checkpoint['B_generator_state_dict'])
            A_discriminator.load_state_dict(checkpoint['A_discriminator_state_dict'])
            B_discriminator.load_state_dict(checkpoint['B_discriminator_state_dict'])
            train_epoch = checkpoint['epoch']

            val(test_dataloader, A_generator, B_generator,
                A_discriminator, B_discriminator, train_epoch,
                FILE_NAME_FORMAT, attention)

            #-------------------------------------------------------------------
            # Print the result on the console
            print("model   : {}".format(model_name))
            print("dataset : {}".format(dataset_name))
            print("loss    : {}".format(loss_type))
            print('-'*50)
        except Exception as e:
            print(e)
    print('==> Visualize test images done.')
コード例 #2
0
def main(args):
    # Step0 ====================================================================
    # Set GPU ids
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_ids

    # Set the file name format
    FILE_NAME_FORMAT = "{0}_{1}_{2}_{3:d}{4}".format(args.model, args.dataset,
                                                     args.loss, args.epochs,
                                                     args.flag)
    # Set the results file path
    RESULT_FILE_NAME = FILE_NAME_FORMAT + '_results.pkl'
    RESULT_FILE_PATH = os.path.join(RESULT_PATH, RESULT_FILE_NAME)
    # Set the checkpoint file path
    CHECKPOINT_FILE_NAME = FILE_NAME_FORMAT + '.ckpt'
    CHECKPOINT_FILE_PATH = os.path.join(CHECKPOINT_PATH, CHECKPOINT_FILE_NAME)
    BEST_CHECKPOINT_FILE_NAME = FILE_NAME_FORMAT + '_best.ckpt'
    BEST_CHECKPOINT_FILE_PATH = os.path.join(CHECKPOINT_PATH,
                                             BEST_CHECKPOINT_FILE_NAME)
    # Set the random seed same for reproducibility
    random.seed(190811)
    torch.manual_seed(190811)
    torch.cuda.manual_seed_all(190811)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # Step1 ====================================================================
    # Load dataset
    train_dataloader = CycleGAN_Dataloader(name=args.dataset,
                                           num_workers=args.num_workers)
    test_dataloader = CycleGAN_Dataloader(name=args.dataset,
                                          train=False,
                                          num_workers=args.num_workers)
    print('==> DataLoader ready.')

    # Step2 ====================================================================
    # Make the model
    if args.dataset == 'cityscapes':
        A_generator = Generator(num_resblock=6)
        B_generator = Generator(num_resblock=6)
        A_discriminator = Discriminator()
        B_discriminator = Discriminator()
    else:
        A_generator = Generator(num_resblock=9)
        B_generator = Generator(num_resblock=9)
        A_discriminator = Discriminator()
        B_discriminator = Discriminator()

    # Check DataParallel available
    if torch.cuda.device_count() > 1:
        A_generator = nn.DataParallel(A_generator)
        B_generator = nn.DataParallel(B_generator)
        A_discriminator = nn.DataParallel(A_discriminator)
        B_discriminator = nn.DataParallel(B_discriminator)

    # Check CUDA available
    if torch.cuda.is_available():
        A_generator.cuda()
        B_generator.cuda()
        A_discriminator.cuda()
        B_discriminator.cuda()
    print('==> Model ready.')

    # Step3 ====================================================================
    # Set each loss function
    criterion_GAN = nn.MSELoss()
    criterion_cycle = nn.L1Loss()
    criterion_identity = nn.L1Loss()
    criterion_feature = nn.L1Loss()

    # Set each optimizer
    optimizer_G = optim.Adam(itertools.chain(A_generator.parameters(),
                                             B_generator.parameters()),
                             lr=args.lr,
                             betas=(0.5, 0.999))
    optimizer_D = optim.Adam(itertools.chain(A_discriminator.parameters(),
                                             B_discriminator.parameters()),
                             lr=args.lr,
                             betas=(0.5, 0.999))

    # Set learning rate scheduler
    def lambda_rule(epoch):
        epoch_decay = args.epochs / 2
        lr_linear_scale = 1.0 - max(0, epoch + 1 - epoch_decay) \
                                / float(epoch_decay+ 1)
        return lr_linear_scale

    scheduler_G = lr_scheduler.LambdaLR(optimizer_G, lr_lambda=lambda_rule)
    scheduler_D = lr_scheduler.LambdaLR(optimizer_D, lr_lambda=lambda_rule)
    print('==> Criterion and optimizer ready.')

    # Step4 ====================================================================
    # Train and validate the model
    start_epoch = 0
    best_metric = float("inf")

    # Initialize the result lists
    train_loss_G = []
    train_loss_D_A = []
    train_loss_D_B = []

    # Set image buffer
    A_buffer = ImageBuffer(args.buffer_size)
    B_buffer = ImageBuffer(args.buffer_size)

    if args.resume:
        assert os.path.exists(CHECKPOINT_FILE_PATH), 'No checkpoint file!'
        checkpoint = torch.load(CHECKPOINT_FILE_PATH)
        A_generator.load_state_dict(checkpoint['A_generator_state_dict'])
        B_generator.load_state_dict(checkpoint['B_generator_state_dict'])
        A_discriminator.load_state_dict(
            checkpoint['A_discriminator_state_dict'])
        B_discriminator.load_state_dict(
            checkpoint['B_discriminator_state_dict'])
        optimizer_G.load_state_dict(checkpoint['optimizer_G_state_dict'])
        optimizer_D.load_state_dict(checkpoint['optimizer_D_state_dict'])
        scheduler_G.load_state_dict(checkpoint['scheduler_G_state_dict'])
        scheduler_D.load_state_dict(checkpoint['scheduler_D_state_dict'])
        start_epoch = checkpoint['epoch']
        train_loss_G = checkpoint['train_loss_G']
        train_loss_D_A = checkpoint['train_loss_D_A']
        train_loss_D_B = checkpoint['train_loss_D_B']
        best_metric = checkpoint['best_metric']

    # Save the training information
    result_data = {}
    result_data['model'] = args.model
    result_data['dataset'] = args.dataset
    result_data['loss'] = args.loss
    result_data['target_epoch'] = args.epochs
    result_data['batch_size'] = args.batch_size

    # Check the directory of the file path
    if not os.path.exists(os.path.dirname(RESULT_FILE_PATH)):
        os.makedirs(os.path.dirname(RESULT_FILE_PATH))
    if not os.path.exists(os.path.dirname(CHECKPOINT_FILE_PATH)):
        os.makedirs(os.path.dirname(CHECKPOINT_FILE_PATH))
    print('==> Train ready.')

    for epoch in range(args.epochs):
        # strat after the checkpoint epoch
        if epoch < start_epoch:
            continue

        print("\n[Epoch: {:3d}/{:3d}]".format(epoch + 1, args.epochs))
        epoch_time = time.time()
        #=======================================================================
        # train and validate the model
        tloss_G, tloss_D = train(
            train_dataloader, A_generator, B_generator, A_discriminator,
            B_discriminator, criterion_GAN, criterion_cycle,
            criterion_identity, optimizer_G, optimizer_D, A_buffer, B_buffer,
            args.loss, args.lambda_cycle, args.lambda_identity,
            criterion_feature, args.lambda_feature, args.attention)
        train_loss_G.append(tloss_G)
        train_loss_D_A.append(tloss_D['A'])
        train_loss_D_B.append(tloss_D['B'])

        if (epoch + 1) % 10 == 0:
            val(test_dataloader, A_generator, B_generator, A_discriminator,
                B_discriminator, epoch + 1, FILE_NAME_FORMAT, args.attention)

        # Update the optimizer's learning rate
        current_lr = optimizer_G.param_groups[0]['lr']
        scheduler_G.step()
        scheduler_D.step()
        #=======================================================================
        current = time.time()

        # Save the current result
        result_data['current_epoch'] = epoch
        result_data['train_loss_G'] = train_loss_G
        result_data['train_loss_D_A'] = train_loss_D_A
        result_data['train_loss_D_B'] = train_loss_D_B

        # Save result_data as pkl file
        with open(RESULT_FILE_PATH, 'wb') as pkl_file:
            pickle.dump(result_data,
                        pkl_file,
                        protocol=pickle.HIGHEST_PROTOCOL)

        # Save the best checkpoint
        # if train_loss_G < best_metric:
        #     best_metric = train_loss_G
        #     torch.save({
        #         'epoch': epoch+1,
        #         'A_generator_state_dict': A_generator.state_dict(),
        #         'B_generator_state_dict': B_generator.state_dict(),
        #         'A_discriminator_state_dict': A_discriminator.state_dict(),
        #         'B_discriminator_state_dict': B_discriminator.state_dict(),
        #         'optimizer_G_state_dict': optimizer_G.state_dict(),
        #         'optimizer_D_state_dict': optimizer_D.state_dict(),
        #         'scheduler_G_state_dict': scheduler_G.state_dict(),
        #         'scheduler_D_state_dict': scheduler_D.state_dict(),
        #         'train_loss_G': train_loss_G,
        #         'train_loss_D_A': train_loss_D_A,
        #         'train_loss_D_B': train_loss_D_B,
        #         'best_metric': best_metric,
        #         }, BEST_CHECKPOINT_FILE_PATH)

        # Save the current checkpoint
        torch.save(
            {
                'epoch': epoch + 1,
                'A_generator_state_dict': A_generator.state_dict(),
                'B_generator_state_dict': B_generator.state_dict(),
                'A_discriminator_state_dict': A_discriminator.state_dict(),
                'B_discriminator_state_dict': B_discriminator.state_dict(),
                'optimizer_G_state_dict': optimizer_G.state_dict(),
                'optimizer_D_state_dict': optimizer_D.state_dict(),
                'scheduler_G_state_dict': scheduler_G.state_dict(),
                'scheduler_D_state_dict': scheduler_D.state_dict(),
                'train_loss_G': train_loss_G,
                'train_loss_D_A': train_loss_D_A,
                'train_loss_D_B': train_loss_D_B,
                'best_metric': best_metric,
            }, CHECKPOINT_FILE_PATH)

        if (epoch + 1) % 10 == 0:
            CHECKPOINT_FILE_NAME_epoch = FILE_NAME_FORMAT + '_{0}.ckpt'
            CHECKPOINT_FILE_PATH_epoch = os.path.join(
                CHECKPOINT_PATH, FILE_NAME_FORMAT, CHECKPOINT_FILE_NAME_epoch)
            if not os.path.exists(os.path.dirname(CHECKPOINT_FILE_PATH_epoch)):
                os.makedirs(os.path.dirname(CHECKPOINT_FILE_PATH_epoch))
            torch.save(
                {
                    'epoch': epoch + 1,
                    'A_generator_state_dict': A_generator.state_dict(),
                    'B_generator_state_dict': B_generator.state_dict(),
                    'A_discriminator_state_dict': A_discriminator.state_dict(),
                    'B_discriminator_state_dict': B_discriminator.state_dict(),
                    'optimizer_G_state_dict': optimizer_G.state_dict(),
                    'optimizer_D_state_dict': optimizer_D.state_dict(),
                    'scheduler_G_state_dict': scheduler_G.state_dict(),
                    'scheduler_D_state_dict': scheduler_D.state_dict(),
                    'train_loss_G': train_loss_G,
                    'train_loss_D_A': train_loss_D_A,
                    'train_loss_D_B': train_loss_D_B,
                    'best_metric': best_metric,
                }, CHECKPOINT_FILE_PATH_epoch)

        # Print the information on the console
        print("model                : {}".format(args.model))
        print("dataset              : {}".format(args.dataset))
        print("loss                 : {}".format(args.loss))
        print("batch_size           : {}".format(args.batch_size))
        print("current lrate        : {:f}".format(current_lr))
        print("G loss               : {:f}".format(tloss_G))
        print("D A/B loss           : {:f}/{:f}".format(
            tloss_D['A'], tloss_D['B']))
        print("epoch time           : {0:.3f} sec".format(current -
                                                          epoch_time))
        print("Current elapsed time : {0:.3f} sec".format(current - start))
    print('==> Train done.')

    print(' '.join(['Results have been saved at', RESULT_FILE_PATH]))
    print(' '.join(['Checkpoints have been saved at', CHECKPOINT_FILE_PATH]))
コード例 #3
0
class Model:
    def __init__(self, base_path='', epochs=10, learning_rate=0.0002, image_size=256, leaky_relu=0.2,
                 betas=(0.5, 0.999), lamda=100, image_format='png'):
        self.image_size = image_size
        self.leaky_relu_threshold = leaky_relu

        self.epochs = epochs
        self.lr = learning_rate
        self.betas = betas
        self.lamda = lamda
        self.base_path = base_path
        self.image_format = image_format
        self.count = 1

        self.gen = None
        self.dis = None
        self.gen_optim = None
        self.dis_optim = None
        self.model_type = None
        self.residual_blocks = 9
        self.layer_size = 64
        self.lr_policy = None
        self.lr_schedule_gen = None
        self.lr_schedule_dis = None

        self.device = self.get_device()
        self.create_folder_structure()

    def create_folder_structure(self):
        checkpoint_folder = self.base_path + '/checkpoints'
        loss_folder = self.base_path + '/Loss_Checkpoints'
        training_folder = self.base_path + '/Training Images'
        test_folder = self.base_path + '/Test Images'
        if not os.path.exists(checkpoint_folder):
            os.makedirs(checkpoint_folder)
        if not os.path.exists(loss_folder):
            os.makedirs(loss_folder)
        if not os.path.exists(training_folder):
            os.makedirs(training_folder)
        if not os.path.exists(test_folder):
            os.makedirs(test_folder)

    def get_device(self):
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        print('Using device:', device)
        print(torch.cuda.get_device_name(0))

        if device.type == 'cuda':
            print('Memory Usage -')
            print('Allocated:', round(torch.cuda.memory_allocated(0) / 1024 ** 3, 1), 'GB')
            print('Cached:   ', round(torch.cuda.memory_cached(0) / 1024 ** 3, 1), 'GB')
            return device
        else:
            return None

    def initialize_model(self, lr_schedular_options, model_type='unet', residual_blocks=9, layer_size=64):

        all_models = ['unet', 'resnet', 'inception', 'unet2', 'unet_large', 'unet_fusion']
        if model_type not in all_models:
            raise Exception('This model type is not available!');

        self.dis = Discriminator(image_size=self.image_size, leaky_relu=self.leaky_relu_threshold)
        if model_type == 'unet':
            self.gen = Generator_Unet(image_size=self.image_size, ngf=layer_size)
        elif model_type == 'resnet':
            self.gen = Generator_RESNET(residual_blocks=residual_blocks, ngf=layer_size)
        elif model_type == 'inception':
            self.gen = Generator_InceptionNet(ngf=layer_size)
        elif model_type == 'unet2':
            self.gen = Generator_Unet_2(image_size=self.image_size, ngf=layer_size)
        elif model_type == 'unet_large':
            self.gen = Generator_Unet_Large(image_size=self.image_size, ngf=layer_size)
        elif model_type == 'unet_fusion':
            self.gen = Generator_Unet_Fusion(image_size=self.image_size, ngf=layer_size)

        if self.device is not None:
            self.gen.cuda()
            self.dis.cuda()

        self.gen_optim = optim.Adam(self.gen.parameters(), lr=self.lr, betas=self.betas)
        self.dis_optim = optim.Adam(self.dis.parameters(), lr=self.lr, betas=self.betas)

        self.lr_schedule_dis = self.get_learning_schedule(self.gen_optim, lr_schedular_options)
        self.lr_schedule_gen = self.get_learning_schedule(self.dis_optim, lr_schedular_options)

        self.model_type = model_type
        self.layer_size = layer_size
        self.residual_blocks = residual_blocks
        self.lr_policy = lr_schedular_options
        print('Model Initialized !\nGenerator Model Type : {} and Layer Size : {}'.format(model_type, layer_size))
        print('Model Parameters are:\nEpochs : {}\nLearning rate : {}\nLeaky Relu Threshold : {}\nLamda : {}\nBeta : {}'
              .format(self.epochs, self.lr, self.leaky_relu_threshold, self.lamda, self.betas))

    def train_model(self, trainloader, average_loss, eval=(False, None, None), save_model=(False, 25),
                    display_test_image=(False, None, 25)):

        print('We will be using L1 loss with perpetual loss (L1)!')
        mean_loss = nn.BCELoss()
        l1_loss = nn.L1Loss()
        vgg16 = models.vgg16()
        vgg16_conv = nn.Sequential(*list(vgg16.children())[:-3])

        self.gen.train()
        self.dis.train()

        batches = len(trainloader)
        print('Total number of batches in an epoch are : {}'.format(batches))

        sample_img_test = None
        if display_test_image[0]:
            sample_img_test, rgb_test_images = next(iter(display_test_image[1]))
            save_image((rgb_test_images[0].detach().cpu() + 1) / 2,
                       '{}/Training Images/real_img.{}'.format(self.base_path, self.image_format))
            if self.device is not None:
                sample_img_test = sample_img_test.cuda()

        for i in range(self.epochs):

            if eval[0] and (i % eval[2] == 0):
                self.evaluate_L1_loss_dataset(eval[1], train=False)
                self.evaluate_L1_loss_dataset(trainloader, train=True)
                self.gen.train()

            running_gen_loss = 0
            running_dis_loss = 0

            for gray_img, real_img in trainloader:

                batch_size = len(gray_img)
                zero_label = torch.zeros(batch_size)
                one_label = torch.ones(batch_size)

                if self.device is not None:
                    gray_img = gray_img.cuda()
                    real_img = real_img.cuda()
                    zero_label = zero_label.cuda()
                    one_label = one_label.cuda()

                # Discriminator loss
                self.dis_optim.zero_grad()
                fake_img = self.gen(gray_img)

                dis_real_loss = mean_loss(self.dis(real_img), one_label)
                dis_fake_loss = mean_loss(self.dis(fake_img), zero_label)

                total_dis_loss = dis_fake_loss + dis_real_loss
                total_dis_loss.backward()
                self.dis_optim.step()

                # Generator loss
                self.gen_optim.zero_grad()

                fake_img = self.gen(gray_img)
                gen_adv_loss = mean_loss(self.dis(fake_img), one_label)
                gen_l1_loss = l1_loss(fake_img.view(batch_size, -1), real_img.view(batch_size, -1))
                gen_pre_train = l1_loss(vgg16_conv(fake_img), vgg16_conv(real_img))
                total_gen_loss = gen_adv_loss + self.lamda * gen_l1_loss + self.lamda * gen_pre_train
                total_gen_loss.backward()
                self.gen_optim.step()

                running_dis_loss += total_dis_loss.item()
                running_gen_loss += total_gen_loss.item()

            running_dis_loss /= (batches * 1.0)
            running_gen_loss /= (batches * 1.0)
            print('Epoch : {}, Generator Loss : {} and Discriminator Loss : {}'.format(i + 1, running_gen_loss,
                                                                                       running_dis_loss))
            if display_test_image[0] and i % display_test_image[2] == 0:
                self.gen.eval()
                out_result = self.gen(sample_img_test)
                out_result = out_result.detach().cpu()
                out_result = (out_result[0] + 1) / 2
                save_image(out_result, '{}/Training Images/epoch_{}.{}'.format(self.base_path, i,
                                                                               self.image_format))
                self.gen.train()

            save_tuple = ([running_gen_loss], [running_dis_loss])
            average_loss.add_loss(save_tuple)

            if save_model[0] and i % save_model[1] == 0:
                self.save_checkpoint('checkpoint_epoch_{}'.format(i), self.model_type)
                average_loss.save('checkpoint_avg_loss', save_index=0)

            self.lr_schedule_gen.step()
            self.lr_schedule_dis.step()
            for param_grp in self.dis_optim.param_groups:
                print('Learning rate after {} epochs is : {}'.format(i + 1, param_grp['lr']))

        self.save_checkpoint('checkpoint_train_final', self.model_type)
        average_loss.save('checkpoint_avg_loss_final', save_index=0)

    def get_learning_schedule(self, optimizer, option):

        schedular = None
        if option['lr_policy'] == 'linear':
            def lambda_rule(epoch):
                lr_l = 1.0 - max(0, epoch - option['n_epochs']) / float(option['n_epoch_decay'] + 1)
                return lr_l

            schedular = lr_schedular.LambdaLR(optimizer, lr_lambda=lambda_rule)
        elif option['lr_policy'] == 'plateau':
            schedular = lr_schedular.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
        elif option['lr_policy'] == 'step':
            schedular = lr_schedular.StepLR(optimizer, step_size=option['step_size'], gamma=0.1)
        elif option['lr_policy'] == 'cosine':
            schedular = lr_schedular.CosineAnnealingLR(optimizer, T_max=option['n_epochs'], eta_min=0)
        else:
            raise Exception('LR Policy not implemented!')

        return schedular

    def evaluate_model(self, loader, save_filename, no_of_images=1):
        # Considering that we have batch size of 1 for test set
        if self.gen is None or self.dis is None:
            raise Exception('Model has not been initialized and hence cannot be saved!');

        counter_images_generated = 0
        while counter_images_generated < no_of_images:
            gray, rgb = next(iter(loader))

            if self.device is not None:
                gray = gray.cuda()

            filename = '{}/Test Images/{}_{}.{}'.format(self.base_path, save_filename, self.count, self.image_format)
            real_filename = '{}/Test Images/{}_{}_real.{}'.format(self.base_path, save_filename, self.count,
                                                                  self.image_format)
            real_gray_filename = '{}/Test Images/{}_{}_real_gray.{}'.format(self.base_path, save_filename, self.count,
                                                                            self.image_format)
            self.count += 1

            self.gen.eval()
            out = self.gen(gray)
            out = out[0].detach().cpu()
            out = (out + 1) / 2
            save_image(out, filename)

            gray_img = gray[0].detach().cpu()
            save_image(gray_img, real_gray_filename)

            real_img = (rgb[0].detach().cpu() + 1) / 2
            save_image(real_img, real_filename)

            counter_images_generated += 1

    def evaluate_L1_loss_dataset(self, loader, train=False):

        if self.gen is None or self.dis is None:
            raise Exception('Model has not been initialized and hence cannot be evaluated!')

        loss_function = nn.L1Loss()
        self.gen.eval()
        total_loss = 0.0;
        iterations = 0;
        for gray, real in loader:
            iterations += 1
            if self.device is not None:
                gray = gray.cuda()
                real = real.cuda()

            gen_out = self.gen(gray)
            iteration_loss = loss_function(gen_out, real)
            total_loss += iteration_loss.item()
        total_loss = total_loss / (iterations * 1.0)
        train_test = 'test'
        if train:
            train_test = 'train'
        print('Total L1 loss over {} set is : {}'.format(train_test, total_loss))
        return total_loss;

    def change_params(self, epochs=None, learning_rate=None, leaky_relu=None, betas=None, lamda=None):
        if epochs is not None:
            self.epochs = epochs
            print('Changed the number of epochs to {}!'.format(self.epochs))
        if learning_rate is not None:
            self.lr = learning_rate
            print('Changed the learning rate to {}!'.format(self.lr))
        if leaky_relu is not None:
            self.leaky_relu_threshold = leaky_relu
            print('Changed the threshold for leaky relu to {}!'.format(self.leaky_relu_threshold))
        if betas is not None:
            self.betas = betas
            print('Changed the betas for Adams Optimizer!')
        if betas is not None or learning_rate is not None:
            self.gen_optim = optim.Adam(self.gen.parameters(), lr=self.lr, betas=self.betas)
            self.dis_optim = optim.Adam(self.dis.parameters(), lr=self.lr, betas=self.betas)

        if lamda is not None:
            self.lamda = lamda
            print('Lamda value has been changed to {}!'.format(self.lamda))

    def set_all_params(self, epochs, lr, leaky_thresh, lamda, beta):
        self.epochs = epochs
        self.lr = lr
        self.leaky_relu_threshold = leaky_thresh
        self.lamda = lamda
        self.betas = beta
        self.gen_optim = optim.Adam(self.gen.parameters(), lr=self.lr, betas=self.betas)
        self.dis_optim = optim.Adam(self.dis.parameters(), lr=self.lr, betas=self.betas)

        print('Model Parameters are:\nEpochs : {}\nLearning rate : {}\nLeaky Relu Threshold : {}\nLamda : {}\nBeta : {}'
              .format(self.epochs, self.lr, self.leaky_relu_threshold, self.lamda, self.betas))

    def run_model_on_dataset(self, loader, save_folder, save_path=None):
        if self.gen is None or self.dis is None:
            raise Exception('Model has not been initialized and hence cannot be saved!');
        index = 1
        if save_path is None:
            save_path = self.base_path
        for gray, dummy in loader:

            if self.device is not None:
                gray = gray.cuda()

            filename = '{}/{}/{}.{}'.format(save_path, save_folder, index, self.image_format)
            index += 1

            self.gen.eval()
            out = self.gen(gray)
            out = out[0].detach().cpu()
            out = (out + 1) / 2
            save_image(out, filename)

    def save_checkpoint(self, filename, model_type='unet'):
        if self.gen is None or self.dis is None:
            raise Exception('The model has not been initialized and hence cannot be saved !')

        filename = '{}/checkpoints/{}.pth'.format(self.base_path, filename)
        save_dict = {'model_type': model_type, 'dis_dict': self.dis.state_dict(), 'gen_dict': self.gen.state_dict(),
                     'lr': self.lr,
                     'epochs': self.epochs, 'betas': self.betas, 'image_size': self.image_size,
                     'leaky_relu_thresh': self.leaky_relu_threshold, 'lamda': self.lamda, 'base_path': self.base_path,
                     'count': self.count, 'image_format': self.image_format, 'device': self.device,
                     'residual_blocks': self.residual_blocks, 'layer_size': self.layer_size,
                     'lr_policy': self.lr_policy}

        torch.save(save_dict, filename)

        print('The model checkpoint has been saved !')

    def load_checkpoint(self, filename):
        filename = '{}/checkpoints/{}.pth'.format(self.base_path, filename)
        if not pathlib.Path(filename).exists():
            raise Exception('This checkpoint does not exist!')

        self.gen = None
        self.dis = None

        save_dict = torch.load(filename)

        self.betas = save_dict['betas']
        self.image_size = save_dict['image_size']
        self.epochs = save_dict['epochs']
        self.leaky_relu_threshold = save_dict['leaky_relu_thresh']
        self.lamda = save_dict['lamda']
        self.lr = save_dict['lr']
        self.base_path = save_dict['base_path']
        self.count = save_dict['count']
        self.image_format = save_dict['image_format']
        self.device = save_dict['device']
        self.residual_blocks = save_dict['residual_blocks']
        self.layer_size = save_dict['layer_size']
        self.lr_policy = save_dict['lr_policy']

        device = self.get_device()
        if device != self.device:
            error_msg = ''
            if self.device is None:
                error_msg = 'The model was trained on CPU and will therefore be continued on CPU only!'
            else:
                error_msg = 'The model was trained on GPU and cannot be loaded on a CPU machine!'
                raise Exception(error_msg)

        self.initialize_model(model_type=save_dict['model_type'], residual_blocks=self.residual_blocks,
                              layer_size=self.layer_size, lr_schedular_options=self.lr_policy)

        self.gen.load_state_dict(save_dict['gen_dict'])
        self.dis.load_state_dict(save_dict['dis_dict'])

        print('The model checkpoint has been restored!')