def __init__(self, config, mode):
        super().__init__()

        self.config = config
        self.mode = mode

        self.h2l_G = HighToLowGenerator()
        self.l2h_G = LowToHighGenerator()

        self.is_cuda = torch.cuda.is_available()

        if self.is_cuda and not self.config.cuda:
            print(
                "WARNING: You have a CUDA device, so you should probably enable CUDA"
            )

        self.cuda = self.is_cuda & self.config.cuda
        self.manual_seed = random.randint(1, 10000)
        print('seed:{}'.format(self.manual_seed))
        random.seed(self.manual_seed)

        self.test_file = self.config.output_path
        if not os.path.exists(self.test_file):
            os.makedirs(self.test_file)

        if self.cuda:
            self.device = torch.device("cuda")
            torch.cuda.set_device(self.config.gpu_device)
            torch.cuda.manual_seed_all(self.manual_seed)
            print("Program will run on *****GPU-CUDA***** ")
            print_cuda_statistics()
        else:
            self.device = torch.device("cpu")
            torch.manual_seed(self.manual_seed)
            print("Program will run on *****CPU***** ")

        self.l2h_G = self.l2h_G.to(self.device)
        self.h2l_G = self.h2l_G.to(self.device)
    def __init__(self, config):
        super().__init__(config)

        # define models (generator and discriminator)
        self.h2l_G = HighToLowGenerator()
        self.h2l_D = HighToLowDiscriminator()
        self.l2h_G = LowToHighGenerator()
        self.l2h_D = LowToHighDiscriminator()

        # define loss
        #self.loss = GANLoss()
        #self.loss = HingeEmbeddingLoss()
        self.criterion_GAN = torch.nn.BCEWithLogitsLoss()
        self.criterion_MSE = MSELoss()

        # define optimizers for both generator and discriminator
        self.l2h_optimG = torch.optim.Adam(self.l2h_G.parameters(),
                                           lr=self.config.learning_rate,
                                           betas=(self.config.beta1,
                                                  self.config.beta2))
        self.l2h_optimD = torch.optim.Adam(self.l2h_D.parameters(),
                                           lr=self.config.learning_rate,
                                           betas=(self.config.beta1,
                                                  self.config.beta2))
        self.h2l_optimG = torch.optim.Adam(self.h2l_G.parameters(),
                                           lr=self.config.learning_rate,
                                           betas=(self.config.beta1,
                                                  self.config.beta2))
        self.h2l_optimD = torch.optim.Adam(self.h2l_D.parameters(),
                                           lr=self.config.learning_rate,
                                           betas=(self.config.beta1,
                                                  self.config.beta2))

        # initialize counter
        self.current_epoch = 0
        self.current_iteration = 0
        self.best_valid_mean_iou = 0

        self.real_label = 1
        self.fake_label = -1

        # set cuda flag
        self.is_cuda = torch.cuda.is_available()
        if self.is_cuda and not self.config.cuda:
            self.logger.info(
                "WARNING: You have a CUDA device, so you should probably enable CUDA"
            )

        self.cuda = self.is_cuda & self.config.cuda

        # set the manual seed for torch
        self.manual_seed = random.randint(1, 10000)
        self.logger.info('seed:{}'.format(self.manual_seed))
        random.seed(self.manual_seed)

        self.test_file = self.config.output_path
        if not os.path.exists(self.test_file):
            os.makedirs(self.test_file)

        if self.cuda:
            self.device = torch.device("cuda")
            torch.cuda.set_device(self.config.gpu_device)
            torch.cuda.manual_seed_all(self.manual_seed)
            self.logger.info("Program will run on *****GPU-CUDA***** ")
            print_cuda_statistics()
        else:
            self.device = torch.device("cpu")
            torch.manual_seed(self.manual_seed)
            self.logger.info("Program will run on *****CPU***** ")

        self.l2h_G = self.l2h_G.to(self.device)
        self.l2h_D = self.l2h_D.to(self.device)
        self.h2l_G = self.h2l_G.to(self.device)
        self.h2l_D = self.h2l_D.to(self.device)
        self.criterion_GAN = self.criterion_GAN.to(self.device)
        self.criterion_MSE = self.criterion_MSE.to(self.device)

        # Summary Writer
        self.summary_writer_l2h = SummaryWriter(
            log_dir=self.config.summary_dir_l2h, comment='Low-To-High GAN')
        self.summary_writer_h2l = SummaryWriter(
            log_dir=self.config.summary_dir_h2l, comment='High-To-Low GAN')
