コード例 #1
0
ファイル: train.py プロジェクト: sunleilei1216/no_projections
        iter_start_time = time.time()
        epoch_iter += args.batch_size

        model.set_input(
            data['image'], data['segmentation_outlined'], data['image_name']
        )
        model.optimize_parameters()

        if epoch_iter % args.print_freq == 0:
            errors = model.get_current_errors()
            t = (time.time() - iter_start_time) / args.batch_size
            visualizer.print_current_errors(epoch, epoch_iter, errors, t)

        if epoch_iter % args.display_freq == 0:
            visuals = model.get_current_visuals()
            visualizer.display_current_results(visuals, epoch)

        if epoch_iter % args.save_freq == 0:
            model.save("latest")
            print(
                "Saving current model to {}".format(checkpoints_dir)
            )

    if epoch - 1 != 0 and epoch % args.lr_decay_freq == 0:
        model.update_learning_rate(args.lr_decay)

    print('End of epoch %d / %d \t Time Taken: %d sec' %
          (epoch, args.num_epochs, time.time() - epoch_start_time))
    model.save(epoch)
    print(
        "Saving current model to {}".format(checkpoints_dir)
コード例 #2
0
    def train(self):

        # if self.config.visualize:
        visualizer = Visualizer()
        """Train StarGAN within a single dataset."""

        # Set dataloader
        self.data_loader = self.train_loader

        # The number of iterations per epoch
        iters_per_epoch = len(self.data_loader)

        fixed_x = []
        real_c = []
        for i, (imgs, labels, _) in enumerate(self.data_loader):
            fixed_x.append(imgs[0])
            real_c.append(labels)
            if i == 0:
                break

        # Fixed inputs and target domain labels for debugging
        fixed_x = torch.cat(fixed_x, dim=0)
        fixed_x = self.to_var(fixed_x, volatile=True)
        real_c = torch.cat(real_c, dim=0)

        fixed_c_list = self.make_celeb_labels(self.config.batch_size)

        # lr cache for decaying
        g_lr = self.config.g_lr
        d_lr = self.config.d_lr

        # Start with trained model if exists
        if self.config.pretrained_model:
            start = int(self.config.pretrained_model.split('_')[0]) - 1
        else:
            start = 0

        # Start training
        self.loss = {}
        start_time = time.time()

        for e in range(start, self.config.num_epochs):
            self.test(e)
            for i, (images, real_label,
                    identity) in enumerate(self.data_loader):

                real_x = images[0]

                if self.config.use_si:
                    real_ox = self.to_var(images[1])
                    real_oo = self.to_var(images[2])

                if self.config.id_cls_loss == 'cross':
                    identity = identity.squeeze()

                # Generate fake labels randomly (target domain labels)
                rand_idx = torch.randperm(real_label.size(0))

                fake_label = real_label[rand_idx]

                real_c = real_label.clone()
                fake_c = fake_label.clone()

                # Convert tensor to variable
                real_x = self.to_var(real_x)
                real_c = self.to_var(real_c)  # input for the generator
                fake_c = self.to_var(fake_c)
                real_label = self.to_var(
                    real_label
                )  # this is same as real_c if dataset == 'CelebA'
                fake_label = self.to_var(fake_label)
                identity = self.to_var(identity)

                # ================== Train D ================== #

                # Compute loss with real images
                if self.config.loss_id_cls:
                    out_src, out_cls, out_id_real = self.D(real_x)
                else:
                    out_src, out_cls = self.D(real_x)

                d_loss_real = -torch.mean(out_src)

                d_loss_cls = F.binary_cross_entropy_with_logits(
                    out_cls, real_label, size_average=False) / real_x.size(0)

                if self.config.loss_id_cls:
                    d_loss_id_cls = self.id_cls_criterion(
                        out_id_real, identity)
                    self.loss[
                        'D/loss_id_cls'] = self.config.lambda_id_cls * d_loss_id_cls.data[
                            0]
                else:
                    d_loss_id_cls = 0.0

                # Compute classification accuracy of the discriminator
                if (i + 1) % self.config.log_step == 0:
                    accuracies = self.compute_accuracy(out_cls.detach(),
                                                       real_label,
                                                       self.config.dataset)
                    log = [
                        "{:.2f}".format(acc)
                        for acc in accuracies.data.cpu().numpy()
                    ]
                    print('Classification Acc (20 classes): ')
                    print(log)
                    print('\n')

                # Compute loss with fake images
                if self.config.use_gpb:
                    fake_x, _ = self.G(real_x, fake_c)
                else:
                    fake_x = self.G(real_x, fake_c)
                fake_x = Variable(fake_x.data)

                if self.config.loss_id_cls:
                    out_src, out_cls, _ = self.D(fake_x.detach())
                else:
                    out_src, out_cls = self.D(fake_x.detach())

                d_loss_fake = torch.mean(out_src)

                # Backward + Optimize
                d_loss = d_loss_real + d_loss_fake + self.config.lambda_cls * d_loss_cls + d_loss_id_cls * self.config.lambda_id_cls
                self.reset_grad()
                d_loss.backward()
                self.d_optimizer.step()

                # Compute gradient penalty
                alpha = torch.rand(real_x.size(0), 1, 1,
                                   1).cuda().expand_as(real_x)
                interpolated = Variable(alpha * real_x.data +
                                        (1 - alpha) * fake_x.data,
                                        requires_grad=True)

                if self.config.loss_id_cls:
                    out, out_cls, _ = self.D(interpolated)
                else:
                    out, out_cls = self.D(interpolated)

                grad = torch.autograd.grad(outputs=out,
                                           inputs=interpolated,
                                           grad_outputs=torch.ones(
                                               out.size()).cuda(),
                                           retain_graph=True,
                                           create_graph=True,
                                           only_inputs=True)[0]

                grad = grad.view(grad.size(0), -1)
                grad_l2norm = torch.sqrt(torch.sum(grad**2, dim=1))
                d_loss_gp = torch.mean((grad_l2norm - 1)**2)

                # Backward + Optimize
                d_loss = self.config.lambda_gp * d_loss_gp
                self.reset_grad()
                d_loss.backward()
                self.d_optimizer.step()

                # Logging
                self.loss['D/loss_real'] = d_loss_real.data[0]
                self.loss['D/loss_fake'] = d_loss_fake.data[0]
                self.loss[
                    'D/loss_cls'] = self.config.lambda_cls * d_loss_cls.data[0]
                self.loss[
                    'D/loss_gp'] = self.config.lambda_gp * d_loss_gp.data[0]

                # ================== Train G ================== #
                if (i + 1) % self.config.d_train_repeat == 0:

                    self.img = {}
                    # Original-to-target and target-to-original domain
                    if self.config.use_gpb:
                        fake_x, id_vector_real_in_x = self.G(real_x, fake_c)
                        rec_x, id_vector_fake_in_x = self.G(
                            fake_x.detach(), real_c)
                    else:
                        fake_x = self.G(real_x, fake_c)
                        rec_x = self.G(fake_x.detach(), real_c)

                    # Compute losses
                    if self.config.loss_id_cls:
                        out_src, out_cls, out_id_fake = self.D(fake_x)
                    else:
                        out_src, out_cls = self.D(fake_x)

                    g_loss_fake = -torch.mean(out_src)
                    g_loss_rec = torch.mean(torch.abs(real_x - rec_x))

                    ### siamese loss
                    if self.config.use_si:
                        if self.config.use_gpb:
                            # feedforward
                            fake_ox, id_vector_ox = self.G(real_ox, fake_c)
                            fake_oo, id_vector_oo = self.G(real_oo, fake_c)

                            id_vector_ox = id_vector_ox.detach()
                            id_vector_oo = id_vector_oo.detach()

                            mdist = 1.0 - torch.mean(
                                torch.abs(id_vector_real_in_x - id_vector_oo))
                            mdist = torch.clamp(mdist, min=0.0)
                            g_loss_si = 0.5 * (torch.pow(
                                torch.mean(
                                    torch.abs(id_vector_real_in_x -
                                              id_vector_ox)), 2) +
                                               torch.pow(mdist, 2))

                            # backward
                            _, id_vector_ox = self.G(fake_ox.detach(), real_c)
                            _, id_vector_oo = self.G(fake_oo.detach(), real_c)

                            id_vector_ox = id_vector_ox.detach()
                            id_vector_oo = id_vector_oo.detach()

                            mdist = 1.0 - torch.mean(
                                torch.abs(id_vector_fake_in_x - id_vector_oo))
                            mdist = torch.clamp(mdist, min=0.0)
                            g_loss_si += 0.5 * (torch.pow(
                                torch.mean(
                                    torch.abs(id_vector_fake_in_x -
                                              id_vector_ox)), 2) +
                                                torch.pow(mdist, 2))

                            self.loss['G/g_loss_si'] = g_loss_si.data[0]
                        else:
                            fake_ox = self.G(real_ox, fake_c).detach()

                            fake_ooc = fake_c.data.cpu().numpy().copy()
                            fake_ooc = np.roll(fake_ooc,
                                               np.random.randint(
                                                   self.config.c_dim),
                                               axis=1)
                            fake_ooc = self.to_var(torch.FloatTensor(fake_ooc))

                            fake_oo = self.G(real_oo, fake_ooc).detach()
                            mdist = 1.0 - torch.mean(
                                torch.abs(fake_x - fake_oo))
                            mdist = torch.clamp(mdist, min=0.0)

                            g_loss_si = 0.5 * (torch.pow(
                                torch.mean(torch.abs(fake_x - fake_ox)), 2) +
                                               torch.pow(mdist, 2))
                            self.loss['G/g_loss_si'] = g_loss_si.data[0]
                    else:
                        g_loss_si = 0.0

                    ### id cls loss
                    if self.config.loss_id_cls:
                        g_loss_id_cls = self.id_cls_criterion(
                            out_id_fake, identity)
                        self.loss[
                            'G/g_loss_id_cls'] = self.config.lambda_id_cls * g_loss_id_cls.data[
                                0]
                    else:
                        g_loss_id_cls = 0.0

                    ### sym loss
                    if self.config.loss_symmetry:
                        g_loss_sym_fake = self.find_sym_img_and_cal_loss(
                            fake_x, fake_c,
                            True)  # cal. over samples w/ specific labels
                        g_loss_sym_rec = self.find_sym_img_and_cal_loss(
                            rec_x, real_c, True)

                        lap_fake_x = self.take_laplacian(fake_x)
                        lap_rec_x = self.take_laplacian(rec_x)
                        g_loss_sym_lap_fake = self.find_sym_img_and_cal_loss(
                            lap_fake_x, None, False)  # cal. over all samples
                        g_loss_sym_lap_rec = self.find_sym_img_and_cal_loss(
                            lap_rec_x, None, False)
                        sym_loss = (g_loss_sym_fake + g_loss_sym_rec +
                                    g_loss_sym_lap_fake + g_loss_sym_lap_rec)
                        self.loss[
                            'G/g_loss_sym'] = self.config.lambda_symmetry * sym_loss.data[
                                0]
                    else:
                        sym_loss = 0

                    ###id loss
                    if self.config.loss_id:
                        if self.config.use_gpb:
                            idx, _ = self.G(real_x, real_c)
                        else:
                            idx = self.G(real_x, real_c)
                        self.img['idx'] = idx

                        g_loss_id = torch.mean(torch.abs(real_x - idx))
                        self.loss[
                            'G/g_loss_id'] = self.config.lambda_idx * g_loss_id.data[
                                0]
                    else:
                        g_loss_id = 0

                    ###identity loss
                    if self.config.loss_identity:
                        real_x_f, real_x_p = self.get_feature(real_x)
                        fake_x_f, fake_x_p = self.get_feature(fake_x)
                        g_loss_identity = torch.mean(
                            torch.abs(real_x_f - fake_x_f))
                        g_loss_identity += torch.mean(
                            torch.abs(real_x_p - fake_x_p))

                        self.loss[
                            'G/g_loss_identity'] = self.config.lambda_identity * g_loss_identity.data[
                                0]
                    else:
                        g_loss_identity = 0

                    ###total var loss
                    if self.config.loss_tv:
                        g_tv_loss = (self.total_variation_loss(fake_x) +
                                     self.total_variation_loss(rec_x)) / 2
                        self.loss[
                            'G/tv_loss'] = self.config.lambda_tv * g_tv_loss.data[
                                0]
                    else:
                        g_tv_loss = 0

                    ### D's cls loss
                    g_loss_cls = F.binary_cross_entropy_with_logits(
                        out_cls, fake_label,
                        size_average=False) / fake_x.size(0)

                    # Backward + Optimize
                    g_loss = g_loss_fake +\
                             self.config.lambda_rec * g_loss_rec +\
                             self.config.lambda_cls * g_loss_cls+\
                             self.config.lambda_idx * g_loss_id+\
                             self.config.lambda_identity*g_loss_identity+\
                             self.config.lambda_tv*g_tv_loss+\
                             self.config.lambda_symmetry*sym_loss+\
                             self.config.lambda_id_cls * g_loss_id_cls+\
                             self.config.lambda_si * g_loss_si

                    self.reset_grad()
                    g_loss.backward()
                    self.g_optimizer.step()

                    # Logging
                    self.img['real_x'] = real_x
                    self.img['fake_x'] = fake_x
                    self.img['rec_x'] = rec_x
                    self.loss['G/loss_fake'] = g_loss_fake.data[0]
                    self.loss[
                        'G/loss_rec'] = self.config.lambda_rec * g_loss_rec.data[
                            0]
                    self.loss[
                        'G/loss_cls'] = self.config.lambda_cls * g_loss_cls.data[
                            0]
                    #

                # Print out log info
                if (i + 1) % self.config.log_step == 0:
                    elapsed = time.time() - start_time
                    elapsed = str(datetime.timedelta(seconds=elapsed))

                    log = "Elapsed [{}], Epoch [{}/{}], Iter [{}/{}]".format(
                        elapsed, e + 1, self.config.num_epochs, i + 1,
                        iters_per_epoch)

                    for tag, value in self.loss.items():
                        log += ", {}: {}".format(tag, value)
                    print(log)

                    if self.config.use_tensorboard:
                        for tag, value in self.loss.items():
                            self.logger.scalar_summary(
                                tag, value, e * iters_per_epoch + i + 1)

                # Translate fixed images for debugging
                if (i) % self.config.sample_step == 0:
                    fake_image_list = [fixed_x]
                    for fixed_c in fixed_c_list:
                        if self.config.use_gpb:
                            fake_image_list.append(self.G(fixed_x, fixed_c)[0])
                        else:
                            fake_image_list.append(self.G(fixed_x, fixed_c))
                    fake_images = torch.cat(fake_image_list, dim=3)

                    if not self.config.log_space:
                        save_image(self.denorm(fake_images.data),
                                   os.path.join(
                                       self.config.sample_path,
                                       '{}_{}_fake.png'.format(e + 1, i + 1)),
                                   nrow=1,
                                   padding=0)
                    else:
                        fake_images = self.denorm(fake_images.data) * 255.0
                        fake_images = torch.pow(
                            2.71828182846,
                            fake_images / 255.0 * np.log(256.0)) - 1.0
                        fake_images = fake_images / 255.0
                        fake_images = fake_images.clamp(0.0, 1.0)
                        save_image(fake_images,
                                   os.path.join(
                                       self.config.sample_path,
                                       '{}_{}_fake.png'.format(e + 1, i + 1)),
                                   nrow=1,
                                   padding=0)

                    print('Translated images and saved into {}..!'.format(
                        self.config.sample_path))

                # Save model checkpoints
                if (i + 1) % self.config.model_save_step == 0:
                    torch.save(
                        self.G.state_dict(),
                        os.path.join(self.config.model_save_path,
                                     '{}_{}_G.pth'.format(e + 1, i + 1)))
                    torch.save(
                        self.D.state_dict(),
                        os.path.join(self.config.model_save_path,
                                     '{}_{}_D.pth'.format(e + 1, i + 1)))
                if self.config.visualize and (i +
                                              1) % self.config.display_f == 0:
                    visualizer.display_current_results(self.img)
                    visualizer.plot_current_errors(
                        e,
                        float(i + 1) / iters_per_epoch, self.loss)

            # Decay learning rate
            if (e + 1) > (self.config.num_epochs -
                          self.config.num_epochs_decay):
                g_lr -= (self.config.g_lr /
                         float(self.config.num_epochs_decay))
                d_lr -= (self.config.d_lr /
                         float(self.config.num_epochs_decay))
                self.update_lr(g_lr, d_lr)
                print('Decay learning rate to g_lr: {}, d_lr: {}.'.format(
                    g_lr, d_lr))
コード例 #3
0
                ('content_loss0_edge',
                 content_loss0_edge.detach().cpu().float().numpy()),
                ('content_loss1_edge',
                 content_loss1_edge.detach().cpu().float().numpy()),
                ('L_img_edge', L_edge.detach().cpu().float().numpy()),
                ('my_psnr', my_psnr / (ccnt))
            ])
            t = (time.time() - iter_start_time) / opt.batchSize
            trainLogger.write(
                visualizer.print_current_losses(epoch, epoch_iter, losses, t,
                                                t_data) + '\n')
            visualizer.print_current_losses(epoch, epoch_iter, losses, t,
                                            t_data)
            r = float(epoch_iter) / (dataset_size * opt.batchSize)
            if opt.display_port != -1:
                visualizer.display_current_results(current_visuals, epoch,
                                                   False)
                visualizer.plot_current_losses(epoch, r, opt, losses)

            netG.train()
            netEdge.train()

    # save model and test performance in validation set
    if epoch % 1 == 0:

        print('hit')
        my_file = open("./" + opt.name + "_" + "evaluation.txt", 'a+')
        torch.save(netG.state_dict(),
                   '%s/netG_epoch_%d.pth' % (opt.exp, epoch))
        torch.save(netEdge.state_dict(),
                   '%s/netEdge_epoch_%d.pth' % (opt.exp, epoch))
        vcnt = 0