class Combined_GAN(BaseAgent):
    def __init__(self, config):
        super().__init__(config)

        # define models (generator and discriminator)
        self.h2l_G = HighToLowGenerator()
        self.h2l_D = HighToLowDiscriminator()
        self.l2h_G = LowToHighGenerator()
        self.l2h_D = LowToHighDiscriminator()

        # define loss
        #self.loss = GANLoss()
        #self.loss = HingeEmbeddingLoss()
        self.criterion_GAN = torch.nn.BCEWithLogitsLoss()
        self.criterion_MSE = MSELoss()

        # define optimizers for both generator and discriminator
        self.l2h_optimG = torch.optim.Adam(self.l2h_G.parameters(),
                                           lr=self.config.learning_rate,
                                           betas=(self.config.beta1,
                                                  self.config.beta2))
        self.l2h_optimD = torch.optim.Adam(self.l2h_D.parameters(),
                                           lr=self.config.learning_rate,
                                           betas=(self.config.beta1,
                                                  self.config.beta2))
        self.h2l_optimG = torch.optim.Adam(self.h2l_G.parameters(),
                                           lr=self.config.learning_rate,
                                           betas=(self.config.beta1,
                                                  self.config.beta2))
        self.h2l_optimD = torch.optim.Adam(self.h2l_D.parameters(),
                                           lr=self.config.learning_rate,
                                           betas=(self.config.beta1,
                                                  self.config.beta2))

        # initialize counter
        self.current_epoch = 0
        self.current_iteration = 0
        self.best_valid_mean_iou = 0

        self.real_label = 1
        self.fake_label = -1

        # set cuda flag
        self.is_cuda = torch.cuda.is_available()
        if self.is_cuda and not self.config.cuda:
            self.logger.info(
                "WARNING: You have a CUDA device, so you should probably enable CUDA"
            )

        self.cuda = self.is_cuda & self.config.cuda

        # set the manual seed for torch
        self.manual_seed = random.randint(1, 10000)
        self.logger.info('seed:{}'.format(self.manual_seed))
        random.seed(self.manual_seed)

        self.test_file = self.config.output_path
        if not os.path.exists(self.test_file):
            os.makedirs(self.test_file)

        if self.cuda:
            self.device = torch.device("cuda")
            torch.cuda.set_device(self.config.gpu_device)
            torch.cuda.manual_seed_all(self.manual_seed)
            self.logger.info("Program will run on *****GPU-CUDA***** ")
            print_cuda_statistics()
        else:
            self.device = torch.device("cpu")
            torch.manual_seed(self.manual_seed)
            self.logger.info("Program will run on *****CPU***** ")

        self.l2h_G = self.l2h_G.to(self.device)
        self.l2h_D = self.l2h_D.to(self.device)
        self.h2l_G = self.h2l_G.to(self.device)
        self.h2l_D = self.h2l_D.to(self.device)
        self.criterion_GAN = self.criterion_GAN.to(self.device)
        self.criterion_MSE = self.criterion_MSE.to(self.device)

        # Summary Writer
        self.summary_writer_l2h = SummaryWriter(
            log_dir=self.config.summary_dir_l2h, comment='Low-To-High GAN')
        self.summary_writer_h2l = SummaryWriter(
            log_dir=self.config.summary_dir_h2l, comment='High-To-Low GAN')

    def load_checkpoint(self, file_name, model):
        if model == 'l2h':
            checkpoint_dir = self.config.checkpoint_l2h_dir
        elif model == 'h2l':
            checkpoint_dir = self.config.checkpoint_h2l_dir
        elif model == 'combined':
            checkpoint_dir = self.config.checkpoint_combined_dir

        filename = checkpoint_dir + file_name
        try:
            self.logger.info("Loading checkpoint '{}'".format(filename))
            checkpoint = torch.load(filename)

            self.current_epoch = checkpoint['epoch']
            self.current_iteration = checkpoint['iteration']
            self.manual_seed = checkpoint['manual_seed']

            if model == 'h2l':
                self.h2l_G.load_state_dict(checkpoint['h2l_G_state_dict'])
                self.h2l_optimG.load_state_dict(checkpoint['h2l_G_optimizer'])
                self.h2l_D.load_state_dict(checkpoint['h2l_D_state_dict'])
                self.h2l_optimD.load_state_dict(checkpoint['h2l_D_optimizer'])

            elif model == 'l2h':
                self.l2h_G.load_state_dict(checkpoint['l2h_G_state_dict'])
                self.l2h_optimG.load_state_dict(checkpoint['l2h_G_optimizer'])
                self.l2h_D.load_state_dict(checkpoint['l2h_D_state_dict'])
                self.l2h_optimD.load_state_dict(checkpoint['l2h_D_optimizer'])

            elif model == 'combined':
                self.h2l_G.load_state_dict(checkpoint['h2l_G_state_dict'])
                self.h2l_optimG.load_state_dict(checkpoint['h2l_G_optimizer'])
                self.h2l_D.load_state_dict(checkpoint['h2l_D_state_dict'])
                self.h2l_optimD.load_state_dict(checkpoint['h2l_D_optimizer'])

                self.l2h_G.load_state_dict(checkpoint['l2h_G_state_dict'])
                self.l2h_optimG.load_state_dict(checkpoint['l2h_G_optimizer'])
                self.l2h_D.load_state_dict(checkpoint['l2h_D_state_dict'])
                self.l2h_optimD.load_state_dict(checkpoint['l2h_D_optimizer'])

        except OSError:
            self.logger.info(
                "No checkpoint exists from '{}'. Skipping...".format(
                    checkpoint_dir))
            self.logger.info("**First time to train**")

    def save_checkpoint(self, file_name, model, is_best=0):
        state = {
            'epoch': self.current_epoch,
            'iteration': self.current_iteration,
            'manual_seed': self.manual_seed
        }
        if model == 'l2h':
            state['l2g_G_state_dict'] = self.l2h_G.state_dict()
            state['l2h_G_optimizer'] = self.l2h_optimG.state_dict()
            state['l2h_D_state_dict'] = self.l2h_D.state_dict()
            state['l2h_D_optimizer'] = self.l2h_optimD.state_dict()

            checkpoint_dir = self.config.checkpoint_l2h_dir

        elif model == 'h2l':
            state['h2l_G_state_dict'] = self.h2l_G.state_dict()
            state['h2l_G_optimizer'] = self.h2l_optimG.state_dict()
            state['h2l_D_state_dict'] = self.h2l_D.state_dict()
            state['h2l_D_optimizer'] = self.h2l_optimD.state_dict()

            checkpoint_dir = self.config.checkpoint_h2l_dir

        elif model == 'combined':
            state['l2h_G_state_dict'] = self.l2h_G.state_dict()
            state['l2h_G_optimizer'] = self.l2h_optimG.state_dict()
            state['l2h_D_state_dict'] = self.l2h_D.state_dict()
            state['l2h_D_optimizer'] = self.l2h_optimD.state_dict()

            state['h2l_G_state_dict'] = self.h2l_G.state_dict()
            state['h2l_G_optimizer'] = self.h2l_optimG.state_dict()
            state['h2l_D_state_dict'] = self.h2l_D.state_dict()
            state['h2l_D_optimizer'] = self.h2l_optimD.state_dict()

            checkpoint_dir = self.config.checkpoint_combined_dir
        # Save the state
        torch.save(state, checkpoint_dir + file_name)
        # If it is the best copy it to another file 'model_best.pth.tar'
        if is_best:
            shutil.copyfile(checkpoint_dir + file_name,
                            checkpoint_dir + '_best.pth.tar')

    def run(self):
        """
        This function will the operator
        :return:
        """
        try:
            self.train()
        except KeyboardInterrupt:
            self.logger.info("You have entered CTRL+C.. Wait to finalize")

    def train(self):
        # Model Loading from the latest checkpoint if not found start from scratch.
        self.load_checkpoint(self.config.checkpoint_file_h2l, 'h2l')
        if self.current_epoch <= 200:
            self.load_checkpoint(self.config.checkpoint_file_l2h, 'l2h')
        elif self.current_epoch > 200:
            self.load_checkpoint(self.config.checkpoint_file_combined,
                                 'combined')
        if self.current_epoch != 0 and self.current_epoch <= 200:
            self.logger.info(
                "Checkpoint loaded successfully from '{}' and '{}' at (epoch {}) at (iteration {})\n"
                .format(self.config.checkpoint_l2h_dir,
                        self.config.checkpoint_h2l_dir, self.current_epoch,
                        self.current_iteration))

        for epoch in range(self.current_epoch, self.config.max_epoch):
            self.current_epoch = epoch
            if epoch <= 200:
                self.train_one_epoch_h2l()
                self.train_one_epoch_l2h()
                self.save_checkpoint(self.config.checkpoint_file_l2h, 'l2h')
                self.save_checkpoint(self.config.checkpoint_file_h2l, 'h2l')
            else:
                self.train_one_epoch_combined()
                self.save_checkpoint(self.config.checkpoint_file_combined,
                                     'combined')

    def to_var(self, data):
        real_cpu = data
        batchsize = real_cpu.size(0)
        inp = Variable(real_cpu.cuda())
        return inp, batchsize

    def train_one_epoch_h2l(self):
        test_loader = get_loader(self.config.HighToLow_hr_datapath,
                                 self.config.HighToLow_lr_datapath,
                                 self.config.batch_size)

        self.h2l_G.train()
        self.h2l_D.train()

        for curr_it, data_dict in enumerate(test_loader):
            data_low = data_dict['lr']
            data_high = data_dict['hr']
            data_input_low, batchsize = self.to_var(data_low)
            data_input_high, _ = self.to_var(data_high)

            y = torch.randn(data_low.size(0), )
            y, _ = self.to_var(y)

            ##################
            #  Train Generator
            ##################

            self.h2l_optimG.zero_grad()

            # Generate a high resolution image from low resolution input
            noise = torch.randn(data_high.size(0), 1)
            noise, _ = self.to_var(noise)
            gen_hr = self.h2l_G(data_input_high, noise)

            # Measure pixel-wise loss against ground truth
            loss_pixel = self.criterion_MSE(gen_hr, data_input_low)

            # Extract validity predictions from discriminator
            pred_real = self.h2l_D(data_input_high).detach()
            pred_fake = self.h2l_D(gen_hr)

            # Adversarial loss (relativistic average GAN)
            y.fill_(self.real_label)
            loss_G_GAN = self.criterion_GAN(
                pred_fake - pred_real.mean(0, keepdim=True), y)

            # Total generator loss
            loss_G = (self.config.beta * loss_G_GAN) + (self.config.alpha *
                                                        loss_pixel)

            loss_G.backward(retain_graph=True)
            self.h2l_optimG.step()

            ######################
            #  Train Discriminator
            ######################

            self.h2l_optimD.zero_grad()

            # Adversarial loss for real and fake images (relativistic average GAN)
            pred_real = self.h2l_D(data_input_high)
            y.fill_(self.real_label)
            loss_D_real = self.criterion_GAN(
                pred_real - pred_fake.mean(0, keepdim=True), y)
            loss_D_real.backward(retain_graph=True)

            pred_fake = self.h2l_D(gen_hr.detach())
            y.fill_(self.fake_label)
            loss_D_fake = self.criterion_GAN(
                pred_fake - pred_real.mean(0, keepdim=True), y)
            loss_D_fake.backward()
            # Total loss
            loss_D = (loss_D_real + loss_D_fake) / 2

            #loss_D.backward()
            self.h2l_optimD.step()

            self.current_iteration += 1

            self.summary_writer_h2l.add_scalar("epoch/Generator_loss",
                                               loss_G.item(),
                                               self.current_iteration)
            self.summary_writer_h2l.add_scalar("epoch/Discriminator_loss_real",
                                               loss_D_real.item(),
                                               self.current_iteration)
            self.summary_writer_h2l.add_scalar("epoch/Discriminator_loss_fake",
                                               loss_D_fake.item(),
                                               self.current_iteration)

            path = os.path.join(
                self.test_file, 'batch' + str(curr_it) + '_epoch' +
                str(self.current_epoch) + '_h2l.jpg')
            vutils.save_image(gen_hr.data, path, normalize=True)

            # --------------
            #  Log Progress
            # --------------

            self.logger.info(
                "High-To-Low GAN: [Epoch %d/%d] [Batch %d/%d] [D loss: %f, real: %f, fake: %f] [G loss: %f, adv: %f, pixel: %f]"
                % (
                    self.current_epoch + 1,
                    self.config.max_epoch,
                    curr_it + 1,
                    len(test_loader),
                    loss_D.item(),
                    loss_D_real.item(),
                    loss_D_fake.item(),
                    loss_G.item(),
                    loss_G_GAN.item(),
                    loss_pixel.item(),
                ))

    def train_one_epoch_l2h(self):
        test_loader = get_loader(self.config.LowToHigh_datapath, None,
                                 self.config.batch_size)

        self.l2h_G.train()
        self.l2h_D.train()

        for curr_it, data_dict in enumerate(test_loader):
            data_low = data_dict['img16']
            data_high = data_dict['img64']
            data_input_low, batchsize = self.to_var(data_low)
            data_input_high, _ = self.to_var(data_high)

            y = torch.randn(data_low.size(0), )
            y, _ = self.to_var(y)

            ##################
            #  Train Generator
            ##################

            self.l2h_optimG.zero_grad()

            # Generate a high resolution image from low resolution input
            gen_hr = self.l2h_G(data_input_low)

            # Measure pixel-wise loss against ground truth
            loss_pixel = self.criterion_MSE(gen_hr, data_input_high)

            # Extract validity predictions from discriminator
            pred_real = self.l2h_D(data_input_high).detach()
            pred_fake = self.l2h_D(gen_hr)

            # Adversarial loss (relativistic average GAN)
            y.fill_(self.real_label)
            loss_G_GAN = self.criterion_GAN(
                pred_fake - pred_real.mean(0, keepdim=True), y)

            # Total generator loss
            loss_G = (self.config.beta * loss_G_GAN) + (self.config.alpha *
                                                        loss_pixel)

            loss_G.backward(retain_graph=True)
            self.l2h_optimG.step()

            ######################
            #  Train Discriminator
            ######################

            self.l2h_optimD.zero_grad()

            # Adversarial loss for real and fake images (relativistic average GAN)
            pred_real = self.l2h_D(data_input_high)
            y.fill_(self.real_label)
            loss_D_real = self.criterion_GAN(
                pred_real - pred_fake.mean(0, keepdim=True), y)
            loss_D_real.backward(retain_graph=True)

            pred_fake = self.l2h_D(gen_hr.detach())
            y.fill_(self.fake_label)
            loss_D_fake = self.criterion_GAN(
                pred_fake - pred_real.mean(0, keepdim=True), y)
            loss_D_fake.backward()
            # Total loss
            loss_D = (loss_D_real + loss_D_fake) / 2

            self.l2h_optimD.step()

            self.current_iteration += 1

            self.summary_writer_l2h.add_scalar("epoch/Generator_loss",
                                               loss_G.item(),
                                               self.current_iteration)
            self.summary_writer_l2h.add_scalar("epoch/Discriminator_loss_real",
                                               loss_D_real.item(),
                                               self.current_iteration)
            self.summary_writer_l2h.add_scalar("epoch/Discriminator_loss_fake",
                                               loss_D_fake.item(),
                                               self.current_iteration)
            self.summary_writer_l2h.add_scalar("epoch/Discriminator_loss",
                                               loss_D.item(),
                                               self.current_iteration)

            path = os.path.join(
                self.test_file, 'batch' + str(curr_it) + '_epoch' +
                str(self.current_epoch) + '_l2h.jpg')
            vutils.save_image(gen_hr.data, path, normalize=True)

            # --------------
            #  Log Progress
            # --------------

            self.logger.info(
                "Low-To-High GAN: [Epoch %d/%d] [Batch %d/%d] [D loss: %f, real: %f, fake: %f] [G loss: %f, adv: %f, pixel: %f]"
                % (
                    self.current_epoch + 1,
                    self.config.max_epoch,
                    curr_it,
                    len(test_loader),
                    loss_D.item(),
                    loss_D_real.item(),
                    loss_D_fake.item(),
                    loss_G.item(),
                    loss_G_GAN.item(),
                    loss_pixel.item(),
                ))

    def train_one_epoch_combined(self):
        test_loader = get_loader(self.config.HighToLow_hr_datapath,
                                 self.config.HighToLow_lr_datapath,
                                 self.config.batch_size)

        self.h2l_G.train()
        self.h2l_D.train()
        self.l2h_G.train()
        self.l2h_D.train()

        for curr_it, data_dict in enumerate(test_loader):
            data_low = data_dict['lr']
            data_high = data_dict['hr']
            data_input_low, batchsize = self.to_var(data_low)
            data_input_high, _ = self.to_var(data_high)

            y = torch.randn(data_low.size(0), )
            y, _ = self.to_var(y)

            ##############################
            #  Train High-To-Low Generator
            ##############################

            self.h2l_optimG.zero_grad()

            # Generate a high resolution image from low resolution input
            noise = torch.randn(data_high.size(0), 1)
            noise, _ = self.to_var(noise)
            h2l_gen_hr = self.h2l_G(data_input_high, noise)

            # Measure pixel-wise loss against ground truth
            h2l_loss_pixel = self.criterion_MSE(h2l_gen_hr, data_input_low)

            # Extract validity predictions from discriminator
            h2l_pred_real = self.h2l_D(data_input_high).detach()
            h2l_pred_fake = self.h2l_D(h2l_gen_hr)

            # Adversarial loss (relativistic average GAN)
            y.fill_(self.real_label)
            h2l_loss_G_GAN = self.criterion_GAN(
                h2l_pred_fake - h2l_pred_real.mean(0, keepdim=True), y)

            # Total generator loss
            h2l_loss_G = (self.config.beta * h2l_loss_G_GAN) + (
                self.config.alpha * h2l_loss_pixel)

            h2l_loss_G.backward(retain_graph=True)
            self.h2l_optimG.step()

            ##################################
            #  Train High-To-Low Discriminator
            ##################################

            self.h2l_optimD.zero_grad()

            # Adversarial loss for real and fake images (relativistic average GAN)
            h2l_pred_real = self.h2l_D(data_input_high)
            y.fill_(self.real_label)
            h2l_loss_D_real = self.criterion_GAN(
                h2l_pred_real - h2l_pred_fake.mean(0, keepdim=True), y)
            h2l_loss_D_real.backward(retain_graph=True)

            h2l_pred_fake = self.h2l_D(h2l_gen_hr.detach())
            y.fill_(self.fake_label)
            h2l_loss_D_fake = self.criterion_GAN(
                h2l_pred_fake - h2l_pred_real.mean(0, keepdim=True), y)
            h2l_loss_D_fake.backward()
            # Total loss
            h2l_loss_D = (h2l_loss_D_real + h2l_loss_D_fake) / 2

            self.h2l_optimD.step()

            self.current_iteration += 1

            self.summary_writer_h2l.add_scalar("epoch/Generator_loss",
                                               h2l_loss_G.item(),
                                               self.current_iteration)
            self.summary_writer_h2l.add_scalar("epoch/Discriminator_loss_real",
                                               h2l_loss_D_real.item(),
                                               self.current_iteration)
            self.summary_writer_h2l.add_scalar("epoch/Discriminator_loss_fake",
                                               h2l_loss_D_fake.item(),
                                               self.current_iteration)
            self.summary_writer_h2l.add_scalar("epoch/Discriminator_loss",
                                               h2l_loss_D.item(),
                                               self.current_iteration)

            path = os.path.join(
                self.test_file, 'batch' + str(curr_it) + '_epoch' +
                str(self.current_epoch) + '_combined_intermidiate.jpg')
            vutils.save_image(h2l_gen_hr.data, path, normalize=True)

            # --------------
            #  Log Progress
            # --------------

            self.logger.info(
                "Combined model: High-To-Low GAN: [Epoch %d/%d] [Batch %d/%d] [D loss: %f, real: %f, fake: %f] [G loss: %f, adv: %f, pixel: %f]"
                % (
                    self.current_epoch + 1,
                    self.config.max_epoch,
                    curr_it + 1,
                    len(test_loader),
                    h2l_loss_D.item(),
                    h2l_loss_D_real.item(),
                    h2l_loss_D_fake.item(),
                    h2l_loss_G.item(),
                    h2l_loss_G_GAN.item(),
                    h2l_loss_pixel.item(),
                ))

            data_input_low = h2l_gen_hr

            y = torch.randn(data_input_low.size(0), )
            y, _ = self.to_var(y)

            ##############################
            #  Train Low-To-High Generator
            ##############################

            self.l2h_optimG.zero_grad()

            # Generate a high resolution image from low resolution input
            l2h_gen_hr = self.l2h_G(data_input_low)

            # Measure pixel-wise loss against ground truth
            l2h_loss_pixel = self.criterion_MSE(l2h_gen_hr, data_input_high)

            # Extract validity predictions from discriminator
            l2h_pred_real = self.l2h_D(data_input_high).detach()
            l2h_pred_fake = self.l2h_D(l2h_gen_hr)

            # Adversarial loss (relativistic average GAN)
            y.fill_(self.real_label)
            l2h_loss_G_GAN = self.criterion_GAN(
                l2h_pred_fake - l2h_pred_real.mean(0, keepdim=True), y)

            # Total generator loss
            l2h_loss_G = (self.config.beta * l2h_loss_G_GAN) + (
                self.config.alpha * l2h_loss_pixel)

            l2h_loss_G.backward(retain_graph=True)
            self.l2h_optimG.step()

            ##################################
            #  Train Low-To-High Discriminator
            ##################################

            self.l2h_optimD.zero_grad()

            # Adversarial loss for real and fake images (relativistic average GAN)
            l2h_pred_real = self.l2h_D(data_input_high)
            y.fill_(self.real_label)
            l2h_loss_D_real = self.criterion_GAN(
                l2h_pred_real - l2h_pred_fake.mean(0, keepdim=True), y)
            l2h_loss_D_real.backward(retain_graph=True)

            l2h_pred_fake = self.l2h_D(l2h_gen_hr.detach())
            y.fill_(self.fake_label)
            l2h_loss_D_fake = self.criterion_GAN(
                l2h_pred_fake - l2h_pred_real.mean(0, keepdim=True), y)
            l2h_loss_D_fake.backward()
            # Total loss
            l2h_loss_D = (l2h_loss_D_real + l2h_loss_D_fake) / 2

            self.l2h_optimD.step()

            self.current_iteration += 1

            self.summary_writer_l2h.add_scalar("epoch/Generator_loss",
                                               l2h_loss_G.item(),
                                               self.current_iteration)
            self.summary_writer_l2h.add_scalar("epoch/Discriminator_loss_real",
                                               l2h_loss_D_real.item(),
                                               self.current_iteration)
            self.summary_writer_l2h.add_scalar("epoch/Discriminator_loss_fake",
                                               l2h_loss_D_fake.item(),
                                               self.current_iteration)
            self.summary_writer_l2h.add_scalar("epoch/Discriminator_loss",
                                               l2h_loss_D.item(),
                                               self.current_iteration)

            path = os.path.join(
                self.test_file, 'batch' + str(curr_it) + '_epoch' +
                str(self.current_epoch) + '_combined_final.jpg')
            vutils.save_image(l2h_gen_hr.data, path, normalize=True)

            # --------------
            #  Log Progress
            # --------------

            self.logger.info(
                "Combined model: Low-To-High GAN: [Epoch %d/%d] [Batch %d/%d] [D loss: %f, real: %f, fake: %f] [G loss: %f, adv: %f, pixel: %f]"
                % (
                    self.current_epoch + 1,
                    self.config.max_epoch,
                    curr_it,
                    len(test_loader),
                    l2h_loss_D.item(),
                    l2h_loss_D_real.item(),
                    l2h_loss_D_fake.item(),
                    l2h_loss_G.item(),
                    l2h_loss_G_GAN.item(),
                    l2h_loss_pixel.item(),
                ))

    def validate(self):
        pass

    def finalize(self):
        """
        Finalize all the operations of the 2 Main classes of the process the operator and the data loader
        :return:
        """
        self.logger.info(
            "Please wait while finalizing the operation.. Thank you")
        self.save_checkpoint()
        self.summary_writer.export_scalars_to_json("{}all_scalars.json".format(
            self.config.summary_dir))
        self.summary_writer.close()
        self.dataloader.finalize()