コード例 #4
0
        visualizer.reset()              # reset the visualizer: make sure it saves the results to HTML at least once every epoch
        model.update_learning_rate()    # update learning rates in the beginning of every epoch.
        for i, data in enumerate(dataset):  # inner loop within one epoch
            iter_start_time = time.time()  # timer for computation per iteration
            if total_iters % opt.print_freq == 0:
                t_data = iter_start_time - iter_data_time

            total_iters += opt.batch_size
            epoch_iter += opt.batch_size
            model.set_input(data)         # unpack data from dataset and apply preprocessing
            model.optimize_parameters()   # calculate loss functions, get gradients, update network weights

            if total_iters % opt.display_freq == 0:   # display images on visdom and save images to a HTML file
                save_result = total_iters % opt.update_html_freq == 0
                model.compute_visuals()
                visualizer.display_current_results(model.get_current_visuals(), epoch, save_result)

            if total_iters % opt.print_freq == 0:    # print training losses and save logging information to the disk
                losses = model.get_current_losses()
                t_comp = (time.time() - iter_start_time) / opt.batch_size
                visualizer.print_current_losses(epoch, epoch_iter, losses, t_comp, t_data)
                if opt.display_id > 0:
                    visualizer.plot_current_losses(epoch, float(epoch_iter) / dataset_size, losses)

            if total_iters % opt.save_latest_freq == 0:   # cache our latest model every <save_latest_freq> iterations
                print('saving the latest model (epoch %d, total_iters %d)' % (epoch, total_iters))
                save_suffix = 'iter_%d' % total_iters if opt.save_by_iter else 'latest'
                model.save_networks(save_suffix)

            iter_data_time = time.time()
        if epoch % opt.save_epoch_freq == 0:              # cache our model every <save_epoch_freq> epochs
コード例 #5
0
    def train(self):
        self.genA2B.train(), self.genB2A.train(), self.disGA.train(
        ), self.disGB.train(), self.disLA.train(), self.disLB.train()

        start_iter = 1
        if self.resume:
            model_list = glob(
                os.path.join(self.result_dir, self.dataset, 'model', '*.pt'))
            if not len(model_list) == 0:
                model_list.sort()
                start_iter = int(model_list[-1].split('_')[-1].split('.')[0])
                self.load(os.path.join(self.result_dir, self.dataset, 'model'),
                          start_iter)
                print(" [*] Load SUCCESS")
                if self.decay_flag and start_iter > (self.iteration // 2):
                    self.G_optim.param_groups[0]['lr'] -= (
                        self.lr /
                        (self.iteration // 2)) * (start_iter -
                                                  self.iteration // 2)
                    self.D_optim.param_groups[0]['lr'] -= (
                        self.lr /
                        (self.iteration // 2)) * (start_iter -
                                                  self.iteration // 2)

        # training loop
        print('training start !')

        # visualize training process...
        visualizer = Visualizer(
            self.opt)  # create a visualizer that display/save images and plots

        start_time = time.time()
        for step in range(start_iter, self.iteration + 1):
            if self.decay_flag and step > (self.iteration // 2):
                self.G_optim.param_groups[0]['lr'] -= (self.lr /
                                                       (self.iteration // 2))
                self.D_optim.param_groups[0]['lr'] -= (self.lr /
                                                       (self.iteration // 2))

            try:
                self.real_A, _ = trainA_iter.next()
            except:
                trainA_iter = iter(self.trainA_loader)
                self.real_A, _ = trainA_iter.next()

            try:
                self.real_B, _ = trainB_iter.next()
            except:
                trainB_iter = iter(self.trainB_loader)
                self.real_B, _ = trainB_iter.next()

            self.real_A, self.real_B = self.real_A.to(
                self.device), self.real_B.to(self.device)

            # Update D
            self.D_optim.zero_grad()

            self.fake_A2B, _, _ = self.genA2B(self.real_A)
            self.fake_B2A, _, _ = self.genB2A(self.real_B)

            real_GA_logit, real_GA_cam_logit, _ = self.disGA(self.real_A)
            real_LA_logit, real_LA_cam_logit, _ = self.disLA(self.real_A)
            real_GB_logit, real_GB_cam_logit, _ = self.disGB(self.real_B)
            real_LB_logit, real_LB_cam_logit, _ = self.disLB(self.real_B)

            fake_GA_logit, fake_GA_cam_logit, _ = self.disGA(self.fake_B2A)
            fake_LA_logit, fake_LA_cam_logit, _ = self.disLA(self.fake_B2A)
            fake_GB_logit, fake_GB_cam_logit, _ = self.disGB(self.fake_A2B)
            fake_LB_logit, fake_LB_cam_logit, _ = self.disLB(self.fake_A2B)

            D_ad_loss_GA = self.MSE_loss(
                real_GA_logit,
                torch.ones_like(real_GA_logit).to(
                    self.device)) + self.MSE_loss(
                        fake_GA_logit,
                        torch.zeros_like(fake_GA_logit).to(self.device))
            D_ad_cam_loss_GA = self.MSE_loss(
                real_GA_cam_logit,
                torch.ones_like(real_GA_cam_logit).to(
                    self.device)) + self.MSE_loss(
                        fake_GA_cam_logit,
                        torch.zeros_like(fake_GA_cam_logit).to(self.device))
            D_ad_loss_LA = self.MSE_loss(
                real_LA_logit,
                torch.ones_like(real_LA_logit).to(
                    self.device)) + self.MSE_loss(
                        fake_LA_logit,
                        torch.zeros_like(fake_LA_logit).to(self.device))
            D_ad_cam_loss_LA = self.MSE_loss(
                real_LA_cam_logit,
                torch.ones_like(real_LA_cam_logit).to(
                    self.device)) + self.MSE_loss(
                        fake_LA_cam_logit,
                        torch.zeros_like(fake_LA_cam_logit).to(self.device))
            D_ad_loss_GB = self.MSE_loss(
                real_GB_logit,
                torch.ones_like(real_GB_logit).to(
                    self.device)) + self.MSE_loss(
                        fake_GB_logit,
                        torch.zeros_like(fake_GB_logit).to(self.device))
            D_ad_cam_loss_GB = self.MSE_loss(
                real_GB_cam_logit,
                torch.ones_like(real_GB_cam_logit).to(
                    self.device)) + self.MSE_loss(
                        fake_GB_cam_logit,
                        torch.zeros_like(fake_GB_cam_logit).to(self.device))
            D_ad_loss_LB = self.MSE_loss(
                real_LB_logit,
                torch.ones_like(real_LB_logit).to(
                    self.device)) + self.MSE_loss(
                        fake_LB_logit,
                        torch.zeros_like(fake_LB_logit).to(self.device))
            D_ad_cam_loss_LB = self.MSE_loss(
                real_LB_cam_logit,
                torch.ones_like(real_LB_cam_logit).to(
                    self.device)) + self.MSE_loss(
                        fake_LB_cam_logit,
                        torch.zeros_like(fake_LB_cam_logit).to(self.device))

            D_loss_A = self.adv_weight * (D_ad_loss_GA + D_ad_cam_loss_GA +
                                          D_ad_loss_LA + D_ad_cam_loss_LA)
            D_loss_B = self.adv_weight * (D_ad_loss_GB + D_ad_cam_loss_GB +
                                          D_ad_loss_LB + D_ad_cam_loss_LB)

            Discriminator_loss = D_loss_A + D_loss_B
            Discriminator_loss.backward()
            self.D_optim.step()

            # Update G
            self.G_optim.zero_grad()

            self.fake_A2B, fake_A2B_cam_logit, _ = self.genA2B(self.real_A)
            self.fake_B2A, fake_B2A_cam_logit, _ = self.genB2A(self.real_B)

            self.fake_A2B2A, _, _ = self.genB2A(self.fake_A2B)
            self.fake_B2A2B, _, _ = self.genA2B(self.fake_B2A)

            self.fake_A2A, fake_A2A_cam_logit, _ = self.genB2A(self.real_A)
            self.fake_B2B, fake_B2B_cam_logit, _ = self.genA2B(self.real_B)

            fake_GA_logit, fake_GA_cam_logit, _ = self.disGA(self.fake_B2A)
            fake_LA_logit, fake_LA_cam_logit, _ = self.disLA(self.fake_B2A)
            fake_GB_logit, fake_GB_cam_logit, _ = self.disGB(self.fake_A2B)
            fake_LB_logit, fake_LB_cam_logit, _ = self.disLB(self.fake_A2B)

            G_ad_loss_GA = self.MSE_loss(
                fake_GA_logit,
                torch.ones_like(fake_GA_logit).to(self.device))
            G_ad_cam_loss_GA = self.MSE_loss(
                fake_GA_cam_logit,
                torch.ones_like(fake_GA_cam_logit).to(self.device))
            G_ad_loss_LA = self.MSE_loss(
                fake_LA_logit,
                torch.ones_like(fake_LA_logit).to(self.device))
            G_ad_cam_loss_LA = self.MSE_loss(
                fake_LA_cam_logit,
                torch.ones_like(fake_LA_cam_logit).to(self.device))
            G_ad_loss_GB = self.MSE_loss(
                fake_GB_logit,
                torch.ones_like(fake_GB_logit).to(self.device))
            G_ad_cam_loss_GB = self.MSE_loss(
                fake_GB_cam_logit,
                torch.ones_like(fake_GB_cam_logit).to(self.device))
            G_ad_loss_LB = self.MSE_loss(
                fake_LB_logit,
                torch.ones_like(fake_LB_logit).to(self.device))
            G_ad_cam_loss_LB = self.MSE_loss(
                fake_LB_cam_logit,
                torch.ones_like(fake_LB_cam_logit).to(self.device))

            G_recon_loss_A = self.L1_loss(self.fake_A2B2A, self.real_A)
            G_recon_loss_B = self.L1_loss(self.fake_B2A2B, self.real_B)

            G_identity_loss_A = self.L1_loss(self.fake_A2A, self.real_A)
            G_identity_loss_B = self.L1_loss(self.fake_B2B, self.real_B)

            G_cam_loss_A = self.BCE_loss(
                fake_B2A_cam_logit,
                torch.ones_like(fake_B2A_cam_logit).to(
                    self.device)) + self.BCE_loss(
                        fake_A2A_cam_logit,
                        torch.zeros_like(fake_A2A_cam_logit).to(self.device))
            G_cam_loss_B = self.BCE_loss(
                fake_A2B_cam_logit,
                torch.ones_like(fake_A2B_cam_logit).to(
                    self.device)) + self.BCE_loss(
                        fake_B2B_cam_logit,
                        torch.zeros_like(fake_B2B_cam_logit).to(self.device))

            G_loss_A = self.adv_weight * (
                G_ad_loss_GA + G_ad_cam_loss_GA + G_ad_loss_LA +
                G_ad_cam_loss_LA
            ) + self.cycle_weight * G_recon_loss_A + self.identity_weight * G_identity_loss_A + self.cam_weight * G_cam_loss_A
            G_loss_B = self.adv_weight * (
                G_ad_loss_GB + G_ad_cam_loss_GB + G_ad_loss_LB +
                G_ad_cam_loss_LB
            ) + self.cycle_weight * G_recon_loss_B + self.identity_weight * G_identity_loss_B + self.cam_weight * G_cam_loss_B

            Generator_loss = G_loss_A + G_loss_B
            Generator_loss.backward()
            self.G_optim.step()

            # clip parameter of AdaILN and ILN, applied after optimizer step
            self.genA2B.apply(self.Rho_clipper)
            self.genB2A.apply(self.Rho_clipper)

            print("[%5d/%5d] time: %4.4f d_loss: %.8f, g_loss: %.8f" %
                  (step, self.iteration, time.time() - start_time,
                   Discriminator_loss, Generator_loss))
            if step % self.print_freq == 0:
                train_sample_num = 5
                test_sample_num = 5
                A2B = np.zeros((self.img_size * 7, 0, 3))
                B2A = np.zeros((self.img_size * 7, 0, 3))

                self.genA2B.eval(), self.genB2A.eval(), self.disGA.eval(
                ), self.disGB.eval(), self.disLA.eval(), self.disLB.eval()
                for _ in range(train_sample_num):
                    try:
                        self.real_A, _ = trainA_iter.next()
                    except:
                        trainA_iter = iter(self.trainA_loader)
                        self.real_A, _ = trainA_iter.next()

                    try:
                        self.real_B, _ = trainB_iter.next()
                    except:
                        trainB_iter = iter(self.trainB_loader)
                        self.real_B, _ = trainB_iter.next()
                    self.real_A, self.real_B = self.real_A.to(
                        self.device), self.real_B.to(self.device)

                    self.fake_A2B, _, fake_A2B_heatmap = self.genA2B(
                        self.real_A)
                    self.fake_B2A, _, fake_B2A_heatmap = self.genB2A(
                        self.real_B)

                    self.fake_A2B2A, _, fake_A2B2A_heatmap = self.genB2A(
                        self.fake_A2B)
                    self.fake_B2A2B, _, fake_B2A2B_heatmap = self.genA2B(
                        self.fake_B2A)

                    self.fake_A2A, _, fake_A2A_heatmap = self.genB2A(
                        self.real_A)
                    self.fake_B2B, _, fake_B2B_heatmap = self.genA2B(
                        self.real_B)

                    A2B = np.concatenate(
                        (A2B,
                         np.concatenate(
                             (RGB2BGR(tensor2numpy(denorm(self.real_A[0]))),
                              cam(tensor2numpy(fake_A2A_heatmap[0]),
                                  self.img_size),
                              RGB2BGR(tensor2numpy(denorm(self.fake_A2A[0]))),
                              cam(tensor2numpy(fake_A2B_heatmap[0]),
                                  self.img_size),
                              RGB2BGR(tensor2numpy(denorm(self.fake_A2B[0]))),
                              cam(tensor2numpy(fake_A2B2A_heatmap[0]),
                                  self.img_size),
                              RGB2BGR(tensor2numpy(denorm(
                                  self.fake_A2B2A[0])))), 0)), 1)

                    B2A = np.concatenate(
                        (B2A,
                         np.concatenate(
                             (RGB2BGR(tensor2numpy(denorm(self.real_B[0]))),
                              cam(tensor2numpy(fake_B2B_heatmap[0]),
                                  self.img_size),
                              RGB2BGR(tensor2numpy(denorm(self.fake_B2B[0]))),
                              cam(tensor2numpy(fake_B2A_heatmap[0]),
                                  self.img_size),
                              RGB2BGR(tensor2numpy(denorm(self.fake_B2A[0]))),
                              cam(tensor2numpy(fake_B2A2B_heatmap[0]),
                                  self.img_size),
                              RGB2BGR(tensor2numpy(denorm(
                                  self.fake_B2A2B[0])))), 0)), 1)

                for _ in range(test_sample_num):
                    try:
                        self.real_A, _ = testA_iter.next()
                    except:
                        testA_iter = iter(self.testA_loader)
                        self.real_A, _ = testA_iter.next()

                    try:
                        self.real_B, _ = testB_iter.next()
                    except:
                        testB_iter = iter(self.testB_loader)
                        self.real_B, _ = testB_iter.next()
                    self.real_A, self.real_B = self.real_A.to(
                        self.device), self.real_B.to(self.device)

                    self.fake_A2B, _, fake_A2B_heatmap = self.genA2B(
                        self.real_A)
                    self.fake_B2A, _, fake_B2A_heatmap = self.genB2A(
                        self.real_B)

                    self.fake_A2B2A, _, fake_A2B2A_heatmap = self.genB2A(
                        self.fake_A2B)
                    self.fake_B2A2B, _, fake_B2A2B_heatmap = self.genA2B(
                        self.fake_B2A)

                    self.fake_A2A, _, fake_A2A_heatmap = self.genB2A(
                        self.real_A)
                    self.fake_B2B, _, fake_B2B_heatmap = self.genA2B(
                        self.real_B)

                    A2B = np.concatenate(
                        (A2B,
                         np.concatenate(
                             (RGB2BGR(tensor2numpy(denorm(self.real_A[0]))),
                              cam(tensor2numpy(fake_A2A_heatmap[0]),
                                  self.img_size),
                              RGB2BGR(tensor2numpy(denorm(self.fake_A2A[0]))),
                              cam(tensor2numpy(fake_A2B_heatmap[0]),
                                  self.img_size),
                              RGB2BGR(tensor2numpy(denorm(self.fake_A2B[0]))),
                              cam(tensor2numpy(fake_A2B2A_heatmap[0]),
                                  self.img_size),
                              RGB2BGR(tensor2numpy(denorm(
                                  self.fake_A2B2A[0])))), 0)), 1)

                    B2A = np.concatenate(
                        (B2A,
                         np.concatenate(
                             (RGB2BGR(tensor2numpy(denorm(self.real_B[0]))),
                              cam(tensor2numpy(fake_B2B_heatmap[0]),
                                  self.img_size),
                              RGB2BGR(tensor2numpy(denorm(self.fake_B2B[0]))),
                              cam(tensor2numpy(fake_B2A_heatmap[0]),
                                  self.img_size),
                              RGB2BGR(tensor2numpy(denorm(self.fake_B2A[0]))),
                              cam(tensor2numpy(fake_B2A2B_heatmap[0]),
                                  self.img_size),
                              RGB2BGR(tensor2numpy(denorm(
                                  self.fake_B2A2B[0])))), 0)), 1)

                cv2.imwrite(
                    os.path.join(self.result_dir, self.dataset, 'img',
                                 'A2B_%07d.png' % step), A2B * 255.0)
                cv2.imwrite(
                    os.path.join(self.result_dir, self.dataset, 'img',
                                 'B2A_%07d.png' % step), B2A * 255.0)
                self.genA2B.train(), self.genB2A.train(), self.disGA.train(
                ), self.disGB.train(), self.disLA.train(), self.disLB.train()

            if step % self.opt.display_freq == 0:
                # self.compute_visuals() # in pytorch_cyclegan,only used in colorization_model
                visualizer.display_current_results(self.get_current_visuals(),
                                                   step, True)

            if step % self.save_freq == 0:
                self.save(os.path.join(self.result_dir, self.dataset, 'model'),
                          step)

            if step % 1000 == 0:
                params = {}
                params['genA2B'] = self.genA2B.state_dict()
                params['genB2A'] = self.genB2A.state_dict()
                params['disGA'] = self.disGA.state_dict()
                params['disGB'] = self.disGB.state_dict()
                params['disLA'] = self.disLA.state_dict()
                params['disLB'] = self.disLB.state_dict()
                torch.save(
                    params,
                    os.path.join(self.result_dir,
                                 self.dataset + '_params_latest.pt'))

                visualizer.reset(
                )  # reset the visualizer: make sure it saves the results to HTML at least once every epoch
コード例 #6
0
ファイル: main.py プロジェクト: FlyingCarrot/examples
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        netG.zero_grad()
        label.fill_(real_label)  # fake labels are real for generator cost
        output = netD(fake)
        errG = criterion(output, label)
        errG.backward()
        D_G_z2 = output.mean().item()
        optimizerG.step()

        if i % 100 == 0:
            visuals = OrderedDict()
            visuals["real"] = real_cpu.detach()
            visuals["fake"] = fake.detach()
            losses = OrderedDict()
            losses["D_loss"] = errD.detach()
            losses["G_loss"] = errG.detach()
            losses["D"] = D_x
            losses["D_G_z1"] = D_G_z1
            losses["D_G_z2"] = D_G_z2
            vis.display_current_results(visuals, epoch, True)
            vis.plot_current_losses(epoch, i / len(dataloader), losses)
            print(
                '[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f'
                % (epoch, opt.niter, i, len(dataloader), errD.item(),
                   errG.item(), D_x, D_G_z1, D_G_z2))

    # do checkpointing
    torch.save(netG.state_dict(), '%s/netG_epoch_%d.pth' % (opt.outf, epoch))
    torch.save(netD.state_dict(), '%s/netD_epoch_%d.pth' % (opt.outf, epoch))
コード例 #7
0
def test_sample(folder_name, frame_no, list_of_boxes):
    global ROOT_DIR
    global vis

    if not vis:
        from visualizer import Visualizer
        settings = Object()
        settings.display_winsize = 8
        settings.display_id = 1
        settings.print_freq = 200
        settings.port_id = 8097
        settings.name = 'pcn'
        vis = Visualizer(settings)

    args = get_parameters()
    args.device = torch.device("cuda:%d" %
                               (args.gpu_id) if torch.cuda.is_available() else
                               "cpu")  # for selecting device for chamfer loss
    args.checkpoints_dir = args.save_path
    torch.cuda.set_device(args.gpu_id)

    #os.makedirs(os.path.join('plots'), exist_ok=True)
    #os.makedirs(os.path.join('completions'), exist_ok=True)

    if (Kitti):
        data_dir = os.path.join(
            ROOT_DIR, "..", "test_dataset/{}/bin_data/{:010d}.bin".format(
                folder_name, int(frame_no)))
    elif (Argo):
        data_dir = os.path.join(
            ROOT_DIR, "..",
            "test_dataset/{}/sample/argoverse/lidar/{:03d}.bin".format(
                folder_name, int(frame_no)))
    #pcd_dir = os.path.join(ROOT_DIR, "..", "test_dataset", folder_name, "cars")
    #bbox_dir = os.path.join(ROOT_DIR, "..", "test_dataset", folder_name, "bboxes")
    results_dir = os.path.join(ROOT_DIR, "..", "test_dataset", "results")
    print("data_dir ", data_dir)
    print(args)

    total_time = 0
    total_points = 0
    plot_id = 0

    c = np.cos
    s = np.sin
    rot_mat = lambda x: [[c(x), -s(x), 0], [s(x), c(x), 0], [0, 0, 1.0]]
    # avg_ = np.array([-5.5450159e-03, 3.1248237e-05, -1.7545074e-02])
    # var_ = np.array([0.06499612, 0.01349955, 0.00575584])
    # min_ = np.array([-0.4482678174972534, -0.17956072092056274, -0.12795628607273102])
    # max_ = np.array([0.4482678174972534, 0.17956072092056274, 0.12795628607273102])
    points = np.fromfile(data_dir, dtype=np.float32).reshape((-1, 4))[:, :3]
    print("Points look like: ")
    print(points[:10, :])
    #for i, car_id in enumerate(car_ids):
    concat_complete = list()
    for j in range(len(list_of_boxes)):
        car_id = folder_name + '_frame_%d_car_%d' % (int(frame_no), j)
        cur_box = list_of_boxes[j]
        bbox = list(
            getCorners(0,
                       cur_box['width'] + 0.3 * cur_box['width'],
                       cur_box['length'] + 0.3 * cur_box['length'],
                       cur_box['center']['x'],
                       cur_box['center']['y'],
                       cur_box['height'] / 2,
                       -cur_box['angle'],
                       rotation=True))
        bbox = np.array(bbox)

        # original pointcloud
        partial_xyz = np.array(
            [point for point in points if within_bbox(point, bbox)])
        total_points += partial_xyz.shape[0]
        plot_path = os.path.join(results_dir, 'plots', '%s.png' % car_id)
        # plot_pcd_three_views(plot_path, [partial_xyz], ['input'],'%d input points' % partial_xyz.shape[0], [5])

        # Calculate center, rotation and scale
        center = (bbox.min(0) + bbox.max(0)) / 2
        bbox -= center
        yaw = np.arctan2(bbox[3, 1] - bbox[0, 1], bbox[3, 0] - bbox[0, 0])
        rotation = np.array([[np.cos(yaw), -np.sin(yaw), 0],
                             [np.sin(yaw), np.cos(yaw), 0], [0, 0, 1]])
        bbox = np.dot(bbox, rotation)
        # scale = bbox[3, 0] - bbox[0, 0]
        # bbox /= scale

        # Scope 0
        # Scale and rotate to make cannonical, Rotate from y axis to x axis
        partial_xyz_center = np.dot(partial_xyz - center, rotation)  #/ scale
        partial_xyz_center = np.dot(partial_xyz_center,
                                    np.array(rot_mat(90 * 3.14 / 180)))

        # Scope 1
        # Ground removal
        partial_xyz_center_gr = partial_xyz_center[
            partial_xyz_center[:, 2] > min(partial_xyz_center[:, 2]) + 0.2]

        # Scope 2
        # Height based point removal
        base_pt = np.min(partial_xyz_center_gr[:, 2])
        ht_pt = base_pt + 1.6
        partial_xyz_center_ht = partial_xyz_center_gr[
            partial_xyz_center_gr[:, 2] < ht_pt]

        max_x, max_y, max_z = np.max(partial_xyz_center_ht, axis=0)
        min_x, min_y, min_z = np.min(partial_xyz_center_ht, axis=0)
        mheight = min_z - max_z
        mheight = mheight if mheight > 1.3 else 1.6
        mcenter_height = max_z + mheight / 2

        # Scope 3
        # Scale input
        min_ = np.array([
            -0.4482678174972534, -0.17956072092056274, -0.12795628607273102
        ])  # shapenet specific change according to the traning dataset
        max_ = np.array(
            [0.4482678174972534, 0.17956072092056274, 0.12795628607273102])
        maxx = np.max(partial_xyz_center_ht, axis=0)
        minn = np.min(partial_xyz_center_ht, axis=0)
        avg_ = np.array([-5.5450159e-03, 3.1248237e-05, -1.7545074e-02])
        var_ = np.array([0.06499612, 0.01349955, 0.00575584])

        partial_xyz_center_ht[:, 2] += 1.6 + (1) / 2
        # print(mcenter_height)
        # partial_xyz_center_ht[:, 2] -=mcenter_height

        act_mean = np.mean(partial_xyz_center_ht)
        act_var = np.var(partial_xyz_center_ht, axis=0)
        # partial_xyz_center_gr_ht_scaled = avg_ + (partial_xyz_center_ht - act_mean)* var_/act_var

        partial_xyz_center_gr_ht_scaled = ((partial_xyz_center_ht) /
                                           (maxx[2] - minn[2])) * (0.38)

        # Get output
        start = time.time()
        completion_xyz_center_out = main(
            args, input_data=partial_xyz_center_gr_ht_scaled)
        if completion_xyz_center_out.shape[0] == 3:
            completion_xyz_center_out = completion_xyz_center_out.transpose()
        total_time += time.time() - start

        # completion_xyz = np.dot(completion_xyz_center * scale, rotation.T) + center
        # Scale output back
        completion_xyz_center = (
            (completion_xyz_center_out) / 0.295) * (maxx[2] - minn[2])
        # completion_xyz_center = act_mean + (completion_xyz_center_out - avg_)* act_var/var_
        completion_xyz_center[:, 2] -= 1.3 + (1) / 2
        # completion_xyz_center[:, 2] +=mcenter_height

        # Rotate output back
        completion_xyz_center = np.dot(completion_xyz_center,
                                       np.array(rot_mat(-90 * 3.14 / 180)))
        # Rotate translate output back
        completion_xyz = np.dot(completion_xyz_center, rotation.T) + center
        pcd_path = os.path.join(results_dir, 'completions_rl',
                                '%s.pcd' % car_id)

        if vis:
            print("plotting")

            # Scope 0
            # Check axis x and y. Check if at (0,0)
            visuals = OrderedDict([
                ('Partial_pc partial_xyz_center ', partial_xyz_center)
            ])  # ('Complete Predicted_pc ', completion)
            vis.display_current_results(visuals, 0, plot_id)
            plot_id += 1

            # Scope 1
            # Check ground removed points
            visuals = OrderedDict([
                ('Partial_pc partial_xyz_center_gr ', partial_xyz_center_gr)
            ])  # ('Complete Predicted_pc ', completion)
            vis.display_current_results(visuals, 0, plot_id)
            plot_id += 1

            # Scope 2
            # Check height removed points
            visuals = OrderedDict([
                ('Partial_pc partial_xyz_center_ht ', partial_xyz_center_ht)
            ])  # ('Complete Predicted_pc ', completion)
            vis.display_current_results(visuals, 0, plot_id)
            plot_id += 1

            # Scope 3
            # Check height scaled points
            visuals = OrderedDict([
                ('Partial_pc partial_xyz_center_gr_ht_scaled ',
                 partial_xyz_center_gr_ht_scaled)
            ])  # ('Complete Predicted_pc ', completion)
            vis.display_current_results(visuals, 0, plot_id)
            plot_id += 1

            # Check direct output
            visuals = OrderedDict([('Complete Predicted_pc ',
                                    completion_xyz_center_out)])
            vis.display_current_results(visuals, 0, plot_id)
            plot_id += 1

            # Check output scaled
            visuals = OrderedDict([('Complete Predicted_pc ',
                                    completion_xyz_center)])
            vis.display_current_results(visuals, 0, plot_id)
            plot_id += 1

            # Check translated and rotaed to original poistion
            visuals = OrderedDict([('Complete Predicted_pc ', completion_xyz)])
            vis.display_current_results(visuals, 0, plot_id)
            plot_id += 1

            both_scaled = np.concatenate(
                (partial_xyz_center_gr_ht_scaled, completion_xyz_center_out))
            both_final = np.concatenate((partial_xyz_center, completion_xyz))

            visuals = OrderedDict([('Both pcs scaled', both_scaled)])
            vis.display_current_results(visuals, 0, plot_id)
            plot_id += 1

            visuals = OrderedDict([('Both pcs original', both_final)])
            vis.display_current_results(visuals, 0, plot_id)
            plot_id += 1

        concat_complete.append(completion_xyz)

    print('Average # input points:', total_points / len(list_of_boxes))
    print('Average time:', total_time / len(list_of_boxes))
    return concat_complete
コード例 #8
0
ファイル: main.py プロジェクト: luohongming/SR_baseline
def main():
    opt = parser.parse_args()
    torch.manual_seed(opt.seed)

    # set gpu_ids
    str_ids = opt.gpu_ids.split(',')
    opt.gpu_ids = []
    for str_id in str_ids:
        id = int(str_id)
        if id >= 0:
            opt.gpu_ids.append(id)
    if len(opt.gpu_ids) > 0:
        torch.cuda.set_device(opt.gpu_ids[0])

    cudnn.benchmark = True
    print(opt)

    if opt.SR_name == 'IDN':
        model = IDNModel()
        train_dataset = DatasetFromNpyTrain(opt, rgb=False, input_up=True)
        test_dataset = DatasetFromNpyTest(opt, rgb=False, input_up=True)

    elif opt.SR_name == 'RCAN':
        model = RCANModel()
        train_dataset = DatasetFromNpyTrain(opt, rgb=True)
        test_dataset = DatasetFromNpyTest(opt, rgb=True)

    elif opt.SR_name == 'EDSR':
        model = EDSRModel()
        train_dataset = DatasetFromNpyTrain(opt, rgb=True)
        test_dataset = DatasetFromNpyTest(opt, rgb=True)

    elif opt.SR_name == 'RDN':
        model = RDNModel()
        train_dataset = DatasetFromNpyTrain(opt, rgb=True)
        test_dataset = DatasetFromNpyTest(opt, rgb=True)
    else:
        raise NotImplementedError('%s is not supported!' % opt.SR_name)

    visualizer = Visualizer()

    print(len(train_dataset))
    train_loader = DataLoader(train_dataset, batch_size=opt.batch_size)
    test_loader = DataLoader(test_dataset, batch_size=1)

    model.initialize(opt)

    if opt.epoch > 0:
        model.load_model(opt.epoch, opt.SR_name)

    for epoch in range(opt.epoch + 1, opt.nEpochs + 1):
        model.set_mode(train=True)

        for i, data in enumerate(train_loader, 1):

            model.set_input(data)
            loss = model.train()
            if i % 50 == 0:
                images = {'input': model.input, 'output': model.output,
                          'target': model.target}
                visualizer.display_current_results(images, k=0)
            if i % 10 == 0:
                print('epoch: {}, iteration: {}/{}, loss: {}'.format(epoch, i, len(train_loader), loss.item()))
                visualizer.plot_current_loss(loss.item())


        model.scheduler.step(epoch)  # update learning rate
        print('Learning rate: %f' % model.scheduler.get_lr()[0])


        if epoch % opt.save_epoch == 0:
            print('a')
            model.set_mode(train=False)
            average_psnr = []
            average_ssim = []
            for i, data in enumerate(test_loader):
                model.set_eval_input(data)
                outputs = model.eval()
                psnr_, ssim_ = model.comput_PSNR_SSIM(outputs['output'], outputs['target'], shave_border=opt.sr_factor)
                average_psnr.append(psnr_)
                average_ssim.append(ssim_)

                img_type = ['input', 'output', 'target']

                for name, img in zip(img_type, outputs):
                    save_name = os.path.join(model.result_dir, '%d_%s.png' % (i, name))
                    model.save_image(outputs[img], save_name)

            average_psnr = np.average(average_psnr)
            average_ssim = np.average(average_ssim)
            log = 'Epoch %d: Average psnr: %f , ssim: %f \n' % (epoch, average_psnr, average_ssim)
            print(log)
            model.log_file.write(log)
            model.save_model(epoch, opt.SR_name)


    model.log_file.close()