class TestCode():
    def __init__(self, config, mode):
        super().__init__()

        self.config = config
        self.mode = mode

        self.h2l_G = HighToLowGenerator()
        self.l2h_G = LowToHighGenerator()

        self.is_cuda = torch.cuda.is_available()

        if self.is_cuda and not self.config.cuda:
            print(
                "WARNING: You have a CUDA device, so you should probably enable CUDA"
            )

        self.cuda = self.is_cuda & self.config.cuda
        self.manual_seed = random.randint(1, 10000)
        print('seed:{}'.format(self.manual_seed))
        random.seed(self.manual_seed)

        self.test_file = self.config.output_path
        if not os.path.exists(self.test_file):
            os.makedirs(self.test_file)

        if self.cuda:
            self.device = torch.device("cuda")
            torch.cuda.set_device(self.config.gpu_device)
            torch.cuda.manual_seed_all(self.manual_seed)
            print("Program will run on *****GPU-CUDA***** ")
            print_cuda_statistics()
        else:
            self.device = torch.device("cpu")
            torch.manual_seed(self.manual_seed)
            print("Program will run on *****CPU***** ")

        self.l2h_G = self.l2h_G.to(self.device)
        self.h2l_G = self.h2l_G.to(self.device)

    def to_var(self, data):
        real_cpu = data
        batchsize = real_cpu.size(0)
        input = Variable(real_cpu.cuda())
        return input, batchsize

    def load_checkpoint(self, file_name, model):
        if model == 'l2h':
            checkpoint_dir = self.config.checkpoint_l2h_dir
        elif model == 'h2l':
            checkpoint_dir = self.config.checkpoint_h2l_dir
        elif model == 'combined':
            checkpoint_dir = self.config.checkpoint_combined_dir

        filename = checkpoint_dir + file_name
        try:
            print("Loading checkpoint '{}'".format(filename))
            checkpoint = torch.load(filename)

            if model == 'h2l':
                self.h2l_G.load_state_dict(checkpoint['h2l_G_state_dict'])
                self.test_h2l()

            elif model == 'l2h':
                self.l2h_G.load_state_dict(checkpoint['l2h_G_state_dict'])
                self.l2h_G = self.l2h_G.eval()
                self.test_l2h()

            elif model == 'combined':
                self.h2l_G.load_state_dict(checkpoint['h2l_G_state_dict'])
                self.l2h_G.load_state_dict(checkpoint['l2h_G_state_dict'])
                self.l2h_G = self.l2h_G.eval()
                self.test_combined()

        except OSError:
            print("No checkpoint exists from '{}'. Skipping...".format(
                checkpoint_dir))
            print("**First time to train**")

    def save_checkpoint(self, file_name, model, is_best=0):
        state = {}

        if model == 'l2h':
            state['l2h_G_state_dict'] = self.l2h_G.state_dict()
            checkpoint_dir = self.config.checkpoint_l2h_dir

        elif model == 'h2l':
            state['h2l_G_state_dict'] = self.h2l_G.state_dict()
            checkpoint_dir = self.config.checkpoint_h2l_dir

        elif model == 'combined':
            state['l2h_G_state_dict'] = self.l2h_G.state_dict()
            state['h2l_G_state_dict'] = self.h2l_G.state_dict()
            checkpoint_dir = self.config.checkpoint_combined_dir

        # Save the state
        torch.save(state, checkpoint_dir + file_name)

    def test(self):
        try:
            if self.mode == 'h2l':
                self.load_checkpoint(self.config.checkpoint_file_h2l, 'h2l')
                self.test_h2l()
            elif self.mode == 'l2h':
                self.load_checkpoint(self.config.checkpoint_file_l2h, 'l2h')
                self.test_l2h()
            elif self.mode == 'combined':
                self.load_checkpoint(self.config.checkpoint_file_combined,
                                     'combined')
                self.test_combined()
        except KeyboardInterrupt:
            print("You have entered CTRL+C.. Wait to finalize")

    def test_h2l(self):
        test_loader = get_loader_h2l(self.config.test_hr_datapath,
                                     self.config.batch_size)
        for curr_it, data_dict in enumerate(test_loader):
            data_high = data_dict['hr']
            img_name = data_dict['imgpath'][0]
            img_name = img_name.split('\\')[-1]
            data_input_high, _ = self.to_var(data_high)
            noise = torch.randn(data_high.size(0), 1)
            noise, _ = self.to_var(noise)
            gen_lr = self.h2l_G(data_input_high, noise)
            path = os.path.join(self.test_file,
                                img_name.split('.')[0] + '_h2l.jpg')
            vutils.save_image(gen_lr.data, path, normalize=True)

    def test_l2h(self):
        test_loader = get_loader_l2h(self.config.test_lr_datapath,
                                     self.config.batch_size)
        for curr_it, data_dict in enumerate(test_loader):
            data_low = data_dict['lr']
            img_name = data_dict['imgpath'][0]
            img_name = img_name.split('\\')[-1]
            data_input_low, _ = self.to_var(data_low)
            gen_hr = self.l2h_G(data_input_low)
            path = os.path.join(self.test_file,
                                img_name.split('.')[0] + '_l2h.jpg')
            vutils.save_image(gen_hr.data, path, normalize=True)

    def test_combined(self):
        test_loader = get_loader_combined(self.config.test_hr_datapath,
                                          self.config.batch_size)
        for curr_it, data_dict in enumerate(test_loader):
            data_high = data_dict['hr']
            data_low = data_dict['lr']
            img_name = data_dict['imgpath'][0]
            img_name = img_name.split('\\')[-1]
            data_input_high, _ = self.to_var(data_high)
            data_input_low, _ = self.to_var(data_low)
            noise = torch.randn(data_high.size(0), 1)
            noise, _ = self.to_var(noise)
            # We figured out a mistake at very last moment in our high-to-low generator's pixel loss.
            # For pixel loss we should provide generated_lr image and donwssampled lr image of original
            # image. We passed generated lr image and actual lr image. So we dont have weights for
            # high-to-low model. Just for the sake of outputs we are doing some twik which is not
            # correct but though we are doing this.
            #gen_int_lr = self.h2l_G(data_input_high, noise)
            gen_hr = self.l2h_G(data_input_low)
            path_int = os.path.join(self.test_file,
                                    img_name.split('.')[0] + '_int_lr.jpg')
            path_final = os.path.join(self.test_file,
                                      img_name.split('.')[0] + '_final_hr.jpg')
            vutils.save_image(data_input_low.data, path_int, normalize=True)