Пример #1
0
Файл: start.py Проект: jotoy/sr
def train():
    if FLAGS.load_model is not None:
        checkpoints_dir = "checkpoints/" + \
            FLAGS.load_model.lstrip("checkpoints/")
    else:
        current_time = datetime.now().strftime("%Y%m%d-%H%M")
        checkpoints_dir = "checkpoints/{}".format(current_time)
        try:
            os.makedirs(checkpoints_dir)
        except os.error:
            pass

    graph = tf.Graph()
    with graph.as_default():  # 设置默认的 图
        cycle_sr = CycleSR(input_ture_LR_x=FLAGS.X_LR_INPUT,
                           input_ture_HR_y=FLAGS.Y_HR_INPUT,
                           dis_true_HR_x=FLAGS.X_HR_DIS,
                           dis_true_LR_y=FLAGS.Y_LR_DIS,
                           batch_size=FLAGS.batch_size,
                           image_size=FLAGS.image_size,
                           use_lsgan=FLAGS.use_lsgan,
                           norm=FLAGS.norm,
                           lambda1=FLAGS.lambda1,
                           lambda2=FLAGS.lambda2,
                           learning_rate=FLAGS.learning_rate,
                           beta1=FLAGS.beta1,
                           ngf=FLAGS.ngf)
        G_loss, D_Y_loss, F_loss, D_X_loss, dis_fake_HR_x, dis_fake_LR_y = cycle_sr.model(
        )

        optimizers = cycle_sr.optimize(G_loss, D_Y_loss, F_loss, D_X_loss)

        summary_op = tf.summary.merge_all()
        train_writer = tf.summary.FileWriter(checkpoints_dir, graph)
        saver = tf.train.Saver()

    with tf.Session(graph=graph) as sess:  # 创建会话
        if FLAGS.load_model is not None:
            checkpoint = tf.train.get_checkpoint_state(checkpoints_dir)
            meta_graph_path = checkpoint.model_checkpoint_path + ".meta"
            restore = tf.train.import_meta_graph(meta_graph_path)
            restore.restore(sess, tf.train.latest_checkpoint(checkpoints_dir))
            step = int(meta_graph_path.split("-")[2].split(".")[0])
        else:
            sess.run(tf.global_variables_initializer())
            step = 0

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        try:
            dis_fake_HR_x_pool = ImagePool(FLAGS.pool_size)
            dis_fake_LR_y_pool = ImagePool(FLAGS.pool_size)

            while not coord.should_stop():
                # get previously generated images
                dis_fake_HR_x, dis_fake_LR_y = sess.run(
                    [dis_fake_HR_x, dis_fake_LR_y])

                # train
                _, G_loss_val, D_Y_loss_val, F_loss_val, D_X_loss_val, summary = (
                    sess.run(
                        [
                            optimizers, G_loss, D_Y_loss, F_loss, D_X_loss,
                            summary_op
                        ],
                        feed_dict={
                            cycle_sr.dis_fake_HR_x:
                            dis_fake_HR_x_pool.query(dis_fake_HR_x),
                            cycle_sr.dis_fake_LR_y:
                            dis_fake_LR_y_pool.query(dis_fake_LR_y)
                        }))

                train_writer.add_summary(summary, step)
                train_writer.flush()

                if step % 100 == 0:
                    logging.info('-----------Step %d:-------------' % step)
                    logging.info('  G_loss   : {}'.format(G_loss_val))
                    logging.info('  D_Y_loss : {}'.format(D_Y_loss_val))
                    logging.info('  F_loss   : {}'.format(F_loss_val))
                    logging.info('  D_X_loss : {}'.format(D_X_loss_val))

                if step % 10000 == 0:
                    save_path = saver.save(sess,
                                           checkpoints_dir + "/model.ckpt",
                                           global_step=step)
                    logging.info("Model saved in file: %s" % save_path)

                step += 1

        except KeyboardInterrupt:
            logging.info('Interrupted')
            coord.request_stop()
        except Exception as e:
            coord.request_stop(e)
        finally:
            save_path = saver.save(sess,
                                   checkpoints_dir + "/model.ckpt",
                                   global_step=step)
            logging.info("Model saved in file: %s" % save_path)
            # When done, ask the threads to stop.
            coord.request_stop()
            coord.join(threads)
Пример #2
0
def train():
  if FLAGS.load_model is not None:
    checkpoints_dir = "checkpoints/" + FLAGS.load_model
  else:
    current_time = datetime.now().strftime("%Y%m%d-%H%M")
    checkpoints_dir = "checkpoints/{}".format(current_time)
    try:
      os.makedirs(checkpoints_dir)
    except os.error:
      pass

  graph = tf.Graph()
  with graph.as_default():
    cycle_gan = CycleGAN(
        X_train_file=FLAGS.X,
        Y_train_file=FLAGS.Y,
        batch_size=FLAGS.batch_size,
        image_size_w=FLAGS.image_size_w,
        image_size_h=FLAGS.image_size_h,
        use_lsgan=FLAGS.use_lsgan,
        norm=FLAGS.norm,
        lambda1=FLAGS.lambda1,
        lambda2=FLAGS.lambda1,
        beta1=FLAGS.beta1,
        ngf=FLAGS.ngf,
    )
    G_loss, C_loss, fake_y, fake_x = cycle_gan.model()
    G_optimizer, C_optimizer = cycle_gan.optimize(G_loss, C_loss)

    summary_op = tf.summary.merge_all()
    train_writer = tf.summary.FileWriter(checkpoints_dir, graph)
    saver = tf.train.Saver()

  with tf.Session(graph=graph) as sess:
    if FLAGS.load_model is not None:
      checkpoint = tf.train.get_checkpoint_state(checkpoints_dir)
      meta_graph_path = checkpoint.model_checkpoint_path + ".meta"
      restore = tf.train.import_meta_graph(meta_graph_path)
      restore.restore(sess, tf.train.latest_checkpoint(checkpoints_dir))
      step = int(meta_graph_path.split("-")[2].split(".")[0])
    else:
      sess.run(tf.global_variables_initializer())
      step = 0

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    try:
      fake_Y_pool = ImagePool(FLAGS.pool_size)
      fake_X_pool = ImagePool(FLAGS.pool_size)

      while not coord.should_stop():
        # get previously generated images
        fake_y_val, fake_x_val = sess.run([fake_y, fake_x])

        # train
        adjusted_lr = (FLAGS.learning_rate *
                           0.5 ** max(0, (step / FLAGS.decay_step) - 2))
        feed_ = {cycle_gan.fake_y: fake_Y_pool.query(fake_y_val),
                    cycle_gan.fake_x: fake_X_pool.query(fake_x_val),
                    cycle_gan.learning_rate: adjusted_lr}
        # update D 5 times before update G
        for i in range(5):
            _ = sess.run(C_optimizer, feed_dict=feed_)
        _ = sess.run(G_optimizer, feed_dict=feed_)

        G_loss_val, C_loss_val, summary = (
              sess.run(
                  [G_loss, C_loss, summary_op],
                  feed_dict=feed_
              )
        )

        train_writer.add_summary(summary, step)
        train_writer.flush()

        if step % 100 == 0:
          logging.info('-----------Step %d:-------------' % step)
          logging.info('  G_loss   : {}'.format(G_loss_val))
          logging.info('  C_loss   : {}'.format(C_loss_val))

        if step % 10000 == 0:
          save_path = saver.save(sess, checkpoints_dir + "/model.ckpt", global_step=step)
          logging.info("Model saved in file: %s" % save_path)

        step += 1

    except KeyboardInterrupt:
      logging.info('Interrupted')
      coord.request_stop()
    except Exception as e:
      coord.request_stop(e)
    finally:
      save_path = saver.save(sess, checkpoints_dir + "/model.ckpt", global_step=step)
      logging.info("Model saved in file: %s" % save_path)
      # When done, ask the threads to stop.
      coord.request_stop()
      coord.join(threads)
class SimpleGAN(Trainer):
    def __init__(self, parsed_args, parsed_groups):
        super().__init__(**parsed_groups['trainer arguments'])

        self.gen = UNetGenerator(**parsed_groups['generator arguments'])
        self.disc = NLayerDiscriminator(
            **parsed_groups['discriminator arguments'])

        init_weights(self.gen)
        init_weights(self.disc)

        self.real_label = torch.tensor(1.0)
        self.fake_label = torch.tensor(0.0)

        self.crit = torch.nn.MSELoss()  # LSGAN
        self.image_pool = ImagePool(parsed_args.pool_size,
                                    parsed_args.replay_prob)

        self.sel_ind = 0
        self.un_normalize = lambda x: 255. * (1 + x.clamp(min=-1, max=1)) / 2.

        self.parsed_args = parsed_args
        self.n_vis = parsed_args.n_vis
        self.vis = Visualizer()

    def configure_optimizers(self):
        opt1 = torch.optim.Adam(self.disc.parameters(), lr=self.parsed_args.lr)
        opt2 = torch.optim.Adam(self.gen.parameters(), lr=self.parsed_args.lr)

        N = self.parsed_args.n_epochs
        N_start = int(N * self.parsed_args.frac_decay_start)
        sched_lamb = lambda x: 1.0 - max(0, x - N_start) / (N - N_start)

        sched1 = torch.optim.lr_scheduler.LambdaLR(opt1, lr_lambda=sched_lamb)
        sched2 = torch.optim.lr_scheduler.LambdaLR(opt2, lr_lambda=sched_lamb)

        return opt1, opt2, sched1, sched2

    def on_fit_start(self):
        self.vis.start()

    def on_fit_end(self):
        self.vis.stop()

    def _shared_step(self, batch, save_img, is_train):
        res = Result()

        X, Y_real = batch
        Y_fake = self.gen(X)

        if is_train:
            Y_pool = self.image_pool.query(Y_fake.detach())
        else:
            Y_pool = Y_fake
            res.recon_error = self.crit(Y_real, Y_fake)

        real_predict = self.disc(Y_real)
        fake_predict = self.disc(Y_pool)

        real_label = self.real_label.expand_as(real_predict)
        fake_label = self.fake_label.expand_as(fake_predict)

        disc_loss = 0.5 * (self.crit(real_predict, real_label) + \
                           self.crit(fake_predict, fake_label))

        if is_train:
            res.step(disc_loss)
        res.disc_loss = disc_loss

        gen_predict = self.disc(Y_fake)
        gen_loss = self.crit(gen_predict, real_label)

        if is_train:
            res.step(gen_loss)
        res.gen_loss = gen_loss

        if save_img:
            res.img = [
                self.un_normalize(X[:self.n_vis]),
                self.un_normalize(Y_fake[:self.n_vis]),
                self.un_normalize(Y_real[:self.n_vis])
            ]
        return res

    def training_step(self, batch, batch_idx):
        res = self._shared_step(batch,
                                save_img=(batch_idx == 0),
                                is_train=True)
        return res

    def validation_step(self, batch, batch_idx):
        res = self._shared_step(batch,
                                save_img=(batch_idx == self.sel_ind),
                                is_train=False)
        return res

    def _shared_end(self, result_outputs, is_train):
        phase = 'Train' if is_train else 'Valid'

        self.vis.plot('Gen. Loss', phase + ' Loss', self.current_epoch,
                      torch.mean(torch.stack(result_outputs.gen_loss)))
        self.vis.plot('Disc. Loss', phase + ' Loss', self.current_epoch,
                      torch.mean(torch.stack(result_outputs.disc_loss)))

        collated_imgs = torch.cat([*torch.cat(result_outputs.img[0], dim=3)],
                                  dim=1)
        self.vis.show_image(phase + ' Images', collated_imgs)

    def training_epoch_end(self, training_outputs):
        self._shared_end(training_outputs, is_train=True)

    def validation_epoch_end(self, validation_outputs):
        self._shared_end(validation_outputs, is_train=False)
        self.sel_ind = random.randint(0, len(self.validation_loader) - 1)

        return torch.mean(torch.stack(validation_outputs.recon_error))
Пример #4
0
def main():

    args = args_initialize()

    save_freq = args.save_freq
    epochs = args.num_epoch
    cuda = args.cuda

    train_dataset = UnalignedDataset(is_train=True)
    train_loader = DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=0
    )

    net_G_A = ResNetGenerator(input_nc=3, output_nc=3)
    net_G_B = ResNetGenerator(input_nc=3, output_nc=3)
    net_D_A = Discriminator()
    net_D_B = Discriminator()

    if args.cuda:
        net_G_A = net_G_A.cuda()
        net_G_B = net_G_B.cuda()
        net_D_A = net_D_A.cuda()
        net_D_B = net_D_B.cuda()

    fake_A_pool = ImagePool(50)
    fake_B_pool = ImagePool(50)

    criterionGAN = GANLoss(cuda=cuda)
    criterionCycle = torch.nn.L1Loss()
    criterionIdt = torch.nn.L1Loss()

    optimizer_G = torch.optim.Adam(
        itertools.chain(net_G_A.parameters(), net_G_B.parameters()),
        lr=args.lr,
        betas=(args.beta1, 0.999)
    )
    optimizer_D_A = torch.optim.Adam(net_D_A.parameters(), lr=args.lr, betas=(args.beta1, 0.999))
    optimizer_D_B = torch.optim.Adam(net_D_B.parameters(), lr=args.lr, betas=(args.beta1, 0.999))

    log_dir = './logs'
    checkpoints_dir = './checkpoints'
    os.makedirs(log_dir, exist_ok=True)
    os.makedirs(checkpoints_dir, exist_ok=True)

    writer = SummaryWriter(log_dir)

    for epoch in range(epochs):

        running_loss = np.zeros((8))
        for batch_idx, data in enumerate(train_loader):

            input_A = data['A']
            input_B = data['B']

            if cuda:
                input_A = input_A.cuda()
                input_B = input_B.cuda()

            real_A = Variable(input_A)
            real_B = Variable(input_B)


            """
            Backward net_G
            """
            optimizer_G.zero_grad()
            lambda_idt = 0.5
            lambda_A = 10.0
            lambda_B = 10.0

            # 各 Generatorに変換後の画像を入力
            # 何もしないのが理想の出力
            idt_B = net_G_A(real_B)
            loss_idt_A = criterionIdt(idt_B, real_B) * lambda_B * lambda_idt

            idt_A = net_G_B(real_A)
            loss_idt_B = criterionIdt(idt_A, real_A) * lambda_A * lambda_idt

            # GAN loss = D_A(G_A(A))
            # G_Aとしては生成した偽物画像が本物(True)と判断して欲しい
            fake_B = net_G_A(real_A)
            pred_fake = net_D_A(fake_B)
            loss_G_A = criterionGAN(pred_fake, True)

            fake_A = net_G_B(real_B)
            pred_fake = net_D_B(fake_A)
            loss_G_B = criterionGAN(pred_fake, True)

            rec_A = net_G_B(fake_B)
            loss_cycle_A = criterionCycle(rec_A, real_A) * lambda_A

            rec_B = net_G_A(fake_A)
            loss_cycle_B = criterionCycle(rec_B, real_B) * lambda_B

            loss_G = loss_G_A + loss_G_B + loss_cycle_A + loss_cycle_B + loss_idt_A + loss_idt_B
            loss_G.backward()

            optimizer_G.step()

            """
            update D_A
            """
            optimizer_D_A.zero_grad()
            fake_B = fake_B_pool.query(fake_B.data)

            pred_real = net_D_A(real_B)
            loss_D_real = criterionGAN(pred_real, True)

            pred_fake = net_D_A(fake_B.detach())
            loss_D_fake = criterionGAN(pred_fake, False)

            loss_D_A = (loss_D_real + loss_D_fake) * 0.5
            loss_D_A.backward()

            optimizer_D_A.step()

            """
            update D_B
            """
            optimizer_D_B.zero_grad()
            fake_A = fake_A_pool.query(fake_A.data)

            pred_real = net_D_B(real_A)
            loss_D_real = criterionGAN(pred_real, True)

            pred_fake = net_D_B(fake_A.detach())
            loss_D_fake = criterionGAN(pred_fake, False)

            loss_D_B = (loss_D_real + loss_D_fake) * 0.5
            loss_D_B.backward()


            optimizer_D_B.step()

            ret_loss = np.array([
                loss_G_A.data.detach().cpu().numpy(), loss_D_A.data.detach().cpu().numpy(),
                loss_G_B.data.detach().cpu().numpy(), loss_D_B.data.detach().cpu().numpy(),
                loss_cycle_A.data.detach().cpu().numpy(), loss_cycle_B.data.detach().cpu().numpy(),
                loss_idt_A.data.detach().cpu().numpy(), loss_idt_B.data.detach().cpu().numpy()
            ])
            running_loss += ret_loss

            """
            Save checkpoints
            """
            if (epoch + 1) % save_freq == 0:
                save_network(net_G_A, 'G_A', str(epoch + 1))
                save_network(net_D_A, 'D_A', str(epoch + 1))
                save_network(net_G_B, 'G_B', str(epoch + 1))
                save_network(net_D_B, 'D_B', str(epoch + 1))

        running_loss /= len(train_loader)
        losses = running_loss
        print('epoch %d, losses: %s' % (epoch + 1, running_loss))

        writer.add_scalar('loss_G_A', losses[0], epoch)
        writer.add_scalar('loss_D_A', losses[1], epoch)
        writer.add_scalar('loss_G_B', losses[2], epoch)
        writer.add_scalar('loss_D_B', losses[3], epoch)
        writer.add_scalar('loss_cycle_A', losses[4], epoch)
        writer.add_scalar('loss_cycle_B', losses[5], epoch)
        writer.add_scalar('loss_idt_A', losses[6], epoch)
        writer.add_scalar('loss_idt_B', losses[7], epoch)
Пример #5
0
def train(args):
    print(args)

    # net
    netG = Generator()
    netG = netG.cuda()
    netD = Discriminator()
    netD = netD.cuda()

    # loss
    l1_loss = nn.L1Loss().cuda()
    l2_loss = nn.MSELoss().cuda()
    bce_loss = nn.BCELoss().cuda()

    # opt
    optimizerG = optim.Adam(netG.parameters(), lr=args.glr)
    optimizerD = optim.Adam(netD.parameters(), lr=args.dlr)

    # lr
    schedulerG = lr_scheduler.StepLR(optimizerG, args.lr_step_size,
                                     args.lr_gamma)
    schedulerD = lr_scheduler.StepLR(optimizerD, args.lr_step_size,
                                     args.lr_gamma)

    # utility for saving models, parameters and logs
    save = SaveData(args.save_dir, args.exp, True)
    save.save_params(args)

    # netG, _ = save.load_model(netG)

    dataset = MyDataset(args.data_dir, is_train=True)
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=args.batch_size,
                                             shuffle=True,
                                             num_workers=int(args.n_threads))

    real_label = Variable(
        torch.ones([1, 1, args.patch_gan, args.patch_gan],
                   dtype=torch.float)).cuda()
    fake_label = Variable(
        torch.zeros([1, 1, args.patch_gan, args.patch_gan],
                    dtype=torch.float)).cuda()

    image_pool = ImagePool(args.pool_size)

    vgg = Vgg16(requires_grad=False)
    vgg.cuda()

    for epoch in range(args.epochs):
        print("* Epoch {}/{}".format(epoch + 1, args.epochs))

        schedulerG.step()
        schedulerD.step()

        d_total_real_loss = 0
        d_total_fake_loss = 0
        d_total_loss = 0

        g_total_res_loss = 0
        g_total_per_loss = 0
        g_total_gan_loss = 0
        g_total_loss = 0

        netG.train()
        netD.train()

        for batch, images in tqdm(enumerate(dataloader)):
            input_image, target_image = images
            input_image = Variable(input_image.cuda())
            target_image = Variable(target_image.cuda())
            output_image = netG(input_image)

            # Update D
            netD.requires_grad(True)
            netD.zero_grad()

            ## real image
            real_output = netD(target_image)
            d_real_loss = bce_loss(real_output, real_label)
            d_real_loss.backward()
            d_real_loss = d_real_loss.data.cpu().numpy()
            d_total_real_loss += d_real_loss

            ## fake image
            fake_image = output_image.detach()
            fake_image = Variable(image_pool.query(fake_image.data))
            fake_output = netD(fake_image)
            d_fake_loss = bce_loss(fake_output, fake_label)
            d_fake_loss.backward()
            d_fake_loss = d_fake_loss.data.cpu().numpy()
            d_total_fake_loss += d_fake_loss

            ## loss
            d_total_loss += d_real_loss + d_fake_loss

            optimizerD.step()

            # Update G
            netD.requires_grad(False)
            netG.zero_grad()

            ## reconstruction loss
            g_res_loss = l1_loss(output_image, target_image)
            g_res_loss.backward(retain_graph=True)
            g_res_loss = g_res_loss.data.cpu().numpy()
            g_total_res_loss += g_res_loss

            ## perceptual loss
            g_per_loss = args.p_factor * l2_loss(vgg(output_image),
                                                 vgg(target_image))
            g_per_loss.backward(retain_graph=True)
            g_per_loss = g_per_loss.data.cpu().numpy()
            g_total_per_loss += g_per_loss

            ## gan loss
            output = netD(output_image)
            g_gan_loss = args.g_factor * bce_loss(output, real_label)
            g_gan_loss.backward()
            g_gan_loss = g_gan_loss.data.cpu().numpy()
            g_total_gan_loss += g_gan_loss

            ## loss
            g_total_loss += g_res_loss + g_per_loss + g_gan_loss

            optimizerG.step()

        d_total_real_loss = d_total_real_loss / (batch + 1)
        d_total_fake_loss = d_total_fake_loss / (batch + 1)
        d_total_loss = d_total_loss / (batch + 1)
        save.add_scalar('D/real', d_total_real_loss, epoch)
        save.add_scalar('D/fake', d_total_fake_loss, epoch)
        save.add_scalar('D/total', d_total_loss, epoch)

        g_total_res_loss = g_total_res_loss / (batch + 1)
        g_total_per_loss = g_total_per_loss / (batch + 1)
        g_total_gan_loss = g_total_gan_loss / (batch + 1)
        g_total_loss = g_total_loss / (batch + 1)
        save.add_scalar('G/res', g_total_res_loss, epoch)
        save.add_scalar('G/per', g_total_per_loss, epoch)
        save.add_scalar('G/gan', g_total_gan_loss, epoch)
        save.add_scalar('G/total', g_total_loss, epoch)

        if epoch % args.period == 0:
            log = "Train d_loss: {:.5f} \t g_loss: {:.5f}".format(
                d_total_loss, g_total_loss)
            print(log)
            save.save_log(log)
            save.save_model(netG, epoch)
Пример #6
0
class CycleGANModel():
    def __init__(self, opt):
        self.opt = opt
        self.dynamic = opt.dynamic
        self.isTrain = opt.istrain
        self.Tensor = torch.cuda.FloatTensor

        # load/define networks
        # The naming conversion is different from those used in the paper
        # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)

        self.netG_A = GModel(opt).cuda()
        self.netG_B = GModel(opt).cuda()
        self.netF_A = Fmodel().cuda()
        self.dataF = Fdata.get_loader()

        if self.isTrain:
            self.netD_A = DModel(opt).cuda()
            self.netD_B = DModel(opt).cuda()

        if self.isTrain:
            self.fake_A_pool = ImagePool(pool_size=128)
            self.fake_B_pool = ImagePool(pool_size=128)
            # define loss functions
            self.criterionGAN = GANLoss(tensor=self.Tensor).cuda()
            if opt.loss == 'l1':
                self.criterionCycle = torch.nn.L1Loss()
                self.criterionIdt = torch.nn.L1Loss()
            elif opt.loss == 'l2':
                self.criterionCycle = torch.nn.MSELoss()
                self.criterionIdt = torch.nn.MSELoss()
            # initialize optimizers
            # self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()))
            self.optimizer_G = torch.optim.Adam([{
                'params':
                self.netG_A.parameters(),
                'lr':
                1e-3
            }, {
                'params':
                self.netF_A.parameters(),
                'lr':
                0.0
            }, {
                'params':
                self.netG_B.parameters(),
                'lr':
                1e-3
            }])
            self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters())
            self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters())
            self.optimizers = []
            self.schedulers = []
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D_A)
            self.optimizers.append(self.optimizer_D_B)
            for optimizer in self.optimizers:
                # self.schedulers.append(networks.get_scheduler(optimizer, opt))
                self.schedulers.append(optimizer)

        print('---------- Networks initialized -------------')
        # networks.print_network(self.netG_A)
        # networks.print_network(self.netG_B)
        # if self.isTrain:
        #     networks.print_network(self.netD_A)
        #     networks.print_network(self.netD_B)
        print('-----------------------------------------------')

    def train_forward(self):
        optimizer = torch.optim.Adam(self.netF_A.parameters(), lr=1e-3)
        loss_fn = torch.nn.MSELoss()
        loss_fn = torch.nn.L1Loss()
        for epoch in range(100):
            epoch_loss = 0
            for i, item in enumerate(self.dataF):
                state0 = item[0].float().cuda()
                action = item[1].float().cuda()
                state1 = item[2].float().cuda()
                pred = self.netF_A(state0, action)
                loss = loss_fn(pred, state1)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                epoch_loss += loss.item()
            print('epoch:{} loss:{:.7f}'.format(epoch,
                                                epoch_loss / len(self.dataF)))
        print('forward model has been trained!')

    def set_input(self, input):
        # AtoB = self.opt.which_direction == 'AtoB'
        # input_A = input['A' if AtoB else 'B']
        # input_B = input['B' if AtoB else 'A']
        self.input_A = input[0]
        self.input_Bt0 = input[1][0]
        self.input_Bt1 = input[1][2]
        self.action = input[1][1]

    def forward(self):
        self.real_A = Variable(self.input_A).float().cuda()
        self.real_Bt0 = Variable(self.input_Bt0).float().cuda()
        self.real_Bt1 = Variable(self.input_Bt1).float().cuda()
        self.action = Variable(self.action).float().cuda()

    def test(self):
        self.forward()
        real_A = Variable(self.input_A, volatile=True).float().cuda()
        fake_B = self.netG_A(real_A)
        self.rec_A = self.netG_B(fake_B).data
        self.fake_B = fake_B.data

        real_B = Variable(self.input_Bt0, volatile=True).float().cuda()
        fake_A = self.netG_B(real_B)
        self.rec_B = self.netG_A(fake_A).data
        self.fake_A = fake_A.data

    def backward_D_basic(self, netD, real, fake):
        # Real
        pred_real = netD(real)
        loss_D_real = self.criterionGAN(pred_real, True)
        # Fake
        pred_fake = netD(fake.detach())
        loss_D_fake = self.criterionGAN(pred_fake, False)
        # Combined loss
        loss_D = (loss_D_real + loss_D_fake) * 0.5
        # backward
        loss_D.backward()
        return loss_D

    def backward_D_A(self):
        fake_B = self.fake_B_pool.query(self.fake_B)
        loss_D_A = self.backward_D_basic(self.netD_A, self.real_Bt0, fake_B)
        self.loss_D_A = loss_D_A.item()

    def backward_D_B(self):
        fake_A = self.fake_A_pool.query(self.fake_At0)
        loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A)
        self.loss_D_B = loss_D_B.item()

    def backward_G(self):
        lambda_idt = 0.5
        lambda_A = 100.0
        lambda_B = 100.0
        # Identity loss
        if lambda_idt > 0:
            # G_A should be identity if real_B is fed.
            idt_A = self.netG_A(self.real_Bt0)
            loss_idt_A = self.criterionIdt(
                idt_A, self.real_Bt0) * lambda_B * lambda_idt
            # G_B should be identity if real_A is fed.
            idt_B = self.netG_B(self.real_A)
            loss_idt_B = self.criterionIdt(idt_B,
                                           self.real_A) * lambda_A * lambda_idt

            self.idt_A = idt_A.data
            self.idt_B = idt_B.data
            self.loss_idt_A = loss_idt_A.item()
            self.loss_idt_B = loss_idt_B.item()
        else:
            loss_idt_A = 0
            loss_idt_B = 0
            self.loss_idt_A = 0
            self.loss_idt_B = 0

        lambda_G = 1.0

        # --------first cycle-----------#
        # GAN loss D_A(G_A(A))
        fake_B = self.netG_A(self.real_A)
        pred_fake = self.netD_A(fake_B)
        loss_G_A = self.criterionGAN(pred_fake, True) * lambda_G
        # Forward cycle loss
        rec_A = self.netG_B(fake_B)
        loss_cycle_A = self.criterionCycle(rec_A, self.real_A) * lambda_A

        # ---------second cycle---------#
        # GAN loss D_B(G_B(B))
        fake_At0 = self.netG_B(self.real_Bt0)
        pred_fake = self.netD_B(fake_At0)
        loss_G_Bt0 = self.criterionGAN(pred_fake, True) * lambda_G
        # Backward cycle loss
        rec_Bt0 = self.netG_A(fake_At0)
        loss_cycle_Bt0 = self.criterionCycle(rec_Bt0, self.real_Bt0) * lambda_B

        # ---------third cycle---------#
        # GAN loss D_B(G_B(B))
        fake_At1 = self.netF_A(fake_At0, self.action)
        pred_fake = self.netD_B(fake_At1)
        loss_G_Bt1 = self.criterionGAN(pred_fake, True) * lambda_G
        # Backward cycle loss
        rec_Bt1 = self.netG_A(fake_At1)
        loss_cycle_Bt1 = self.criterionCycle(rec_Bt1, self.real_Bt1) * lambda_B

        # combined loss
        loss_G = loss_idt_A + loss_idt_B
        loss_G = loss_G + loss_G_A + loss_cycle_A
        loss_G = loss_G + loss_G_Bt0 + loss_cycle_Bt0
        if self.dynamic:
            loss_G = loss_G + loss_G_Bt1 + loss_cycle_Bt1
        loss_G.backward()

        self.fake_B = fake_B.data
        self.fake_At0 = fake_At0.data
        self.fake_At1 = fake_At1.data
        self.rec_A = rec_A.data
        self.rec_Bt0 = rec_Bt0.data
        self.rec_Bt1 = rec_Bt1.data

        self.loss_G_A = loss_G_A.item()
        self.loss_G_Bt0 = loss_G_Bt0.item()
        self.loss_G_Bt1 = loss_G_Bt1.item()
        self.loss_cycle_A = loss_cycle_A.item()
        self.loss_cycle_Bt0 = loss_cycle_Bt0.item()
        self.loss_cycle_Bt1 = loss_cycle_Bt1.item()

    def optimize_parameters(self):
        # forward
        self.forward()
        # G_A and G_B
        self.optimizer_G.zero_grad()
        self.backward_G()
        self.optimizer_G.step()
        # D_A
        self.optimizer_D_A.zero_grad()
        self.backward_D_A()
        self.optimizer_D_A.step()
        # D_B
        self.optimizer_D_B.zero_grad()
        self.backward_D_B()
        self.optimizer_D_B.step()

    def get_current_errors(self):
        ret_errors = OrderedDict([('D_A', self.loss_D_A),
                                  ('G_A', self.loss_G_A),
                                  ('Cyc_A', self.loss_cycle_A),
                                  ('D_B', self.loss_D_B),
                                  ('G_B', self.loss_G_Bt0),
                                  ('Cyc_B', self.loss_cycle_Bt0)])
        # if self.opt.identity > 0.0:
        ret_errors['idt_A'] = self.loss_idt_A
        ret_errors['idt_B'] = self.loss_idt_B
        return ret_errors

    # helper saving function that can be used by subclasses
    def save_network(self, network, network_label, path):
        save_filename = 'model_{}.pth'.format(network_label)
        save_path = os.path.join(path, save_filename)
        torch.save(network.state_dict(), save_path)

    def save(self, path):
        self.save_network(self.netG_A, 'G_A', path)
        self.save_network(self.netD_A, 'D_A', path)
        self.save_network(self.netG_B, 'G_B', path)
        self.save_network(self.netD_B, 'D_B', path)

    def load_network(self, network, network_label, path):
        weight_filename = 'model_{}.pth'.format(network_label)
        weight_path = os.path.join(path, weight_filename)
        network.load_state_dict(torch.load(weight_path))

    def load(self, path):
        self.load_network(self.netG_A, 'G_A', path)
        self.load_network(self.netG_B, 'G_B', path)

    def plot_points(self, item, label):
        item = item.cpu().data.numpy()
        plt.scatter(item[:, 0], item[:, 1], label=label)

    def visual(self, path):
        plt.rcParams['figure.figsize'] = (8.0, 3.0)
        plt.xlim(-4, 4)
        plt.ylim(-1.5, 1.5)
        self.plot_points(self.real_A, 'realA')
        self.plot_points(self.fake_B, 'fake_B')
        self.plot_points(self.rec_A, 'rec_A')
        # self.plot_points(self.real_B,'real_B')
        for p1, p2 in zip(self.real_A, self.fake_B):
            p1, p2 = p1.cpu().data.numpy(), p2.cpu().data.numpy()
            plt.plot([p1[0], p2[0]], [p1[1], p2[1]])
        plt.legend()
        plt.savefig(path)
        plt.cla()
        plt.clf()
Пример #7
0
class cycleGan():

    # def __init__(self,  g_conv_dim=64, d_conv_dim=64,res_blocks=4,lr=0.001, beta1=0.5, beta2=0.999):
    def __init__(self, opt):
        super(cycleGan, self).__init__()
        self.opt = opt
        self.G_XtoY = RensetGenerator(opt.input_nc, opt.output_nc, opt.ngf,
                                      opt.norm, not opt.no_dropout,
                                      opt.n_blocks,
                                      opt.padding_type).to(device)
        # self.G_YtoX = CycleGenerator(d_conv_dim, res_blocks)
        self.G_YtoX = RensetGenerator(opt.output_nc, opt.input_nc, opt.ngf,
                                      opt.norm, not opt.no_dropout,
                                      opt.n_blocks,
                                      opt.padding_type).to(device)
        # self.D_X = CycleDiscriminator(d_conv_dim)
        self.D_X = PatchDiscriminator(opt.output_nc, opt.ndf, opt.n_layers_D,
                                      opt.norm).to(device)
        # self.D_Y = CycleDiscriminator(d_conv_dim)
        self.D_Y = PatchDiscriminator(opt.input_nc, opt.ndf, opt.n_layers_D,
                                      opt.norm).to(device)
        # self.G_XtoY.apply(init_weights)
        # self.G_YtoX.apply(init_weights)
        # self.D_X.apply(init_weights)
        # self.D_Y.apply(init_weights)
        print(self.G_XtoY)
        print("Parameters: ", len(list(self.G_XtoY.parameters())))
        print(self.G_YtoX)
        print("Parameters: ", len(list(self.G_YtoX.parameters())))
        print(self.D_X)
        print("Parameters: ", len(list(self.D_X.parameters())))
        print(self.D_Y)
        print("Parameters: ", len(list(self.D_Y.parameters())))
        self.fake_X_pool = ImagePool(opt.pool_size)
        self.fake_Y_pool = ImagePool(opt.pool_size)
        self.criterionCycle = torch.nn.L1Loss()
        self.criterionIdt = torch.nn.L1Loss()  #For implementing Identity Loss
        self.optimizer_G = torch.optim.Adam(itertools.chain(
            self.G_XtoY.parameters(), self.G_YtoX.parameters()),
                                            lr=opt.lr,
                                            betas=[opt.beta1, 0.999])
        self.optimizer_DX = torch.optim.Adam(self.D_X.parameters(),
                                             lr=opt.lr,
                                             betas=[opt.beta1, 0.999])
        self.optimizer_DY = torch.optim.Adam(self.D_Y.parameters(),
                                             lr=opt.lr,
                                             betas=[opt.beta1, 0.999])

    def get_input(self, inputX, inputY):
        self.inputX = inputX
        self.inputY = inputY

    def forward(self):
        self.fake_X = self.G_YtoX(self.inputY).to(device)
        self.rec_Y = self.G_XtoY(self.fake_X).to(device)
        self.fake_Y = self.G_XtoY(self.inputX).to(device)
        self.rec_X = self.G_YtoX(self.fake_Y).to(device)

    def backward_D_basic(self, netD, real, fake):
        pred_real = netD(real)
        loss_D_real = real_mse_loss(pred_real)  # Fake
        pred_fake = netD(fake.detach())
        loss_D_fake = fake_mse_loss(pred_fake)
        # Combined loss and calculate gradients
        loss_D = (loss_D_real + loss_D_fake) * 0.5
        loss_D.backward()
        return loss_D

    def backward_D_X(self):
        """Calculate GAN loss for discriminator D_A"""
        if self.opt.pool == True:
            fake_X = self.fake_X_pool.query(self.fake_X)
            self.loss_D_X = self.backward_D_basic(self.D_X, self.inputX,
                                                  fake_X)
        else:
            self.loss_D_X = self.backward_D_basic(self.D_X, self.inputX,
                                                  self.fake_X)

    def backward_D_Y(self):
        """Calculate GAN loss for discriminator D_A"""
        if self.opt.pool == True:
            fake_Y = self.fake_Y_pool.query(self.fake_Y)
            self.loss_D_Y = self.backward_D_basic(self.D_Y, self.inputY,
                                                  fake_Y)
        else:
            self.loss_D_Y = self.backward_D_basic(self.D_Y, self.inputY,
                                                  self.fake_Y)

    def backward_G(self):
        # Not implemented identity loss

        lambda_A = self.opt.lambda_A
        lambda_B = self.opt.lambda_B

        self.loss_G_X = real_mse_loss(self.D_X(self.fake_X))
        # GAN loss D_B(G_B(B))
        self.loss_G_Y = real_mse_loss(self.D_Y(self.fake_Y))
        # Forward cycle loss || G_B(G_A(A)) - A||
        self.loss_cycle_X = self.criterionCycle(self.rec_X,
                                                self.inputX) * lambda_A
        # Backward cycle loss || G_A(G_B(B)) - B||
        self.loss_cycle_Y = self.criterionCycle(self.rec_Y,
                                                self.inputY) * lambda_B
        # combined loss and calculate gradients
        self.loss_G = self.loss_G_X + self.loss_G_Y + self.loss_cycle_X + self.loss_cycle_Y
        self.loss_G.backward()

    def change_lr(self, new_lr):
        self.opt.lr = new_lr
        self.optimizer_G = torch.optim.Adam(itertools.chain(
            self.G_XtoY.parameters(), self.G_YtoX.parameters()),
                                            lr=self.opt.lr,
                                            betas=[self.opt.beta1, 0.999])
        self.optimizer_DX = torch.optim.Adam(self.D_X.parameters(),
                                             lr=self.opt.lr,
                                             betas=[self.opt.beta1, 0.999])
        self.optimizer_DY = torch.optim.Adam(self.D_Y.parameters(),
                                             lr=self.opt.lr,
                                             betas=[self.opt.beta1, 0.999])

    def set_requires_grad(self, nets, requires_grad=False):

        if not isinstance(nets, list):
            nets = [nets]
        for net in nets:
            if net is not None:
                for param in net.parameters():
                    param.requires_grad = requires_grad

    def optimize(self):

        self.set_requires_grad([self.D_X, self.D_Y], False)
        self.forward()
        self.optimizer_G.zero_grad()  # set G_A and G_B's gradients to zero
        self.backward_G()  # calculate gradients for G_A and G_B
        self.optimizer_G.step()  # update G_A and G_B's weights

        self.set_requires_grad([self.D_X, self.D_Y], True)
        self.forward()
        self.optimizer_G.zero_grad()  # set G_A and G_B's gradients to zero
        self.backward_G()  # calculate gradients for G_A and G_B
        self.optimizer_G.step()

        # D_A and D_B
        self.optimizer_DX.zero_grad()
        self.optimizer_DY.zero_grad()  # set D_A and D_B's gradients to zero
        self.backward_D_X()  # calculate gradients for D_A
        self.backward_D_Y()  # calculate graidents for D_B
        self.optimizer_DX.step()  # update D_A and D_B's weights
        self.optimizer_DY.step()  # update D_A and D_B's weights
Пример #8
0
def main():

    num_epoch = 100000
    pool_size = 20
    batch_size = 1
    oldpath = FLAGS.buckets
    RealPicPath = 'picF'
    AnimaPicPaht = 'picG'
    useCopyfile = True

    if useCopyfile:
        trainfiles = ['picf1.zip', 'picf2.zip', 'picg1.zip']
        # trainfiles.extend(['picf3.zip','picf4.zip','picg2.zip'])

        print(trainfiles)

        for f in trainfiles:
            fn = utils.pai_copy(f, oldpath)
            utils.Unzip(fn)

        RealPicPath = os.path.join('temp', RealPicPath)
        AnimaPicPaht = os.path.join('temp', AnimaPicPaht)

    print(RealPicPath)
    print(AnimaPicPaht)

    sess = tf.InteractiveSession(
        config=tf.ConfigProto(allow_soft_placement=True))

    cycle_gan = CycleGAN(
        X_train_file=AnimaPicPaht,
        Y_train_file=RealPicPath,
        batch_size=batch_size,
        image_size=(270, 480),
        lossfunc = 'wgan',
        norm='instance',
        learning_rate=2e-4,
        start_decay_step = 10000,
        decay_steps = 100000
        optimizer = 'RMSProp'
    )

    Ga2b_loss, Da2b_loss, Gb2a_loss, Db2a_loss, fake_a, fake_b,real_a,real_b = cycle_gan.build()

    optimizers = cycle_gan.optimize(Ga2b_loss, Da2b_loss, Gb2a_loss, Db2a_loss)

    summary_op = tf.summary.merge_all()
    train_writer = tf.summary.FileWriter(FLAGS.checkpointDir)
    saver = tf.train.Saver(max_to_keep=0)

    sess.run([tf.global_variables_initializer(),
              tf.local_variables_initializer()])

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    # save_path = saver.save(sess,os.path.join(FLAGS.checkpointDir,"model_pre.ckpt"))
    # print("Model saved in file: %s" % save_path)

    fake_a_pool = ImagePool(pool_size)
    fake_b_pool = ImagePool(pool_size)
    print('start train')
    start_time = time.time()

    for step in range(1, num_epoch + 1):
        # get previously generated images
        fake_a_val, fake_b_val = sess.run([fake_a, fake_b])

        # train
        _, Ga2b_loss_val, Da2b_loss_val, Gb2a_loss_val, Db2a_loss_val,real_a_val,real_b_val, summary = (
            sess.run(
                [optimizers, Ga2b_loss, Da2b_loss, Gb2a_loss, Db2a_loss,real_a,real_b, summary_op],
                feed_dict={cycle_gan.fake_a: fake_a_pool.query(fake_a_val),
                           cycle_gan.fake_b: fake_b_pool.query(fake_b_val)}
            )
        )

        elapsed_time = time.time() - start_time
        start_time = time.time()

        if step % 25 == 0:
            print('Ga2b_loss_val : %s--Da2b_loss_val : %s--Gb2a_loss_val : %s--Db2a_loss_val : %s--' % (Ga2b_loss_val,
                                                                                Da2b_loss_val, Gb2a_loss_val, Db2a_loss_val))

            print('step : %s --elapsed_time : %s' % (step, elapsed_time))
            print('adding summary...')
            train_writer.add_summary(summary, step)
            train_writer.flush()

        # if step % 100 == 0:
        #     print('-----------Step %d:-------------' % step)
        #     print('  G_loss   : {}'.format(G_loss_val))
        #     print('  D_Y_loss : {}'.format(D_Y_loss_val))
        #     print('  F_loss   : {}'.format(F_loss_val))
        #     print('  D_X_loss : {}'.format(D_X_loss_val))

        if step % 1000 == 0:
            save_path = saver.save(sess, os.path.join(
                FLAGS.checkpointDir, "model.ckpt"), global_step=step,write_meta_graph=False)
            print("Model saved in file: %s" % save_path)

    coord.request_stop()
    coord.join(threads)
Пример #9
0
class Model():

    @staticmethod
    def modify_commandline_options(parser, is_train=True):
        parser.set_defaults(no_dropout=True)  # default CycleGAN did not use dropout
        if is_train:
            parser.add_argument('--lambda_A', type=float, default=10.0, help='weight for cycle loss (A -> B -> A)')
            parser.add_argument('--lambda_B', type=float, default=10.0, help='weight for cycle loss (B -> A -> B)')
            parser.add_argument('--lambda_identity', type=float, default=0.5, help='use identity mapping. Setting lambda_identity other than 0 has an effect of scaling the weight of the identity mapping loss. For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set lambda_identity = 0.1')

        return parser

    def __init__(self, opt):
        # BaseModel.__init__(self, opt)
        self.opt = opt
        self.gpu_ids = opt.gpu_ids
        self.isTrain = opt.isTrain
        self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device(
            'cpu')  # get device name: CPU or GPU
        self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)  # save all the checkpoints to save_dir
        if opt.preprocess != 'scale_width':  # with [scale_width], input images might have different sizes, which hurts the performance of cudnn.benchmark.
            torch.backends.cudnn.benchmark = True
        self.loss_names = []
        self.model_names = []
        self.visual_names = []
        self.optimizers = []
        self.image_paths = []
        self.metric = None  # used for learning rate policy 'plateau'
        self.loss_names = ['D_A', 'G_A', 'cycle_A', 'idt_A', 'D_B', 'G_B', 'cycle_B', 'idt_B']
            # , 'perception_G_A',
            #                'perception_G_B', 'image_G_A', 'image_G_B', 'tv_G_A', 'tv_G_B', 'rl_G_A', 'rl_G_B']
        # specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
        visual_names_A = ['real_A', 'fake_B', 'rec_A']
        visual_names_B = ['real_B', 'fake_A', 'rec_B']
        if self.isTrain and self.opt.lambda_identity > 0.0:  # if identity loss is used, we also visualize idt_B=G_A(B) ad idt_A=G_A(B)
            visual_names_A.append('idt_B')
            visual_names_B.append('idt_A')

        self.visual_names = visual_names_A + visual_names_B  # combine visualizations for A and B
        if self.isTrain:
            self.model_names = ['G_A', 'G_B', 'D_A', 'D_B']
        else:  # during test time, only load Gs
            self.model_names = ['G_A', 'G_B']

        self.netG_A = define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm,
                                        not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
        self.netG_B = define_G(opt.output_nc, opt.input_nc, opt.ngf, opt.netG, opt.norm,
                                        not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)

        if self.isTrain:  # define discriminators
            self.netD_A = define_D(opt.output_nc, opt.ndf, opt.netD,
                                            opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
            self.netD_B = define_D(opt.input_nc, opt.ndf, opt.netD,
                                            opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)

        if self.isTrain:
            if opt.lambda_identity > 0.0:  # only works when input and output images have the same number of channels
                assert(opt.input_nc == opt.output_nc)
            self.fake_A_pool = ImagePool(opt.pool_size)  # create image buffer to store previously generated images
            self.fake_B_pool = ImagePool(opt.pool_size)  # create image buffer to store previously generated images
            self.criterionGAN = GANLoss(opt.gan_mode).to(self.device)  # define GAN loss.
            self.criterionCycle = torch.nn.L1Loss()
            self.criterionIdt = torch.nn.L1Loss()

            vgg = vgg16(pretrained=True)
            loss_network = nn.Sequential(*list(vgg.features)[:31]).eval()

            for param in loss_network.parameters():
                param.requires_grad = False
            loss_network.cuda()
            self.criterionLossnetwork = loss_network
            self.criterionMse = torch.nn.MSELoss()
            self.criterionTv = TVLoss()
            self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizer_D = torch.optim.Adam(itertools.chain(self.netD_A.parameters(), self.netD_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D)

    def set_input(self, input):
        AtoB = self.opt.direction == 'AtoB'
        self.real_A = input['A' if AtoB else 'B'].to(self.device)
        self.real_B = input['B' if AtoB else 'A'].to(self.device)
        self.A_paths = input['A_paths'][0]
        self.B_paths = input['B_paths'][0]
        self.image_paths = input['A_paths' if AtoB else 'B_paths']

    def forward(self):
        """Run forward pass; called by both functions <optimize_parameters> and <test>."""
        self.fake_B = self.netG_A(self.real_A)  # G_A(A)
        self.rec_A = self.netG_B(self.fake_B)   # G_B(G_A(A))
        self.fake_A = self.netG_B(self.real_B)  # G_B(B)
        self.rec_B = self.netG_A(self.fake_A)   # G_A(G_B(B))

    def backward_D_basic(self, netD, real, fake):
        # Real
        pred_real = netD(real)
        loss_D_real = self.criterionGAN(pred_real, True)
        # Fake
        pred_fake = netD(fake.detach())
        loss_D_fake = self.criterionGAN(pred_fake, False)
        # Combined loss and calculate gradients
        loss_D = (loss_D_real + loss_D_fake) * 0.5
        loss_D.backward()
        return loss_D

    def backward_D_A(self):
        fake_B = self.fake_B_pool.query(self.fake_B)
        self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B)

    def backward_D_B(self):
        fake_A = self.fake_A_pool.query(self.fake_A)
        self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A)

    def realistic_loss_grad(self, image, laplacian_m):
        img = image.squeeze(0)
        channel, height, width = img.size()
        loss = 0
        for i in range(channel):
            # print(laplacian_m.size())
            # print(img[i, :, :].size())
            # print(img[i, :, :].reshape(-1, 1).size())
            grad = torch.mm(laplacian_m, img[i, :, :].reshape(-1, 1))
            loss += torch.mm(img[i, :, :].reshape(1, -1), grad)
        return loss

    def backward_G(self):
        lambda_idt = self.opt.lambda_identity
        lambda_A = self.opt.lambda_A
        lambda_B = self.opt.lambda_B
        # Identity loss
        if lambda_idt > 0:
            # G_A should be identity if real_B is fed: ||G_A(B) - B||
            self.idt_A = self.netG_A(self.real_B)
            self.loss_idt_A = self.criterionIdt(self.idt_A, self.real_B) * lambda_B * lambda_idt
            # G_B should be identity if real_A is fed: ||G_B(A) - A||
            self.idt_B = self.netG_B(self.real_A)
            self.loss_idt_B = self.criterionIdt(self.idt_B, self.real_A) * lambda_A * lambda_idt
        else:
            self.loss_idt_A = 0
            self.loss_idt_B = 0

        # GAN loss D_A(G_A(A))
        self.loss_G_A = self.criterionGAN(self.netD_A(self.fake_B), True)
        # GAN loss D_B(G_B(B))
        self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_A), True)
        # Forward cycle loss || G_B(G_A(A)) - A||
        self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A
        # Backward cycle loss || G_A(G_B(B)) - B||
        self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B

         # Perception Loss
        self.loss_perception_G_A = self.criterionMse(self.criterionLossnetwork(self.fake_A),
                                                      self.criterionLossnetwork(self.real_A)) * 0.5
        self.loss_perception_G_B = self.criterionMse(self.criterionLossnetwork(self.fake_B),
                                                      self.criterionLossnetwork(self.real_B)) * 0.5
         # Image Loss
        self.loss_image_G_A = self.criterionMse(self.fake_A, self.real_A) * 20.0
        self.loss_image_G_B = self.criterionMse(self.fake_B, self.real_B) * 20.0
         # TV Loss
        self.loss_tv_G_A = self.criterionTv(self.fake_A) * 2e-8
        self.loss_tv_G_B = self.criterionTv(self.fake_B) * 2e-8
         # real loss
        print('Computing Laplacian matrix of content image')
        # print(self.real_A.size())
        # image2 = cv2.imread(self.A_paths)
        # print(image2.shape)
        
        self.loss_rl_G_A = 0
        self.loss_rl_G_B = 0
        
        for i in range(self.real_A.size()[0]):
            L_A = compute_lap(self.real_A[i])
            L_B = compute_lap(self.real_B[i])
            self.loss_rl_G_A += self.realistic_loss_grad(self.fake_A[i], L_A) * 0.00001
            self.loss_rl_G_B += self.realistic_loss_grad(self.fake_B[i], L_B) * 0.00001
        
        self.loss_rl_G_A = torch.div(self.loss_rl_G_A, float(self.real_A.size()[0]))
        self.loss_rl_G_B = torch.div(self.loss_rl_G_B, float(self.real_B.size()[0]))


        self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + \
                      self.loss_idt_B
                      # + self.loss_perception_G_A + self.loss_perception_G_B + self.loss_image_G_A + \
                      # self.loss_image_G_B + self.loss_tv_G_A + self.loss_tv_G_B + self.loss_rl_G_A + self.loss_rl_G_B

        self.loss_G.backward()

    def optimize_parameters(self):
        """Calculate losses, gradients, and update network weights; called in every training iteration"""
        # forward
        self.forward()      # compute fake images and reconstruction images.
        # G_A and G_B
        self.set_requires_grad([self.netD_A, self.netD_B], False)  # Ds require no gradients when optimizing Gs
        self.optimizer_G.zero_grad()  # set G_A and G_B's gradients to zero
        self.backward_G()             # calculate gradients for G_A and G_B
        self.optimizer_G.step()       # update G_A and G_B's weights
        # D_A and D_B
        self.set_requires_grad([self.netD_A, self.netD_B], True)
        self.optimizer_D.zero_grad()   # set D_A and D_B's gradients to zero
        self.backward_D_A()      # calculate gradients for D_A
        self.backward_D_B()      # calculate graidents for D_B
        self.optimizer_D.step()  # update D_A and D_B's weights
        return self.real_A, self.fake_A, self.real_B, self.fake_B, self.loss_G_A, self.loss_G_B, self.loss_D_A, \
               self.loss_D_B, self.loss_cycle_A, self.loss_cycle_B, self.loss_idt_A, self.loss_idt_B
               # self.loss_perception_G_A, self.loss_perception_G_B, self.loss_image_G_A, self.loss_image_G_B, \
               # self.loss_tv_G_A, self.loss_tv_G_B, self.loss_rl_G_A, self.loss_rl_G_B

    def setup(self, opt):
        if self.isTrain:
            self.schedulers = [get_scheduler(optimizer, opt) for optimizer in self.optimizers]
        if not self.isTrain or opt.continue_train:
            load_suffix = 'iter_%d' % opt.load_iter if opt.load_iter > 0 else opt.epoch
            self.load_networks(load_suffix)
        self.print_networks(opt.verbose)

    def eval(self):
        """Make models eval mode during test time"""
        for name in self.model_names:
            if isinstance(name, str):
                net = getattr(self, 'net' + name)
                net.eval()

    def test(self):
        with torch.no_grad():
            self.forward()
            self.compute_visuals()

    def compute_visuals(self):
        pass

    def get_image_paths(self):
        return self.image_paths

    def update_learning_rate(self):
        for scheduler in self.schedulers:
            scheduler.step(self.metric)
        lr = self.optimizers[0].param_groups[0]['lr']
        print('learning rate = %.7f' % lr)

    def get_current_visuals(self):
        visual_ret = OrderedDict()
        for name in self.visual_names:
            if isinstance(name, str):
                visual_ret[name] = getattr(self, name)
        return visual_ret

    def get_current_losses(self):
        errors_ret = OrderedDict()
        for name in self.loss_names:
            if isinstance(name, str):
                errors_ret[name] = float(
                    getattr(self, 'loss_' + name))  # float(...) works for both scalar tensor and float number
        return errors_ret

    def save_networks(self, epoch):
        for name in self.model_names:
            if isinstance(name, str):
                save_filename = '%s_net_%s.pth' % (epoch, name)
                save_path = os.path.join(self.save_dir, save_filename)
                net = getattr(self, 'net' + name)

                if len(self.gpu_ids) > 0 and torch.cuda.is_available():
                    torch.save(net.module.cpu().state_dict(), save_path)
                    net.cuda(self.gpu_ids[0])
                else:
                    torch.save(net.cpu().state_dict(), save_path)

    def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
        key = keys[i]
        if i + 1 == len(keys):  # at the end, pointing to a parameter/buffer
            if module.__class__.__name__.startswith('InstanceNorm') and \
                    (key == 'running_mean' or key == 'running_var'):
                if getattr(module, key) is None:
                    state_dict.pop('.'.join(keys))
            if module.__class__.__name__.startswith('InstanceNorm') and \
                    (key == 'num_batches_tracked'):
                state_dict.pop('.'.join(keys))
        else:
            self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1)

    def load_networks(self, epoch):
        for name in self.model_names:
            if isinstance(name, str):
                load_filename = '%s_net_%s.pth' % (epoch, name)
                load_path = os.path.join(self.save_dir, load_filename)
                net = getattr(self, 'net' + name)
                if isinstance(net, torch.nn.DataParallel):
                    net = net.module
                print('loading the model from %s' % load_path)
                # if you are using PyTorch newer than 0.4 (e.g., built from
                # GitHub source), you can remove str() on self.device
                state_dict = torch.load(load_path, map_location=str(self.device))
                if hasattr(state_dict, '_metadata'):
                    del state_dict._metadata

                # patch InstanceNorm checkpoints prior to 0.4
                for key in list(state_dict.keys()):  # need to copy keys here because we mutate in loop
                    self.__patch_instance_norm_state_dict(state_dict, net, key.split('.'))
                net.load_state_dict(state_dict)

    def print_networks(self, verbose):
        print('---------- Networks initialized -------------')
        for name in self.model_names:
            if isinstance(name, str):
                net = getattr(self, 'net' + name)
                num_params = 0
                for param in net.parameters():
                    num_params += param.numel()
                if verbose:
                    print(net)
                print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6))
        print('-----------------------------------------------')

    def set_requires_grad(self, nets, requires_grad=False):
        if not isinstance(nets, list):
            nets = [nets]
        for net in nets:
            if net is not None:
                for param in net.parameters():
                    param.requires_grad = requires_grad
Пример #10
0
        loss_vertex_A = criterionCycle(fake_B, real_B)
        loss_vertex_B = criterionCycle(fake_A, real_A)

        loss_G = loss_G_A + loss_G_B + loss_cycle_A + loss_cycle_B + loss_idt_A + loss_idt_B - cc_A * lambda_cc - cc_B * lambda_cc + loss_vertex_A * lambda_vertex + loss_vertex_B * lambda_vertex
        """ calculate gradients for G_A and G_B """
        loss_G.backward()

        optimizer_G_A.step()  # update G_A and G_B's weights
        optimizer_G_B.step()  # update G_A and G_B's weights

        # train D_A and D_B
        set_requires_grad([netD_A, netD_B], True)
        optimizer_D_A.zero_grad()  # set D_A and D_B's gradients to zero
        optimizer_D_B.zero_grad()  # set D_A and D_B's gradients to zero
        """Calculate GAN loss for discriminator D_A"""
        fake_B = fake_B_pool.query(fake_B)
        loss_D_A = backward_D_basic(netD_A, real_B, fake_B, 0.1)
        """Calculate GAN loss for discriminator D_B"""
        fake_A = fake_A_pool.query(fake_A)
        loss_D_B = backward_D_basic(netD_B, real_A, fake_A, 0.1)

        optimizer_D_A.step()  # update D_A and D_B's weights
        optimizer_D_B.step()  # update D_A and D_B's weights

        print(
            "[{}:{}/{}] IDT_A={:.4}, IDT_B={:.4}, G_A={:.4}, G_B={:.4}, CYCLE_A={:.4}, CYCLE_B={:.4}, D_A={:.4}, D_B={:.4}, CC_A={:.4}, CC_B={:.4}"
            .format(epoch, batch_idx, len(train_dataloader), loss_idt_A,
                    loss_idt_B, loss_G_A, loss_G_B, loss_cycle_A, loss_cycle_B,
                    loss_D_A, loss_D_B, cc_A, cc_B))

        writer.add_scalars('Train/IDT_loss', {
Пример #11
0
def train():
    if FLAGS.load_model is not None:
        checkpoint_dir = 'checkpoint/' + FLAGS.load_model
    else:
        current_time = datetime.now().strftime('%Y%m%d-%H%M')
        checkpoint_dir = 'checkpoint/{}'.format(current_time)
        try:
            os.makedirs(checkpoint_dir)
        except os.error:
            pass

    graph = tf.Graph()
    with graph.as_default():
        cycle_gan = CycleGAN(X_train_file=FLAGS.X,
                             Y_train_file=FLAGS.Y,
                             batch_size=FLAGS.batch_size,
                             image_size=FLAGS.image_size,
                             use_lsgan=FLAGS.use_lsgan,
                             norm=FLAGS.norm,
                             lambda1=FLAGS.lambda1,
                             lambda2=FLAGS.lambda2,
                             learning_rate=FLAGS.learning_rate,
                             beta1=FLAGS.beta1,
                             ngf=FLAGS.ngf)
        G_loss, D_Y_loss, F_loss, D_X_loss, fake_y, fake_x = cycle_gan.model()
        optimizers = cycle_gan.optimize(G_loss, D_Y_loss, F_loss, D_X_loss)

        summary_op = tf.summary.merge_all()
        train_writer = tf.summary.FileWriter(checkpoint_dir, graph)
        saver = tf.train.Saver()

    with tf.Session(graph=graph) as sess:
        if FLAGS.load_model is not None:
            checkpoint = tf.train.get_checkpoint_state(checkpoint_dir)
            meta_graph_path = checkpoint.model_checkpoint_path + '.meta'
            restore = tf.train.import_meta_graph(meta_graph_path)
            restore.restore(sess, tf.train.latest_checkpoint(checkpoint_dir))
            step = int(meta_graph_path.split('-')[2].split('.')[0])
        else:
            sess.run(tf.global_variables_initializer())
            step = 0

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        try:
            fake_Y_pool = ImagePool(FLAGS.pool_size)
            fake_X_pool = ImagePool(FLAGS.pool_size)

            print('Begin to train...')
            while not coord.should_stop():
                # get previously generated images
                print('tf.Session().Run [fake_y, fake_x] ')
                fake_y_val, fake_x_val = sess.run([fake_y, fake_x])

                #train
                print('Calculate loss...')
                _, G_loss_val, D_Y_loss_val, F_loss_val, D_X_loss_val, summary = (
                    sess.run(
                        [
                            optimizers, G_loss, D_Y_loss, F_loss, D_X_loss,
                            summary_op
                        ],
                        feed_dict={
                            cycle_gan.fake_y: fake_X_pool.query(fake_y_val),
                            cycle_gan.fake_x: fake_X_pool.query(fake_x_val)
                        }))
                if step % 100 == 0:
                    train_writer.add_summary(summary, step)
                    train_writer.flush()

                if step % 100 == 0:
                    logging.info('-------------Step %d------------' % step)
                    logging.info('G_loss: {}'.format(G_loss_val))
                    logging.info('D_Y_loss: {}'.format(D_Y_loss_val))
                    logging.info('F_loss: {}'.format(F_loss_val))
                    logging.info('D_X_loss:{}'.format(D_X_loss_val))
                    logging.info('********************************')

                if step % 10000 == 0:
                    save_path = saver.save(sess,
                                           checkpoint_dir + '/model.ckpt',
                                           global_step=step)
                    logging.info('* Model saved in file %s' % save_path)

                step += 1

        except KeyboardInterrupt:
            logging.info('Interrupted')
            coord.request_stop()
        except Exception as e:
            coord.request_stop(e)
        finally:
            save_path = saver.save(sess,
                                   checkpoint_dir + '/model.ckpt',
                                   global_step=step)
            logging.info('Model saved in file %s' % save_path)
            coord.request_stop()
            coord.join(threads)
class CycleGanModel(BaseModel):
    def __init__(self, opt):
        super(CycleGanModel, self).__init__(opt)
        print('-------------- Networks initializing -------------')

        self.mode = None

        # specify the training losses you want to print out. The program will call base_model.get_current_losses
        self.lossNames = [
            'loss{}'.format(i) for i in [
                'GenA', 'DisA', 'CycleA', 'IdtA', 'DisB', 'GenB', 'CycleB',
                'IdtB'
            ]
        ]
        self.lossGenA, self.lossDisA, self.lossCycleA, self.lossIdtA = 0, 0, 0, 0
        self.lossGenB, self.lossDisB, self.lossCycleB, self.lossIdtB = 0, 0, 0, 0

        # define loss functions
        self.criterionGAN = networks.GANLoss(use_lsgan=opt.lsgan).to(
            opt.device)
        self.criterionCycle = torch.nn.L1Loss()
        self.criterionIdt = torch.nn.L1Loss()

        # specify the training miou you want to print out. The program will call base_model.get_current_mious
        self.miouNames = []

        # specify the images you want to save/display. The program will call base_model.get_current_visuals
        # only image doesn't have prefix
        imageNamesA = ['realA', 'fakeA', 'recA', 'idtA']
        imageNamesB = ['realB', 'fakeB', 'recB', 'idtB']
        self.imageNames = imageNamesA + imageNamesB

        self.realA, self.fakeA, self.recA, self.idtA = None, None, None, None
        self.realB, self.fakeB, self.recB, self.idtB = None, None, None, None

        # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks
        # naming is by the input domain
        self.modelNames = [
            'net{}'.format(i) for i in ['GenA', 'DisA', 'GenB', 'DisB']
        ]

        # load/define networks
        # The naming conversion is different from those used in the paper
        # Code (paper): G_RGB (G), G_D (F), D_RGB (D_Y), D_D (D_X)
        self.netGenA = networks.define_G(opt.inputCh, opt.inputCh, opt.ngf,
                                         opt.which_model_netG, opt.norm,
                                         opt.dropout, opt.init_type,
                                         opt.init_gain, opt.gpuIds)
        self.netDisA = networks.define_D(opt.inputCh, opt.inputCh,
                                         opt.which_model_netD, opt.n_layers_D,
                                         opt.norm, not opt.lsgan,
                                         opt.init_type, opt.init_gain,
                                         opt.gpuIds)
        self.netGenB = networks.define_G(opt.inputCh, opt.inputCh, opt.ngf,
                                         opt.which_model_netG, opt.norm,
                                         opt.dropout, opt.init_type,
                                         opt.init_gain, opt.gpuIds)
        self.netDisB = networks.define_D(opt.inputCh, opt.inputCh,
                                         opt.which_model_netD, opt.n_layers_D,
                                         opt.norm, not opt.lsgan,
                                         opt.init_type, opt.init_gain,
                                         opt.gpuIds)

        self.set_requires_grad(
            [self.netGenA, self.netGenB, self.netDisA, self.netDisB], True)

        # define image pool
        self.fakeAPool = ImagePool(opt.pool_size)
        self.fakeBPool = ImagePool(opt.pool_size)

        # initialize optimizers
        self.optimizerG = getOptimizer(itertools.chain(
            self.netGenA.parameters(), self.netGenB.parameters()),
                                       opt=opt.opt,
                                       lr=opt.lr,
                                       beta1=opt.beta1,
                                       momentum=opt.momentum,
                                       weight_decay=opt.weight_decay)
        self.optimizerD = getOptimizer(itertools.chain(
            self.netDisA.parameters(), self.netDisB.parameters()),
                                       opt=opt.opt,
                                       lr=opt.lr,
                                       beta1=opt.beta1,
                                       momentum=opt.momentum,
                                       weight_decay=opt.weight_decay)
        self.optimizers = []
        self.optimizers.append(self.optimizerG)
        self.optimizers.append(self.optimizerD)
        print('--------------------------------------------------')

    def name(self):
        return 'CycleGanModel'

    def set_input(self, input):
        self.realA = input[0]['image'].to(self.opt.device)
        self.realB = input[1]['image'].to(self.opt.device)

    def forward(self):
        self.fakeA = self.netGenB(self.realB)
        self.fakeB = self.netGenA(self.realA)
        self.recA = self.netGenB(self.fakeB)
        self.recB = self.netGenA(self.fakeA)

    def backward_dis_basic(self, netDis, real, fake):
        # Real
        predReal = netDis(real)
        lossDisReal = self.criterionGAN(predReal, True)
        # Fake
        predFake = netDis(fake.detach())
        lossDisFake = self.criterionGAN(predFake, False)
        # Combined loss
        lossDis = (lossDisReal + lossDisFake) * 0.5
        # backward
        lossDis.backward()
        return float(lossDis)

    def backward_dis_A(self):
        fakeA = self.fakeAPool.query(self.fakeA)
        self.lossDisA = self.backward_dis_basic(self.netDisA, self.realA,
                                                fakeA)

    def backward_dis_B(self):
        fakeB = self.fakeBPool.query(self.fakeB)
        self.lossDisB = self.backward_dis_basic(self.netDisB, self.realB,
                                                fakeB)

    def backward_gen(self, retain_graph=False):
        lambdaIdt = self.opt.lambdaIdentity
        lambdaA = self.opt.lambdaA
        lambdaB = self.opt.lambdaB
        # Identity loss
        self.forward()
        if lambdaIdt > 0:
            # GenB should be identity if realA is fed.
            self.idtA = self.netGenB(self.realA)
            lossIdtA = self.criterionIdt(self.idtA,
                                         self.realA) * lambdaA * lambdaIdt
            # GenA should be identity if realB is fed.
            self.idtB = self.netGenA(self.realB)
            lossIdtB = self.criterionIdt(self.idtB,
                                         self.realB) * lambdaB * lambdaIdt
        else:
            lossIdtA = 0
            lossIdtB = 0

        # GAN D loss
        lossGenA = self.criterionGAN(self.netDisB(self.fakeB), True)
        # GAN D loss
        lossGenB = self.criterionGAN(self.netDisA(self.fakeA), True)
        # Forward cycle loss
        lossCycleA = self.criterionCycle(self.recA, self.realA) * lambdaA
        # Backward cycle loss
        lossCycleB = self.criterionCycle(self.recB, self.realB) * lambdaB
        # combined loss
        lossG = lossGenA + lossGenB + lossCycleA + lossCycleB + lossIdtA + lossIdtB
        lossG.backward(retain_graph=retain_graph)
        # move image to cpu
        self.lossGenA = float(lossGenA)
        self.lossGenB = float(lossGenB)
        self.lossCycleA = float(lossCycleA)
        self.lossCycleB = float(lossCycleB)
        self.lossIdtA = float(lossIdtA)
        self.lossIdtB = float(lossIdtB)

    def optimize_parameters(self):
        # GenA and GenB
        self.set_requires_grad([self.netDisA, self.netDisB], False)
        self.optimizerG.zero_grad()
        self.backward_gen()
        self.optimizerG.step()
        # DisA and DisB
        self.set_requires_grad([self.netDisA, self.netDisB], True)
        self.optimizerD.zero_grad()
        self.backward_dis_A()
        self.backward_dis_B()
        self.optimizerD.step()
Пример #13
0
def train():
    if cfg.load_model is not None:
        checkpoints_dir = cfg.load_model

    graph = tf.Graph()
    with graph.as_default():
        cycle_gan = CycleGAN()
        G_loss, D_Y_loss, F_loss, D_X_loss, fake_y, fake_x = cycle_gan.model()
        G_optimizers, D_optimizers = cycle_gan.optimize(G_loss,
                                                        D_Y_loss,
                                                        F_loss,
                                                        D_X_loss,
                                                        gan=cfg.gan)
        summary_op = tf.summary.merge_all()
        train_writer = tf.summary.FileWriter(cfg.tb_dir, graph)
        #for v in tf.global_variables():
        #    print(v.name)
        if cfg.new_pretrain is not None:
            var_to_restore = []
            for v in tf.global_variables():
                var_to_restore.append(v)
            saver = tf.train.Saver(var_to_restore)
            saver_dump = tf.train.Saver()
        else:
            saver = tf.train.Saver()
            saver_dump = tf.train.Saver()

    with tf.Session(graph=graph) as sess:
        if cfg.load_model is not None:
            checkpoint = tf.train.get_checkpoint_state(checkpoints_dir)
            meta_graph_path = checkpoint.model_checkpoint_path + ".meta"
            restore = tf.train.import_meta_graph(meta_graph_path)
            restore.restore(sess, tf.train.latest_checkpoint(checkpoints_dir))
            step = int(meta_graph_path.split("-")[1].split(".")[0])
        else:
            sess.run(tf.global_variables_initializer())
            step = 0
            print(
                '--------------------------------------------------------------------------------'
            )
            if cfg.new_pretrain is not None:
                saver.restore(sess, cfg.new_pretrain)

        ## TODO dataset
        trainA = Dataset(cfg.trainA_dir)
        trainB = Dataset(cfg.trainB_dir)

        # train
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        D_times = 0
        G_train_times = 0

        try:
            fake_Y_pool = ImagePool(cfg.pool_size)
            fake_X_pool = ImagePool(cfg.pool_size)

            while not coord.should_stop():
                st_t = time.time()
                # generate data
                x_image = sess.run(trainA.data)[0]
                #x_image = x_image + tf.random_normal(shape=tf.shape(x_image), mean=0.0, stddev=0.1, dtype=tf.float32)
                y_image = sess.run(trainB.data)[0]
                # y_image = y_image + tf.random_normal(shape=tf.shape(y_image), mean=0.0, stddev=0.1, dtype=tf.float32)

                data_time = time.time() - st_t
                st_t = time.time()
                # generate fake_x, fake_y
                fake_y_val, fake_x_val = sess.run([fake_y, fake_x],
                                                  feed_dict={
                                                      cycle_gan.x_image:
                                                      x_image,
                                                      cycle_gan.y_image:
                                                      y_image
                                                  })
                gen_fake_time = time.time() - st_t
                st_t = time.time()
                # train
                # Discrminator

                _, G_loss_val, D_Y_loss_val, F_loss_val, D_X_loss_val, summary = \
                        sess.run([D_optimizers, G_loss, D_Y_loss, F_loss, D_X_loss,
                        summary_op], feed_dict={
                            cycle_gan.fake_y: fake_Y_pool.query(fake_y_val),
                            cycle_gan.fake_x: fake_X_pool.query(fake_x_val),
                            cycle_gan.x_image: x_image,
                            cycle_gan.y_image: y_image})

                if D_times > 0 and D_times % cfg.D_times == 0:
                    D_times = 0
                    G_train_times += 1
                    _, G_loss_val, D_Y_loss_val, F_loss_val, D_X_loss_val, summary = \
                        sess.run([G_optimizers, G_loss, D_Y_loss, F_loss, D_X_loss,
                        summary_op], feed_dict={
                            cycle_gan.fake_y: fake_Y_pool.query(fake_y_val),
                            cycle_gan.fake_x: fake_X_pool.query(fake_x_val),
                            cycle_gan.x_image: x_image,
                            cycle_gan.y_image: y_image})

                bp_time = time.time() - st_t
                train_writer.add_summary(summary, step)
                train_writer.flush()

                if step % 1 == 0:
                    logging.info(
                        'step {} | G_loss : {:.4f} | D_Y_loss : {:.4f} | F_loss : {:.4f} |'
                        'D_X_loss : {:.4f} | g_train_times: {} | data {:.3f}s | gen_fake {:.3f}s | bp {:.3f}s'
                        .format(step, G_loss_val, D_Y_loss_val, F_loss_val,
                                D_X_loss_val, G_train_times, data_time,
                                gen_fake_time, bp_time))

                if step % 100 == 0:
                    save_path = saver_dump.save(sess,
                                                cfg.model_dump_dir +
                                                '/model.ckpt',
                                                global_step=step)
                    logging.info('model saved in files: %s' % save_path)

                D_times += 1
                step += 1

        except KeyboardInterrupt:
            logging.info('Interrupted')
            coord.request_stop()
        except Exception as e:
            coord.request_stop(e)
        finally:
            save_path = saver_dump.save(sess,
                                        cfg.model_dump_dir + '/model.ckpt',
                                        global_step=step)
            logging.info('model saved in files: %s' % save_path)
            coord.request_stop()
            coord.join(threads)
def train():

    # 如果存储中间训练结果的路径设置不为None 就从路径中读取数据继续训练,如果为None则建立一个新的,以时间命名的文件夹存储训练结果
    if FLAGS.load_model is not None:
        checkpoints_dir = "checkpoints/" + FLAGS.load_model
    else:
        current_time = datetime.now().strftime("%Y%m%d-%H%M")
        checkpoints_dir = "checkpoints/{}".format(current_time)
        try:
            os.makedirs(checkpoints_dir)
            os.makedirs(FLAGS.res_im_path)
        except os.error:
            pass

    graph = tf.Graph()
    with graph.as_default():
        # 初始化 cyclegan 类
        cycle_gan = CycleGAN(FLAGS)

        # 构建图
        G_loss, D_Y_loss, F_loss, D_X_loss, fake_y, fake_x, real_y, real_x = cycle_gan.model(
        )
        optimizers = cycle_gan.optimize(G_loss, D_Y_loss, F_loss, D_X_loss)

        # 初始化summary
        summary_op = tf.summary.merge_all()
        train_writer = tf.summary.FileWriter(checkpoints_dir, graph)
        saver = tf.train.Saver(max_to_keep=10)

    with tf.Session(graph=graph) as sess:
        # 如果存储中间训练结果的路径设置不为None 就从路径中读取数据继续训练
        if FLAGS.load_model is not None:
            checkpoint = tf.train.get_checkpoint_state(checkpoints_dir)
            meta_graph_path = checkpoint.model_checkpoint_path + ".meta"
            restore = tf.train.import_meta_graph(meta_graph_path)
            restore.restore(sess, tf.train.latest_checkpoint(checkpoints_dir))
            step = int(meta_graph_path.split("-")[2].split(".")[0])
        else:
            sess.run(tf.global_variables_initializer())
            step = 0

        # 初始化样本队列
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        try:
            # 初始化在线样本池
            fake_Y_pool = ImagePool(FLAGS.pool_size)
            fake_X_pool = ImagePool(FLAGS.pool_size)

            while not coord.should_stop():
                # get previously generated images
                fake_y_val, fake_x_val, real_y_in, real_x_in = sess.run(
                    [fake_y, fake_x, real_y, real_x])

                # train
                _, G_loss_val, D_Y_loss_val, F_loss_val, D_X_loss_val, summary = (
                    sess.run(
                        [
                            optimizers, G_loss, D_Y_loss, F_loss, D_X_loss,
                            summary_op
                        ],
                        feed_dict={
                            cycle_gan.fake_y: fake_Y_pool.query(fake_y_val),
                            cycle_gan.fake_x: fake_X_pool.query(fake_x_val)
                        }))

                train_writer.add_summary(summary, step)
                train_writer.flush()
                # 输出当前状态
                if step % 1 == 0:
                    logging.info('-----------Step %d:-------------' % step)
                    logging.info('  G_loss   : {}'.format(G_loss_val))
                    logging.info('  D_Y_loss : {}'.format(D_Y_loss_val))
                    logging.info('  F_loss   : {}'.format(F_loss_val))
                    logging.info('  D_X_loss : {}'.format(D_X_loss_val))

                if step % 1000 == 0:
                    ops.save_img_result(fake_y_val, fake_x_val, real_y_in,
                                        real_x_in, FLAGS.res_im_path, step)

                if step % 1000 == 0:
                    save_path = saver.save(sess,
                                           checkpoints_dir + "/model.ckpt",
                                           global_step=step)
                    logging.info("Model saved in file: %s" % save_path)

                step += 1
                if step == FLAGS.epho:
                    coord.request_stop()  # 发出停止训练信号

        except KeyboardInterrupt:
            logging.info('Interrupted')
            coord.request_stop()
        except Exception as e:
            coord.request_stop(e)
        finally:
            save_path = saver.save(sess,
                                   checkpoints_dir + "/model.ckpt",
                                   global_step=step)
            ops.save_img_result(fake_y_val, fake_x_val, real_y_in, real_x_in,
                                FLAGS.res_im_path, step)
            logging.info("Model saved in file: %s" % save_path)

            coord.request_stop()  # 停止训练
            coord.join(threads)
Пример #15
0
class CycleGANModel():
    def __init__(self,opt):
        self.opt = opt
        self.dynamic = opt.dynamic
        self.isTrain = opt.istrain
        self.Tensor = torch.cuda.FloatTensor

        # load/define networks
        # The naming conversion is different from those used in the paper
        # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)

        self.netG_A = state2img().cuda()
        self.netG_B = img2state().cuda()
        self.netF_A = Fmodel().cuda()
        self.dataF = CDFdata.get_loader(opt)
        self.train_forward()

        if self.isTrain:
            self.netD_A = imgDmodel().cuda()
            self.netD_B = stateDmodel().cuda()

        if self.isTrain:
            self.fake_A_pool = ImagePool(pool_size=128)
            self.fake_B_pool = ImagePool(pool_size=128)
            # define loss functions
            self.criterionGAN = GANLoss(tensor=self.Tensor).cuda()
            if opt.loss == 'l1':
                self.criterionCycle = torch.nn.L1Loss()
                self.criterionIdt = torch.nn.L1Loss()
            elif opt.loss == 'l2':
                self.criterionCycle = torch.nn.MSELoss()
                self.criterionIdt = torch.nn.MSELoss()
            # initialize optimizers
            # self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()))
            self.optimizer_G = torch.optim.Adam([{'params':self.netG_A.parameters(),'lr':1e-3},
                                                 {'params':self.netF_A.parameters(),'lr':0.0},
                                                 {'params':self.netG_B.parameters(),'lr':1e-3}])
            self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters())
            self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters())
            self.optimizers = []
            self.schedulers = []
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D_A)
            self.optimizers.append(self.optimizer_D_B)
            for optimizer in self.optimizers:
                # self.schedulers.append(networks.get_scheduler(optimizer, opt))
                self.schedulers.append(optimizer)

        print('---------- Networks initialized -------------')
        # networks.print_network(self.netG_A)
        # networks.print_network(self.netG_B)
        # if self.isTrain:
        #     networks.print_network(self.netD_A)
        #     networks.print_network(self.netD_B)
        print('-----------------------------------------------')


    def train_forward(self,pretrained=True):
        if pretrained:
            self.netF_A.load_state_dict(torch.load('./pred.pth'))
            return None
        optimizer = torch.optim.Adam(self.netF_A.parameters(),lr=1e-3)
        loss_fn = torch.nn.L1Loss()
        for epoch in range(100):
            epoch_loss = 0
            for i,item in enumerate(tqdm(self.dataF)):
                state, action, result = item[1]
                state = state.float().cuda()
                action = action.float().cuda()
                result = result.float().cuda()
                out = self.netF_A(state, action)
                loss = loss_fn(out, result)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                epoch_loss += loss.item()
            print('epoch:{} loss:{:.7f}'.format(epoch,epoch_loss/len(self.dataF)))
        print('forward model has been trained!')

    def set_input(self, input):
        # AtoB = self.opt.which_direction == 'AtoB'
        # input_A = input['A' if AtoB else 'B']
        # input_B = input['B' if AtoB else 'A']
        # A is state
        self.input_A = input[1][0]

        # B is img
        self.input_Bt0 = input[0][0]
        self.input_Bt1 = input[0][2]
        self.action = input[0][1]
        self.gt0 = input[2][0].float().cuda()
        self.gt1 = input[2][1].float().cuda()


    def forward(self):
        self.real_A = Variable(self.input_A).float().cuda()
        self.real_Bt0 = Variable(self.input_Bt0).float().cuda()
        self.real_Bt1 = Variable(self.input_Bt1).float().cuda()
        self.action = Variable(self.action).float().cuda()


    def test(self):
        self.forward()
        real_A = Variable(self.input_A, volatile=True).float().cuda()
        fake_B = self.netG_A(real_A)
        self.rec_A = self.netG_B(fake_B).data
        self.fake_B = fake_B.data

        real_B = Variable(self.input_Bt0, volatile=True).float().cuda()
        fake_A = self.netG_B(real_B)
        self.rec_B = self.netG_A(fake_A).data
        self.fake_A = fake_A.data


    def backward_D_basic(self, netD, real, fake):
        # Real
        pred_real = netD(real)
        loss_D_real = self.criterionGAN(pred_real, True)
        # Fake
        pred_fake = netD(fake.detach())
        loss_D_fake = self.criterionGAN(pred_fake, False)
        # Combined loss
        loss_D = (loss_D_real + loss_D_fake) * 0.5
        # backward
        loss_D.backward()
        return loss_D

    def backward_D_A(self):
        fake_B = self.fake_B_pool.query(self.fake_B)
        loss_D_A = self.backward_D_basic(self.netD_A, self.real_Bt0, fake_B)
        self.loss_D_A = loss_D_A.item()

    def backward_D_B(self):
        fake_A = self.fake_A_pool.query(self.fake_At0)
        loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A)
        self.loss_D_B = loss_D_B.item()

    def backward_G(self):
        lambda_idt = -0.5
        lambda_A = self.opt.lambda_A
        lambda_B = self.opt.lambda_B
        lambda_F = self.opt.lambda_F
        # Identity loss
        if lambda_idt > 0:
            # G_A should be identity if real_B is fed.
            idt_A = self.netG_A(self.real_Bt0)
            loss_idt_A = self.criterionIdt(idt_A, self.real_Bt0) * lambda_B * lambda_idt
            # G_B should be identity if real_A is fed.
            idt_B = self.netG_B(self.real_A)
            loss_idt_B = self.criterionIdt(idt_B, self.real_A) * lambda_A * lambda_idt

            self.idt_A = idt_A.data
            self.idt_B = idt_B.data
            self.loss_idt_A = loss_idt_A.item()
            self.loss_idt_B = loss_idt_B.item()
        else:
            loss_idt_A = 0
            loss_idt_B = 0
            self.loss_idt_A = 0
            self.loss_idt_B = 0

        lambda_G_A = 1.0
        lambda_G_B = 1.0

        # --------first cycle-----------#
        # GAN loss D_A(G_A(A))
        fake_B = self.netG_A(self.real_A)
        pred_fake = self.netD_A(fake_B)
        loss_G_A = self.criterionGAN(pred_fake, True) * lambda_G_A
        # Forward cycle loss
        rec_A = self.netG_B(fake_B)
        loss_cycle_A = self.criterionCycle(rec_A, self.real_A) * lambda_A

        # ---------second cycle---------#
        # GAN loss D_B(G_B(B))
        fake_At0 = self.netG_B(self.real_Bt0)
        pred_fake = self.netD_B(fake_At0)
        loss_G_Bt0 = self.criterionGAN(pred_fake, True) * lambda_G_B
        # Backward cycle loss
        rec_Bt0 = self.netG_A(fake_At0)
        loss_cycle_Bt0 = self.criterionCycle(rec_Bt0, self.real_Bt0) * lambda_B

        # ---------third cycle---------#
        # GAN loss D_B(G_B(B))
        fake_At1 = self.netF_A(fake_At0,self.action)
        pred_fake = self.netD_B(fake_At1)
        loss_G_Bt1 = self.criterionGAN(pred_fake, True) * lambda_G_B
        # Backward cycle loss
        rec_Bt1 = self.netG_A(fake_At1)
        loss_cycle_Bt1 = self.criterionCycle(rec_Bt1, self.real_Bt1) * lambda_F


        # combined loss
        loss_G = loss_idt_A + loss_idt_B
        loss_G = loss_G + loss_G_A + loss_cycle_A
        loss_G = loss_G + loss_G_Bt0 + loss_cycle_Bt0
        if self.dynamic:
            loss_G = loss_G + loss_G_Bt1 + loss_cycle_Bt1
        loss_G.backward()

        self.fake_B = fake_B.data
        self.fake_At0 = fake_At0.data
        self.fake_At1 = fake_At1.data
        self.rec_A = rec_A.data
        self.rec_Bt0 = rec_Bt0.data
        self.rec_Bt1 = rec_Bt1.data

        self.loss_G_A = loss_G_A.item()
        self.loss_G_Bt0 = loss_G_Bt0.item()
        self.loss_G_Bt1 = loss_G_Bt1.item()
        self.loss_cycle_A = loss_cycle_A.item()
        self.loss_cycle_Bt0 = loss_cycle_Bt0.item()
        self.loss_cycle_Bt1 = loss_cycle_Bt1.item()

        self.loss_state_l1 = self.criterionCycle(self.fake_At0, self.gt).item()


    def optimize_parameters(self):
        # forward
        self.forward()
        # G_A and G_B
        self.optimizer_G.zero_grad()
        self.backward_G()
        self.optimizer_G.step()
        # D_A
        self.optimizer_D_A.zero_grad()
        self.backward_D_A()
        self.optimizer_D_A.step()
        # D_B
        self.optimizer_D_B.zero_grad()
        self.backward_D_B()
        self.optimizer_D_B.step()

    def get_current_errors(self):
        ret_errors = OrderedDict([('L1',self.loss_state_l1), ('D_A', self.loss_D_A), ('G_A', self.loss_G_A),
                                  ('Cyc_A', self.loss_cycle_A), ('D_B', self.loss_D_B),
                                  ('G_B0', self.loss_G_Bt0), ('G_B1', self.loss_G_Bt1),
                                  ('Cyc_B0',  self.loss_cycle_Bt0), ('Cyc_B1',  self.loss_cycle_Bt1)])
        # if self.opt.identity > 0.0:
        #     ret_errors['idt_A'] = self.loss_idt_A
        #     ret_errors['idt_B'] = self.loss_idt_B
        return ret_errors

    # helper saving function that can be used by subclasses
    def save_network(self, network, network_label, path):
        save_filename = 'model_{}.pth'.format(network_label)
        save_path = os.path.join(path, save_filename)
        torch.save(network.state_dict(), save_path)

    def save(self, path):
        self.save_network(self.netG_A, 'G_A', path)
        self.save_network(self.netD_A, 'D_A', path)
        self.save_network(self.netG_B, 'G_B', path)
        self.save_network(self.netD_B, 'D_B', path)

    def load_network(self, network, network_label, path):
        weight_filename = 'model_{}.pth'.format(network_label)
        weight_path = os.path.join(path, weight_filename)
        network.load_state_dict(torch.load(weight_path))

    def load(self,path):
        self.load_network(self.netG_A, 'G_A', path)
        self.load_network(self.netG_B, 'G_B', path)

    def plot_points(self,item,label):
        item = item.cpu().data.numpy()
        plt.scatter(item[:,0],item[:,1],label=label)

    def visual(self,path):
        imgs = []
        for i in range(self.real_A.shape[0]):
            imgs_i = [self.real_Bt0[i], self.rec_Bt0[i]]
            imgs_i += [self.real_Bt1[i], self.rec_Bt1[i]]
            imgs_i += [self.fake_B[i]]
            imgs_i = torch.cat(imgs_i, 2).cpu()
            imgs.append(imgs_i)
        imgs = torch.cat(imgs, 1)
        imgs = (imgs + 1) / 2
        imgs = transforms.ToPILImage()(imgs)
        imgs.save(path)
Пример #16
0
def train():
    if FLAGS.load_model is not None:  #如果该命令行参数不为空,则据此给出checkpoint_dir
        checkpoints_dir = "checkpoints/" + FLAGS.load_model
    else:  #否则,根据当前时间,创建一个checkpoint_dir
        current_time = datetime.now().strftime("%Y%m%d - %H%M")
        checkpoints_dir = "checkpoints/{}".format(current_time)
        try:
            os.makedirs(checkpoints_dir)
        except os.error:
            pass

    graph = tf.Graph()  #创建计算图
    with graph.as_default():
        cycle_gan = CycleGAN(X_train_file=FLAGS.X,
                             Y_train_file=FLAGS.Y,
                             batch_size=FLAGS.batch_size,
                             image_size=FLAGS.image_size,
                             use_lsgan=FLAGS.use_lsgan,
                             norm=FLAGS.norm,
                             lambda1=FLAGS.lambda1,
                             lambda2=FLAGS.lambda1,
                             learning_rate=FLAGS.learning_rate,
                             beta1=FLAGS.beta1,
                             ngf=FLAGS.ngf)  #引入CycleGAN网络
        G_loss, D_Y_loss, F_loss, D_X_loss, fake_y, fake_x = cycle_gan.model(
        )  #返回值分别是:反向生成网络损失,正向判别函数损失,生成网络损失,逆向判别函数损失,正向生成的y,反向生成的x
        optimizers = cycle_gan.optimize(G_loss, D_Y_loss, F_loss,
                                        D_X_loss)  #四个损失的优化器

        summary_op = tf.summary.merge_all()  #将一些信息显示在stdoutput中
        train_writer = tf.summary.FileWriter(checkpoints_dir,
                                             graph)  #将图保存在checkpoints_dir中
        saver = tf.train.Saver()

    with tf.Session(graph=graph) as sess:
        if FLAGS.load_model is not None:  #如果已存在训练模型,则加载继续训练
            checkpoint = tf.train.get_checkpoint_state(
                checkpoints_dir)  #将最新的model加载进来
            meta_graph_path = checkpoint.model_checkpoint_path + ".meta"
            restore = tf.train.import_meta_graph(meta_graph_path)  #加载model结构
            restore.restore(
                sess,
                tf.train.latest_checkpoint(checkpoints_dir))  #加载最新的model模型参数
            step = int(meta_graph_path.split("-")[2].split(".")[0])
        else:
            sess.run(tf.global_variables_initializer())  #初始化全局变量
            step = 0

        coord = tf.train.Coordinator()  #进行线程管理
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        try:
            fake_Y_pool = ImagePool(FLASG.pool_size)  #设定image缓冲大小
            fake_X_pool = ImagePool(FLAGS.pool_size)

            while not coord.should_stop():
                fake_y_val, fake_x_val = sess.run(
                    [fake_y, fake_x])  #先得出generated image x,y???

                #train
                _, G_loss_val, D_Y_loss_val, F_loss_val, D_X_loss_val, summary = (
                    sess.run(
                        [
                            optimizers, G_loss, D_Y_loss, F_loss, D_X_loss,
                            summary_op
                        ],
                        feed_dict={
                            cycle_gan.fake_y: fake_Y_pool.query(
                                fake_y_val
                            ),  #将上述得到的fake_x,fake_y输入到optimizers,G_loss,...,中,优化; 假设,初始化F,D_y,然后根据x得到fake_y,然后根据G,D_x,y,得到fake_x,根据这些value:x,y,fake_x,fake_y,求上述的几个loss,利用优化器对其进行优化
                            cycle_gan.fake_x: fake_X_pool.query(fake_x_val)
                        }  #还是没太弄明白 为什么一会儿fake_y,一会儿self.fake_y;是要缓冲若干个fake_y???
                    ))  #进行训练
                if step % 100 == 0:  #到100步时,将信息输出到stdout
                    train_writer.add_summary(summary, step)
                    train_writer.flush()

                if step % 100 == 0:
                    logging.info('----------step %d:--------------' % step)
                    logging.info(' G_loss : {}'.format(G_loss_val))
                    logging.info(' D_Y_loss : {}'.format(D_Y_loss_val))
                    logging.info(' F_loss : {}'.format(F_loss_val))
                    logging.info(' D_X_loss : {}'.format(D_X_loss_val))

                if step % 10000 == 0:
                    save_path = saver.save(sess,
                                           checkpoints_dir + "/model.ckpt",
                                           global_step=step)
                    logging.info("Model saved in file: %s" % save_path)

                step += 1

        except KeyboardInterrupt:
            logging.info('Interrupted')
            coord.request_stop()
        except Exception as e:
            coord.request_stop(e)
        finally:
            save_path = saver.save(
                sess, checkpoints_dir + "/model.ckpt",
                global_step=step)  #训练完成后,将训练好的model保存起来.ckpt;
            logging.info("Model saved in file: %s" % save_path)
            coord.request_stop()
            coord.join(threads)
Пример #17
0
class CycleGAN(ModelBackbone):
    def __init__(self, p):

        super(CycleGAN, self).__init__(p)
        nb = p.batchSize
        size = p.cropSize

        # load/define models
        # The naming conversion is different from those used in the paper
        # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)

        self.netG_A = networks.define_G(p.input_nc, p.output_nc, p.ngf,
                                        p.which_model_netG, p.norm,
                                        not p.no_dropout, p.init_type,
                                        self.gpu_ids)
        self.netG_B = networks.define_G(p.output_nc, p.input_nc, p.ngf,
                                        p.which_model_netG, p.norm,
                                        not p.no_dropout, p.init_type,
                                        self.gpu_ids)

        if self.isTrain:
            use_sigmoid = p.no_lsgan
            self.netD_A = networks.define_D(p.output_nc, p.ndf,
                                            p.which_model_netD, p.n_layers_D,
                                            p.norm, use_sigmoid, p.init_type,
                                            self.gpu_ids)
            self.netD_B = networks.define_D(p.input_nc, p.ndf,
                                            p.which_model_netD, p.n_layers_D,
                                            p.norm, use_sigmoid, p.init_type,
                                            self.gpu_ids)

        if not self.isTrain or p.continue_train:
            which_epoch = p.which_epoch
            self.load_model(self.netG_A, 'G_A', which_epoch)
            self.load_model(self.netG_B, 'G_B', which_epoch)
            if self.isTrain:
                self.load_model(self.netD_A, 'D_A', which_epoch)
                self.load_model(self.netD_B, 'D_B', which_epoch)

        if self.isTrain:
            self.old_lr = p.lr
            self.fake_A_pool = ImagePool(p.pool_size)
            self.fake_B_pool = ImagePool(p.pool_size)
            # define loss functions
            self.criterionGAN = networks.GANLoss(use_lsgan=not p.no_lsgan,
                                                 tensor=self.Tensor)
            self.criterionCycle = torch.nn.L1Loss()
            self.criterionIdt = torch.nn.L1Loss()

            # initialize optimizers
            self.optimizer_G = torch.optim.Adam(itertools.chain(
                self.netG_A.parameters(), self.netG_B.parameters()),
                                                lr=p.lr,
                                                betas=(p.beta1, 0.999))
            self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(),
                                                  lr=p.lr,
                                                  betas=(p.beta1, 0.999))
            self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(),
                                                  lr=p.lr,
                                                  betas=(p.beta1, 0.999))

            self.optimizers = []
            self.schedulers = []
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D_A)
            self.optimizers.append(self.optimizer_D_B)
            for optimizer in self.optimizers:
                self.schedulers.append(networks.get_scheduler(optimizer, p))

        # print('---------- Networks initialized -------------')
        # networks.print_network(self.netG_A)
        # networks.print_network(self.netG_B)
        # if self.isTrain:
        #     networks.print_network(self.netD_A)
        #     networks.print_network(self.netD_B)
        # print('-----------------------------------------------')

    def name(self):
        return 'CycleGAN'

    def set_input(self, inp):
        AtoB = self.p.which_direction == 'AtoB'
        input_A = inp['A' if AtoB else 'B']
        input_B = inp['B' if AtoB else 'A']
        if len(self.gpu_ids) > 0:
            input_A = input_A.cuda(self.gpu_ids[0])
            input_B = input_B.cuda(self.gpu_ids[0])
        self.input_A = input_A
        self.input_B = input_B
        self.image_paths = inp['A_path' if AtoB else 'B_path']

    def forward(self):
        self.real_A = Variable(self.input_A)
        self.real_B = Variable(self.input_B)

    def test(self):
        with torch.no_grad():
            real_A = Variable(self.input_A)
            fake_B = self.netG_A(real_A)
            self.rec_A = self.netG_B(fake_B).data
            self.fake_B = fake_B.data

            real_B = Variable(self.input_B)
            fake_A = self.netG_B(real_B)
            self.rec_B = self.netG_A(fake_A).data
            self.fake_A = fake_A.data

    # get image paths
    def get_image_paths(self):
        return self.image_paths

    def backward_D_basic(self, netD, real, fake):
        # Real
        pred_real = netD(real)
        loss_D_real = self.criterionGAN(pred_real, True)
        # Fake
        pred_fake = netD(fake.detach())
        loss_D_fake = self.criterionGAN(pred_fake, False)
        # Combined loss
        loss_D = (loss_D_real + loss_D_fake) * 0.5
        # backward
        loss_D.backward()
        return loss_D

    def backward_D_A(self):
        fake_B = self.fake_B_pool.query(self.fake_B)
        loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B)
        self.loss_D_A = loss_D_A.item()

    def backward_D_B(self):
        fake_A = self.fake_A_pool.query(self.fake_A)
        loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A)
        self.loss_D_B = loss_D_B.item()

    def backward_G(self):
        lambda_idt = self.p.identity
        lambda_A = self.p.lambda_A
        lambda_B = self.p.lambda_B
        # Identity loss
        if lambda_idt > 0:
            # G_A should be identity if real_B is fed.
            idt_A = self.netG_A(self.real_B)
            loss_idt_A = self.criterionIdt(idt_A,
                                           self.real_B) * lambda_B * lambda_idt
            # G_B should be identity if real_A is fed.
            idt_B = self.netG_B(self.real_A)
            loss_idt_B = self.criterionIdt(idt_B,
                                           self.real_A) * lambda_A * lambda_idt
            self.idt_A = idt_A.data
            self.idt_B = idt_B.data
            self.loss_idt_A = loss_idt_A.item()
            self.loss_idt_B = loss_idt_B.item()

        else:
            loss_idt_A = 0
            loss_idt_B = 0
            self.loss_idt_A = 0
            self.loss_idt_B = 0

        # GAN loss D_A(G_A(A))
        fake_B = self.netG_A(self.real_A)
        pred_fake = self.netD_A(fake_B)
        loss_G_A = self.criterionGAN(pred_fake, True)

        # GAN loss D_B(G_B(B))
        fake_A = self.netG_B(self.real_B)
        pred_fake = self.netD_B(fake_A)
        loss_G_B = self.criterionGAN(pred_fake, True)

        # Forward cycle loss
        rec_A = self.netG_B(fake_B)
        loss_cycle_A = self.criterionCycle(rec_A, self.real_A) * lambda_A
        # print("loss_cycle_A  ",loss_cycle_A.grad)

        # Backward cycle loss
        rec_B = self.netG_A(fake_A)
        loss_cycle_B = self.criterionCycle(rec_B, self.real_B) * lambda_B

        # combined loss
        loss_G = loss_G_A + loss_G_B + loss_cycle_A + loss_cycle_B + loss_idt_A + loss_idt_B
        loss_G.backward()

        self.fake_B = fake_B.data
        self.fake_A = fake_A.data
        self.rec_A = rec_A.data
        self.rec_B = rec_B.data

        self.loss_G_A = loss_G_A.item()
        self.loss_G_B = loss_G_B.item()

        self.loss_cycle_A = loss_cycle_A.item()
        self.loss_cycle_B = loss_cycle_B.item()

    # def get_gradient(self, img):
    #     # print("get_gradient ",img)
    #     # return np.gradient(np.array(img))
    #     return np.gradient(img)

    def optimize_parameters(self):
        # forward
        self.forward()
        # G_A and G_B
        self.optimizer_G.zero_grad()
        self.backward_G()
        self.optimizer_G.step()
        # D_A
        self.optimizer_D_A.zero_grad()
        self.backward_D_A()
        self.optimizer_D_A.step()
        # D_B
        self.optimizer_D_B.zero_grad()
        self.backward_D_B()
        self.optimizer_D_B.step()

    def get_current_errors(self):
        ret_errors = OrderedDict([('D_A', self.loss_D_A),
                                  ('G_A', self.loss_G_A),
                                  ('Cyc_A', self.loss_cycle_A),
                                  ('D_B', self.loss_D_B),
                                  ('G_B', self.loss_G_B),
                                  ('Cyc_B', self.loss_cycle_B)])

        if self.p.identity > 0.0:
            ret_errors['idt_A'] = self.loss_idt_A
            ret_errors['idt_B'] = self.loss_idt_B

        return ret_errors

    def get_current_visuals(self):
        real_A = tensor2im(self.input_A)
        fake_B = tensor2im(self.fake_B)
        rec_A = tensor2im(self.rec_A)
        real_B = tensor2im(self.input_B)
        fake_A = tensor2im(self.fake_A)
        rec_B = tensor2im(self.rec_B)
        ret_visuals = OrderedDict([('real_A', real_A), ('fake_B', fake_B),
                                   ('rec_A', rec_A), ('real_B', real_B),
                                   ('fake_A', fake_A), ('rec_B', rec_B)])
        if self.isTrain and self.p.identity > 0.0:
            ret_visuals['idt_A'] = tensor2im(self.idt_A)
            ret_visuals['idt_B'] = tensor2im(self.idt_B)

        return ret_visuals

    def save(self, label):
        self.save_model(self.netG_A, 'G_A', label, self.gpu_ids)
        self.save_model(self.netD_A, 'D_A', label, self.gpu_ids)
        self.save_model(self.netG_B, 'G_B', label, self.gpu_ids)
        self.save_model(self.netD_B, 'D_B', label, self.gpu_ids)
Пример #18
0
def train():
    if FLAGS.load_model is not None:
        checkpoints_dir = "checkpoints/" + FLAGS.load_model.lstrip(
            "checkpoints/")
    else:
        current_time = datetime.now().strftime("%Y%m%d-%H%M")
        checkpoints_dir = "checkpoints/{}".format(current_time)
        try:
            os.makedirs(checkpoints_dir)
        except os.error:
            pass

    graph = tf.Graph()
    variable_to_restore = []
    with graph.as_default():
        segmentation = SegmentationNN('combined_model')
        variable_to_restore = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)

        cycle_gan = CycleGAN(X_train_file=FLAGS.X,
                             Y_train_file=FLAGS.Y,
                             batch_size=FLAGS.batch_size,
                             image_size=FLAGS.image_size,
                             use_lsgan=FLAGS.use_lsgan,
                             norm=FLAGS.norm,
                             lambda1=FLAGS.lambda1,
                             lambda2=FLAGS.lambda2,
                             learning_rate=FLAGS.learning_rate,
                             beta1=FLAGS.beta1,
                             ngf=FLAGS.ngf,
                             segmentation=segmentation)
        G_loss, D_Y_loss, F_loss, D_X_loss, fake_y, fake_x = cycle_gan.model()
        optimizers = cycle_gan.optimize(G_loss, D_Y_loss, F_loss, D_X_loss)

        summary_op = tf.summary.merge_all()
        train_writer = tf.summary.FileWriter(checkpoints_dir, graph)
        saver = tf.train.Saver()

    with tf.Session(graph=graph) as sess:
        if FLAGS.load_model is not None:
            checkpoint = tf.train.get_checkpoint_state(checkpoints_dir)
            meta_graph_path = checkpoint.model_checkpoint_path + ".meta"
            restore = tf.train.import_meta_graph(meta_graph_path)
            restore.restore(sess, tf.train.latest_checkpoint(checkpoints_dir))
            step = int(meta_graph_path.split("-")[2].split(".")[0])
        else:
            sess.run(tf.global_variables_initializer())
            print('variables', variable_to_restore)
            restore1 = tf.train.Saver(variable_to_restore)
            restore1.restore(sess, 'Segmentation/lib/real.ckpt')
            step = 0

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        try:
            fake_Y_pool = ImagePool(FLAGS.pool_size)
            fake_X_pool = ImagePool(FLAGS.pool_size)

            while not coord.should_stop():
                # get previously generated images
                fake_y_val, fake_x_val = sess.run([fake_y, fake_x])

                R = 95 * np.ones([cycle_gan.batch_size, 256, 256])
                G = 40 * np.ones([cycle_gan.batch_size, 256, 256])
                B = 20 * np.ones([cycle_gan.batch_size, 256, 256])
                ones = np.ones(([cycle_gan.batch_size, 256, 256, 3]))

                RGB = np.stack([R, G, B], axis=3)

                _fake_x = fake_X_pool.query(fake_x_val)
                _fake_y = fake_Y_pool.query(fake_y_val)

                # train
                _, G_loss_val, D_Y_loss_val, F_loss_val, D_X_loss_val, summary = (
                    sess.run(
                        [
                            optimizers, G_loss, D_Y_loss, F_loss, D_X_loss,
                            summary_op
                        ],
                        feed_dict={
                            cycle_gan.fake_y: _fake_x,
                            cycle_gan.fake_x: _fake_y,
                            cycle_gan.RGB: RGB,
                            cycle_gan.ones: ones,
                            cycle_gan.covered1: False,
                            cycle_gan.covered2: True
                        }))

                train_writer.add_summary(summary, step)
                train_writer.flush()

                if step % 100 == 0:
                    logging.info('-----------Step %d:-------------' % step)
                    logging.info('  G_loss   : {}'.format(G_loss_val))
                    logging.info('  D_Y_loss : {}'.format(D_Y_loss_val))
                    logging.info('  F_loss   : {}'.format(F_loss_val))
                    logging.info('  D_X_loss : {}'.format(D_X_loss_val))

                if step % 10000 == 0:
                    save_path = saver.save(sess,
                                           checkpoints_dir + "/model.ckpt",
                                           global_step=step)
                    logging.info("Model saved in file: %s" % save_path)

                step += 1

        except KeyboardInterrupt:
            logging.info('Interrupted')
            coord.request_stop()
        except Exception as e:
            coord.request_stop(e)
        finally:
            save_path = saver.save(sess,
                                   checkpoints_dir + "/model.ckpt",
                                   global_step=step)
            logging.info("Model saved in file: %s" % save_path)
            # When done, ask the threads to stop.
            coord.request_stop()
            coord.join(threads)
                x_last_test_predict_list = []
                for last_i in range(len(testimage_x_list)):
                    x_last_test_predict_list.append(0)
                y_last_test_predict_list = []
                for last_i in range(len(testimage_x_list)):
                    y_last_test_predict_list.append(0)

                while not coord.should_stop():

                    if step <= 25:
                        for i in range(FLAGS.dis_pretrain):
                            _, fake_y_val, fake_x_val = sess.run(
                                [D_optimizer, fake_y, fake_x],
                                feed_dict={
                                    cycle_gan.fake_y:
                                    fake_Y_pool.query(fake_y_val),
                                    cycle_gan.fake_x:
                                    fake_X_pool.query(fake_x_val)
                                })
                        _, G_loss_val, D_Y_loss_val, F_loss_val, D_X_loss_val, fake_y_val, fake_x_val,\
                        real_y, real_x, reconstructed_y_val, reconstructed_x_val, summary, \
                        D_Y_output_real_1_val, D_Y_output_fake_forG_1_val, D_Y_output_real_2_val, D_Y_output_fake_forG_2_val, \
                        D_X_output_real_1_val, D_X_output_fake_forG_1_val, D_X_output_real_2_val, D_X_output_fake_forG_2_val  = \
                            sess.run([optimizers, G_loss, D_Y_loss, F_loss, D_X_loss, fake_y, fake_x, y, x, reconstructed_y,
                                      reconstructed_x, summary_op,
                                      D_Y_output_real_1, D_Y_output_fake_forG_1, D_Y_output_real_2,
                                      D_Y_output_fake_forG_2,
                                      D_X_output_real_1, D_X_output_fake_forG_1, D_X_output_real_2,
                                      D_X_output_fake_forG_2],
                                     feed_dict={cycle_gan.fake_y: fake_Y_pool.query(fake_y_val),
                                                cycle_gan.fake_x: fake_X_pool.query(fake_x_val)})
Пример #20
0
class resgan(BaseModel):
    def init_architecture(self, opt):
        self.opt = opt
        self.netG = define_G(opt.in_nc,
                             opt.out_nc,
                             opt.nz,
                             opt.ngf,
                             which_model_netG=opt.G_model)
        if opt.use_gpu:
            self.netG.cuda()

        if opt.isTrain:
            self.netD = define_D(opt.in_nc, opt.ngf, 'basic_128')
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                                lr=opt.lr,
                                                betas=(opt.beta1, 0.999))

            self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
                                                lr=opt.lr,
                                                betas=(opt.beta1, 0.999))

            if opt.use_gpu:
                self.netD.cuda()

            self.optimizers = [self.optimizer_G, self.optimizer_D]

            self.fake_A_pool = ImagePool(opt.pool_size)

    def forward(self):
        self.real_A = self.input_A
        self.real_B = self.input_B

        self.G_fake_B = self.netG(self.real_A)

    def update_D(self, netD, real, fake, optim):
        D_fake = netD(self.fake_A_pool.query(fake.data))
        D_real = netD(real)

        D_fake_loss = self.critGAN(D_fake, False)
        D_real_loss = self.critGAN(D_real, True)
        D_loss = (D_fake_loss + D_real_loss) * 0.5

        optim.zero_grad()
        D_loss.backward()
        optim.step()

        return D_loss

    def update_G(self):
        loss_G = 0

        pred_fake = self.netD(self.G_fake_B)
        loss_GAN = self.critGAN(pred_fake, True)
        loss_G += loss_GAN

        noise_est = self.real_A - self.G_fake_B
        noisy_est = self.real_B + noise_est
        rec_B = self.netG(noisy_est)

        rec_loss = self.critL1(rec_B, self.real_B)
        loss_G += self.opt.l1_lambda * rec_loss

        self.optimizer_G.zero_grad()
        loss_G.backward()
        self.optimizer_G.step()

        return loss_G

    def optimize_parameters(self):
        self.forward()
        self.update_G()
        self.update_D(self.netD, self.real_B, self.G_fake_B.detach(),
                      self.optimizer_D)

    def save(self):
        self.save_network(self.netG, 'G')
        self.save_network(self.netD, 'D')
Пример #21
0
class Model:
    def initialize(self, cfg):
        self.cfg = cfg

        ## set devices
        if cfg['GPU_IDS']:
            assert (torch.cuda.is_available())
            self.device = torch.device('cuda:{}'.format(cfg['GPU_IDS'][0]))
            torch.backends.cudnn.benchmark = True
            print('Using %d GPUs' % len(cfg['GPU_IDS']))
        else:
            self.device = torch.device('cpu')

        # define network
        if cfg['ARCHI'] == 'alexnet':
            self.netB = networks.netB_alexnet()
            self.netH = networks.netH_alexnet()
            if self.cfg['USE_DA'] and self.cfg['TRAIN']:
                self.netD = networks.netD_alexnet(self.cfg['DA_LAYER'])
        elif cfg['ARCHI'] == 'vgg16':
            raise NotImplementedError
            self.netB = networks.netB_vgg16()
            self.netH = networks.netH_vgg16()
            if self.cfg['USE_DA'] and self.cfg['TRAIN']:
                self.netD = netD_vgg16(self.cfg['DA_LAYER'])
        elif 'resnet' in cfg['ARCHI']:
            self.netB = networks.netB_resnet34()
            self.netH = networks.netH_resnet34()
            if self.cfg['USE_DA'] and self.cfg['TRAIN']:
                self.netD = networks.netD_resnet(self.cfg['DA_LAYER'])
        else:
            raise ValueError('Un-supported network')

        ## initialize network param.
        self.netB = networks.init_net(self.netB, cfg['GPU_IDS'], 'xavier')
        self.netH = networks.init_net(self.netH, cfg['GPU_IDS'], 'xavier')

        if self.cfg['USE_DA'] and self.cfg['TRAIN']:
            self.netD = networks.init_net(self.netD, cfg['GPU_IDS'], 'xavier')

        # loss, optimizer, and scherduler
        if cfg['TRAIN']:
            self.total_steps = 0
            ## Output path
            self.save_dir = os.path.join(
                cfg['OUTPUT_PATH'], cfg['ARCHI'],
                datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
            if not os.path.isdir(self.save_dir):
                os.makedirs(self.save_dir)
            # self.logger = Logger(self.save_dir)

            ## model names
            self.model_names = ['netB', 'netH']
            ## loss
            self.criterionGAN = networks.GANLoss().to(self.device)
            self.criterionDepth1 = torch.nn.MSELoss().to(self.device)
            self.criterionNorm = torch.nn.CosineEmbeddingLoss().to(self.device)
            self.criterionEdge = torch.nn.BCELoss().to(self.device)

            ## optimizers
            self.lr = cfg['LR']
            self.optimizers = []
            self.optimizer_B = torch.optim.Adam(self.netB.parameters(),
                                                lr=cfg['LR'],
                                                betas=(cfg['BETA1'],
                                                       cfg['BETA2']))
            self.optimizer_H = torch.optim.Adam(self.netH.parameters(),
                                                lr=cfg['LR'],
                                                betas=(cfg['BETA1'],
                                                       cfg['BETA2']))
            self.optimizers.append(self.optimizer_B)
            self.optimizers.append(self.optimizer_H)
            if cfg['USE_DA']:
                self.real_pool = ImagePool(cfg['POOL_SIZE'])
                self.syn_pool = ImagePool(cfg['POOL_SIZE'])
                self.model_names.append('netD')
                ## use SGD for discriminator
                self.optimizer_D = torch.optim.SGD(
                    self.netD.parameters(),
                    lr=cfg['LR'],
                    momentum=cfg['MOMENTUM'],
                    weight_decay=cfg['WEIGHT_DECAY'])
                self.optimizers.append(self.optimizer_D)
            ## LR scheduler
            self.schedulers = [
                networks.get_scheduler(optimizer, cfg)
                for optimizer in self.optimizers
            ]
        else:
            ## testing
            self.model_names = ['netB', 'netH']
            self.criterionDepth1 = torch.nn.MSELoss().to(self.device)
            self.criterionNorm = torch.nn.CosineEmbeddingLoss(
                reduction='none').to(self.device)

        self.load_dir = os.path.join(cfg['CKPT_PATH'])
        self.criterionNorm_eval = torch.nn.CosineEmbeddingLoss(
            reduction='none').to(self.device)

        if cfg['TEST'] or cfg['RESUME']:
            self.load_networks(cfg['EPOCH_LOAD'])

    def set_input(self, inputs):
        if self.cfg['GRAY']:
            _ch = np.random.randint(3)
            _syn = inputs['syn']['color'][:, _ch, :, :]
            self.input_syn_color = torch.stack((_syn, _syn, _syn),
                                               dim=1).to(self.device)
        else:
            self.input_syn_color = inputs['syn']['color'].to(self.device)
        self.input_syn_dep = inputs['syn']['depth'].to(self.device)
        self.input_syn_edge = inputs['syn']['edge'].to(self.device)
        self.input_syn_edge_count = inputs['syn']['edge_pix'].to(self.device)
        self.input_syn_norm = inputs['syn']['normal'].to(self.device)
        if self.cfg['USE_DA']:
            if self.cfg['GRAY']:
                _ch = np.random.randint(3)
                _real = inputs['real'][0][:, _ch, :, :]
                self.input_real_color = torch.stack((_real, _real, _real),
                                                    dim=1).to(self.device)
            else:
                self.input_real_color = inputs['real'][0].to(self.device)

    def forward(self):
        self.feat_syn = self.netB(self.input_syn_color)
        # TODO: make it compatible with other networks in addition to ResNet
        self.head_pred = self.netH(self.feat_syn['layer1'],
                                   self.feat_syn['layer2'],
                                   self.feat_syn['layer3'],
                                   self.feat_syn['layer4'])
        if self.cfg['USE_DA'] and self.cfg['TRAIN']:
            self.feat_real = self.netB(self.input_real_color)
            self.pred_D_real = self.netD(self.feat_real[self.cfg['DA_LAYER']])
            self.pred_D_syn = self.netD(self.feat_syn[self.cfg['DA_LAYER']])
        return self.head_pred

    def backward_BH(self):
        ## forward to compute prediction
        # TODO: replace this with self.head_pred to avoid computation twice
        self.task_pred = self.netH(self.feat_syn['layer1'],
                                   self.feat_syn['layer2'],
                                   self.feat_syn['layer3'],
                                   self.feat_syn['layer4'])

        # depth
        depth_diff = self.task_pred['depth'] - self.input_syn_dep
        _n = self.task_pred['depth'].size(0) * self.task_pred['depth'].size(
            2) * self.task_pred['depth'].size(3)
        loss_depth2 = depth_diff.sum().pow_(2).div_(_n).div_(_n)
        loss_depth1 = self.criterionDepth1(self.task_pred['depth'],
                                           self.input_syn_dep)
        self.loss_dep = self.cfg['DEP_WEIGHT'] * (loss_depth1 -
                                                  loss_depth2) * 0.5

        # surface normal
        ch = self.task_pred['norm'].size(1)
        _pred = self.task_pred['norm'].permute(0, 2, 3,
                                               1).contiguous().view(-1, ch)
        _gt = self.input_syn_norm.permute(0, 2, 3, 1).contiguous().view(-1, ch)
        _gt = (_gt / 127.5) - 1
        _pred = torch.nn.functional.normalize(_pred, dim=1)
        self.task_pred['norm'] = _pred.view(self.task_pred['norm'].size(0),
                                            self.task_pred['norm'].size(2),
                                            self.task_pred['norm'].size(3),
                                            3).permute(0, 3, 1, 2)
        self.task_pred['norm'] = (self.task_pred['norm'] + 1) * 127.5
        cos_label = torch.ones(_gt.size(0)).to(self.device)
        self.loss_norm = self.cfg['NORM_WEIGHT'] * self.criterionNorm(
            _pred, _gt, cos_label)

        # # edge
        self.loss_edge = self.cfg['EDGE_WEIGHT'] * self.criterionEdge(
            self.task_pred['edge'], self.input_syn_edge)

        ## combined loss
        loss = self.loss_dep + self.loss_norm + self.loss_edge

        if self.cfg['USE_DA']:
            pred_syn = self.netD(self.feat_syn[self.cfg['DA_LAYER']])
            self.loss_DA = self.criterionGAN(pred_syn, True)
            loss += self.loss_DA * self.cfg['DA_WEIGHT']

        loss.backward()

    def backward_D(self):
        ## Synthetic
        # stop backprop to netB by detaching
        _feat_s = self.syn_pool.query(
            self.feat_syn[self.cfg['DA_LAYER']].detach().cpu())
        pred_syn = self.netD(_feat_s.to(self.device))
        self.loss_D_syn = self.criterionGAN(pred_syn, False)

        ## Real
        _feat_r = self.real_pool.query(
            self.feat_real[self.cfg['DA_LAYER']].detach().cpu())
        pred_real = self.netD(_feat_r.to(self.device))
        self.loss_D_real = self.criterionGAN(pred_real, True)

        ## Combined
        self.loss_D = (self.loss_D_syn + self.loss_D_real) * 0.5
        self.loss_D.backward()

    def optimize(self):
        self.total_steps += 1
        self.train_mode()
        self.forward()
        # if DA, update on real data
        if self.cfg['USE_DA']:
            self.set_requires_grad(self.netD, True)
            self.set_requires_grad([self.netB, self.netH], False)
            self.optimizer_D.zero_grad()
            self.backward_D()
            self.optimizer_D.step()

        # update on synthetic data
        if self.cfg['USE_DA']:
            self.set_requires_grad([self.netB, self.netH], True)
            self.set_requires_grad(self.netD, False)
        self.optimizer_B.zero_grad()
        self.optimizer_H.zero_grad()
        self.backward_BH()
        self.optimizer_B.step()
        self.optimizer_H.step()

    def train_mode(self):
        self.netB.train()
        self.netH.train()
        if self.cfg['USE_DA']:
            self.netD.train()

    # make models eval mode during test time
    def eval_mode(self):
        self.netB.eval()
        self.netH.eval()
        if self.cfg['USE_DA']:
            self.netD.eval()

    def angle_error_ratio(self, angle_degree, base_angle_degree):
        logic_map = torch.gt(
            base_angle_degree *
            torch.ones_like(angle_degree, device=self.device), angle_degree)
        if len(logic_map.size()) == 1:
            num_pixels = torch.sum(logic_map).float()
        else:
            num_pixels = torch.sum(logic_map, dim=1).float()
        ratio = torch.div(
            num_pixels,
            torch.tensor(angle_degree.nelement(),
                         device=self.device,
                         dtype=torch.float64))
        return ratio, logic_map

    def normal_angle(self):
        # surface normal
        ch = self.head_pred['norm'].size(1)
        _pred = self.head_pred['norm'].permute(0, 2, 3,
                                               1).contiguous().view(-1, ch)
        _gt = self.input_syn_norm.permute(0, 2, 3, 1).contiguous().view(-1, ch)
        _gt = (_gt / 127.5) - 1
        _pred = torch.nn.functional.normalize(_pred, dim=1)
        _gt = torch.nn.functional.normalize(_gt, dim=1)
        cos_label = torch.ones(_gt.size(0)).to(self.device)
        norm_diff = self.criterionNorm_eval(_pred, _gt, cos_label)

        cos_val = 1 - norm_diff
        cos_val = torch.max(cos_val,
                            -torch.ones_like(cos_val, device=self.device))
        cos_val = torch.min(cos_val,
                            torch.ones_like(cos_val, device=self.device))

        angle_rad = torch.acos(cos_val)
        angle_degree = angle_rad / 3.14 * 180
        return angle_degree

    def test(self):
        self.eval_mode()
        with torch.no_grad():
            self.forward()

            angle_degree = self.normal_angle()
            # ratio metrics
            ratio_11, _ = self.angle_error_ratio(angle_degree, 11.25)
            ratio_22, _ = self.angle_error_ratio(angle_degree, 22.5)
            ratio_30, _ = self.angle_error_ratio(angle_degree, 30.0)
            ratio_45, _ = self.angle_error_ratio(angle_degree, 45.0)
            # image-wise metrics
            batch_size = self.head_pred['norm'].size(0)
            # TODO double check if it's image-wise
            batch_angles = angle_degree.view(batch_size, -1)
            image_mean = torch.mean(batch_angles, dim=1)
            image_score, _ = self.angle_error_ratio(batch_angles, 45.0)

            return {
                'batch_size': batch_size,
                'pixel_error': angle_degree.cpu().detach().numpy(),
                'image_mean': image_mean.cpu().detach().numpy(),
                'image_score': image_score.cpu().detach().numpy(),
                'ratio_11': ratio_11.cpu().detach().numpy(),
                'ratio_22': ratio_22.cpu().detach().numpy(),
                'ratio_30': ratio_30.cpu().detach().numpy(),
                'ratio_45': ratio_45.cpu().detach().numpy()
            }

    def out_logic_map(self, epoch_num, img_num):
        self.eval_mode()
        with torch.no_grad():
            self.forward()

        # surface normal
        batch_size = self.head_pred['norm'].size(0)
        ch = self.head_pred['norm'].size(1)
        _pred = self.head_pred['norm'].permute(0, 2, 3,
                                               1).contiguous().view(-1, ch)
        _gt = self.input_syn_norm.permute(0, 2, 3, 1).contiguous().view(-1, ch)
        _gt = (_gt / 127.5) - 1
        _pred = torch.nn.functional.normalize(_pred, dim=1)
        _gt = torch.nn.functional.normalize(_gt, dim=1)
        cos_label = torch.ones(_gt.size(0)).to(self.device)
        norm_diff = self.criterionNorm_eval(_pred, _gt, cos_label)

        cos_val = 1 - norm_diff
        cos_val = torch.max(cos_val,
                            -torch.ones_like(cos_val, device=self.device))
        cos_val = torch.min(cos_val,
                            torch.ones_like(cos_val, device=self.device))

        angle_rad = torch.acos(cos_val)
        angle_degree = angle_rad / 3.14 * 180
        # ratio metrics
        ratio_11, c = self.angle_error_ratio(angle_degree, 11.25)

        good_pixel_img = torch.cat(
            (c.view(-1, 1), c.view(-1, 1), c.view(-1, 1)),
            1).view(self.head_pred['norm'].size(0),
                    self.head_pred['norm'].size(2),
                    self.head_pred['norm'].size(3), 3).permute(0, 3, 1, 2)

        self.head_pred['norm'] = _pred.view(self.head_pred['norm'].size(0),
                                            self.head_pred['norm'].size(2),
                                            self.head_pred['norm'].size(3),
                                            3).permute(0, 3, 1, 2)
        self.head_pred['norm'] = (self.head_pred['norm'] + 1) * 127.5
        vis_norm = torch.cat((self.input_syn_norm, self.head_pred['norm'],
                              good_pixel_img.float() * 255),
                             dim=0)
        map_path = '%s/ep%d' % (self.cfg['VIS_PATH'], epoch_num)
        if not os.path.isdir(map_path):
            os.makedirs(map_path)
        torchvision.utils.save_image(vis_norm.detach(),
                                     '%s/%d_norm.jpg' % (map_path, img_num),
                                     nrow=1,
                                     normalize=True)

    # update learning rate (called once every epoch)
    def update_learning_rate(self):
        for scheduler in self.schedulers:
            scheduler.step()
        self.lr = self.cfgimizers[0].param_groups[0]['lr']
        print('learning rate = %.7f' % self.lr)

    #  return visualization images. train.py will save the images.
    def visualize_pred(self, ep=0):
        vis_dir = os.path.join(self.save_dir, 'vis')
        if not os.path.isdir(vis_dir):
            os.makedirs(vis_dir)
        if self.total_steps % self.cfg['VIS_FREQ'] == 0:
            num_pic = min(10, self.task_pred['norm'].size(0))
            torchvision.utils.save_image(self.input_syn_color[0:num_pic].cpu(),
                                         '%s/ep_%d_iter_%d_color.jpg' %
                                         (vis_dir, ep, self.total_steps),
                                         nrow=num_pic,
                                         normalize=True)
            vis_norm = torch.cat((self.input_syn_norm[0:num_pic],
                                  self.task_pred['norm'][0:num_pic]),
                                 dim=0)
            torchvision.utils.save_image(vis_norm.detach(),
                                         '%s/ep_%d_iter_%d_norm.jpg' %
                                         (vis_dir, ep, self.total_steps),
                                         nrow=num_pic,
                                         normalize=True)
            vis_depth = torch.cat((self.input_syn_dep[0:num_pic],
                                   self.task_pred['depth'][0:num_pic]),
                                  dim=0)
            torchvision.utils.save_image(vis_depth.detach(),
                                         '%s/ep_%d_iter_%d_depth.jpg' %
                                         (vis_dir, ep, self.total_steps),
                                         nrow=num_pic,
                                         normalize=True)
            # TODO Jason: visualization
            # edge_vis = torch.nn.functional.sigmoid(self.task_pred['edge'])
            # vis_edge = torch.cat((self.input_syn_edge[0:num_pic], edge_vis[0:num_pic]), dim=0)
            # torchvision.utils.save_image(vis_edge.detach(),
            #                             '%s/ep_%d_iter_%d_edge.jpg' % (vis_dir,ep,self.total_steps),
            #                             nrow=num_pic, normalize=False)
            if self.cfg['USE_DA']:
                torchvision.utils.save_image(
                    self.input_real_color[0:num_pic].cpu(),
                    '%s/ep_%d_iter_%d_real.jpg' %
                    (vis_dir, ep, self.total_steps),
                    nrow=num_pic,
                    normalize=True)
            print('==> Saved epoch %d total step %d visualization to %s' %
                  (ep, self.total_steps, vis_dir))

    # print on screen, log into tensorboard
    def print_n_log_losses(self, ep=0):
        if self.total_steps % self.cfg['PRINT_FREQ'] == 0:
            print('\nEpoch: %d  Total_step: %d  LR: %f' %
                  (ep, self.total_steps, self.lr))
            # print('Train on tasks: Loss_dep: %.4f   | Loss_edge: %.4f   | Loss_norm: %.4f'
            #       % (self.loss_dep, self.loss_edge, self.loss_norm))
            print('Train on tasks: Loss_dep: %.4f | Loss_norm: %.4f' %
                  (self.loss_dep, self.loss_norm))
            info = {
                'loss_dep': self.loss_dep,
                'loss_norm': self.loss_norm  #,
                # 'loss_edge': self.loss_edge
            }
            if self.cfg['USE_DA']:
                print(
                    'Train for DA:   Loss_D_syn: %.4f | Loss_D_real: %.4f | Loss_DA: %.4f'
                    % (self.loss_D_syn, self.loss_D_real, self.loss_DA))
                info['loss_D_syn'] = self.loss_D_syn
                info['loss_D_real'] = self.loss_D_real
                info['loss_DA'] = self.loss_DA

            # for tag, value in info.items():
            #     self.logger.scalar_summary(tag, value, self.total_steps)

    # save models to the disk
    def save_networks(self, which_epoch):
        for name in self.model_names:
            save_filename = '%s_ep%s.pth' % (name, which_epoch)
            save_path = os.path.join(self.save_dir, save_filename)
            net = getattr(self, name)
            if isinstance(net, torch.nn.DataParallel):
                torch.save(net.module.cpu().state_dict(), save_path)
            else:
                torch.save(net.cpu().state_dict(), save_path)
            print('==> Saved networks to %s' % save_path)
            if torch.cuda.is_available:
                net.cuda(self.device)

    # load models from the disk
    def load_networks(self, which_epoch):
        self.save_dir = self.load_dir
        print('loading networks...')
        if which_epoch == 'None':
            print('epoch is None')
            exit()
        for name in self.model_names:
            if isinstance(name, str):
                load_filename = '%s_ep%s.pth' % (name, which_epoch)
                load_path = os.path.join(self.load_dir, load_filename)
                net = getattr(self, name)
                if isinstance(net, torch.nn.DataParallel):
                    net = net.module
                    print('loading the model from %s' % load_path)
                    # if you are using PyTorch newer than 0.4 (e.g., built from
                    # GitHub source), you can remove str() on self.device
                    state_dict = torch.load(load_path,
                                            map_location=str(self.device))
                    net.load_state_dict(state_dict)

    # set requies_grad=Fasle to avoid computation
    def set_requires_grad(self, nets, requires_grad=False):
        if not isinstance(nets, list):
            nets = [nets]
        for net in nets:
            if net is not None:
                for param in net.parameters():
                    param.requires_grad = requires_grad
Пример #22
0
class CycleGANModel():
    def __init__(self, opt):
        self.opt = opt
        self.dynamic = opt.dynamic
        self.isTrain = opt.istrain
        self.Tensor = torch.cuda.FloatTensor

        # load/define networks
        # The naming conversion is different from those used in the paper
        # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)

        self.netG_B = img2state().cuda()
        self.netF_A = Fmodel().cuda()
        self.dataF = CDFdata.get_loader(opt)
        self.train_forward(pretrained=True)

        self.gt_buffer = []
        self.pred_buffer = []

        # if self.isTrain:
        self.netD_B = stateDmodel().cuda()

        # if self.isTrain:
        self.fake_A_pool = ImagePool(pool_size=128)
        self.fake_B_pool = ImagePool(pool_size=128)
        # define loss functions
        self.criterionGAN = GANLoss(tensor=self.Tensor).cuda()
        if opt.loss == 'l1':
            self.criterionCycle = torch.nn.L1Loss()
            self.criterionIdt = torch.nn.L1Loss()
        elif opt.loss == 'l2':
            self.criterionCycle = torch.nn.MSELoss()
            self.criterionIdt = torch.nn.MSELoss()
        # initialize optimizers
        # self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()))
        self.optimizer_G = torch.optim.Adam([{
            'params': self.netF_A.parameters(),
            'lr': 0.0
        }, {
            'params': self.netG_B.parameters(),
            'lr': 1e-3
        }])
        self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters())

        print('---------- Networks initialized ---------------')
        print('-----------------------------------------------')

    def train_forward(self, pretrained=False):
        if pretrained:
            self.netF_A.load_state_dict(torch.load('./pred_large.pth'))
            return None
        optimizer = torch.optim.Adam(self.netF_A.parameters(), lr=1e-3)
        loss_fn = torch.nn.L1Loss()
        for epoch in range(10):
            epoch_loss = 0
            for i, item in enumerate(tqdm(self.dataF)):
                state, action, result = item[1]
                state = state.float().cuda()
                action = action.float().cuda()
                result = result.float().cuda()
                out = self.netF_A(state, action)
                loss = loss_fn(out, result)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                epoch_loss += loss.item()
            print('epoch:{} loss:{:.7f}'.format(epoch,
                                                epoch_loss / len(self.dataF)))
            torch.save(self.netF_A.state_dict(), './pred_large.pth')
        print('forward model has been trained!')

    def set_input(self, input):
        # AtoB = self.opt.which_direction == 'AtoB'
        # input_A = input['A' if AtoB else 'B']
        # input_B = input['B' if AtoB else 'A']
        # A is state
        self.input_A = input[1][0]

        # B is img
        self.input_Bt0 = input[0][0]
        self.input_Bt1 = input[0][2]
        self.action = input[0][1]
        self.gt0 = input[2][0].float().cuda()
        self.gt1 = input[2][1].float().cuda()

    def forward(self):
        self.real_A = Variable(self.input_A).float().cuda()
        self.real_Bt0 = Variable(self.input_Bt0).float().cuda()
        self.real_Bt1 = Variable(self.input_Bt1).float().cuda()
        self.action = Variable(self.action).float().cuda()

    def test(self):
        # forward
        self.forward()
        # G_A and G_B
        self.backward_G()
        self.backward_D_B()

    def backward_D_basic(self, netD, real, fake):
        # Real
        pred_real = netD(real)
        loss_D_real = self.criterionGAN(pred_real, True)
        # Fake
        pred_fake = netD(fake.detach())
        loss_D_fake = self.criterionGAN(pred_fake, False)
        # Combined loss
        loss_D = (loss_D_real + loss_D_fake) * 0.5
        # backward
        if self.isTrain:
            loss_D.backward()
        return loss_D

    def backward_D_B(self):
        fake_A = self.fake_A_pool.query(self.fake_At0)
        loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A)
        self.loss_D_B = loss_D_B.item()

    def backward_G(self):
        lambda_G_B0 = 1.0
        lambda_G_B1 = 1.0
        lambda_F = 500.0

        # GAN loss D_B(G_B(B))
        fake_At0 = self.netG_B(self.real_Bt0)
        pred_fake = self.netD_B(fake_At0)
        loss_G_Bt0 = self.criterionGAN(pred_fake, True) * lambda_G_B0

        # GAN loss D_B(G_B(B))
        fake_At1 = self.netF_A(fake_At0, self.action)
        pred_fake = self.netD_B(fake_At1)
        loss_G_Bt1 = self.criterionGAN(pred_fake, True) * lambda_G_B1

        # cycle loss
        pred_At1 = self.netG_B(self.real_Bt1)
        cycle_label = torch.zeros_like(fake_At1).float().cuda()
        loss_cycle = self.criterionCycle(fake_At1 - pred_At1,
                                         cycle_label) * lambda_F

        # combined loss
        loss_G = loss_G_Bt0 + loss_G_Bt1 + loss_cycle
        if self.isTrain:
            loss_G.backward()

        self.fake_At0 = fake_At0.data
        self.fake_At1 = fake_At1.data

        self.loss_G_Bt0 = loss_G_Bt0.item()
        self.loss_G_Bt1 = loss_G_Bt1.item()
        self.loss_cycle = loss_cycle.item()

        self.loss_state_lt0 = self.criterionCycle(self.fake_At0,
                                                  self.gt0).item()
        self.loss_state_lt1 = self.criterionCycle(self.fake_At1,
                                                  self.gt1).item()
        self.gt_buffer.append(self.gt0.cpu().data.numpy())
        self.gt_buffer.append(self.gt1.cpu().data.numpy())
        self.pred_buffer.append(self.fake_At0.cpu().data.numpy())
        self.pred_buffer.append(self.fake_At1.cpu().data.numpy())

    def optimize_parameters(self):
        # forward
        self.forward()
        # G_A and G_B
        self.optimizer_G.zero_grad()
        self.backward_G()
        self.optimizer_G.step()
        # D_B
        self.optimizer_D_B.zero_grad()
        self.backward_D_B()
        self.optimizer_D_B.step()

    def get_current_errors(self):
        ret_errors = OrderedDict([('L_t0', self.loss_state_lt0),
                                  ('L_t1', self.loss_state_lt1),
                                  ('D_B', self.loss_D_B),
                                  ('G_B0', self.loss_G_Bt0),
                                  ('G_B1', self.loss_G_Bt1),
                                  ('Cyc', self.loss_cycle)])
        # if self.opt.identity > 0.0:
        #     ret_errors['idt_A'] = self.loss_idt_A
        #     ret_errors['idt_B'] = self.loss_idt_B
        return ret_errors

    # helper saving function that can be used by subclasses
    def save_network(self, network, network_label, path):
        save_filename = 'model_{}.pth'.format(network_label)
        save_path = os.path.join(path, save_filename)
        torch.save(network.state_dict(), save_path)

    def save(self, path):
        self.save_network(self.netG_B, 'G_B2', path)
        self.save_network(self.netD_B, 'D_B2', path)

    def load_network(self, network, network_label, path):
        weight_filename = 'model_{}.pth'.format(network_label)
        weight_path = os.path.join(path, weight_filename)
        network.load_state_dict(torch.load(weight_path))

    def load(self, path):
        self.load_network(self.netG_B, 'G_B', path)

    def show_points(self):
        # num_images = min(imgs.shape[0],num_images)
        ncols = 1
        nrows = 3
        _, axes = plt.subplots(ncols, nrows, figsize=(nrows * 3, ncols * 3))
        axes = axes.flatten()
        gt_data = np.vstack(self.gt_buffer)
        pred_data = np.vstack(self.pred_buffer)
        print(abs(gt_data - pred_data).mean(0))

        for ax_i, ax in enumerate(axes):
            if ax_i < nrows:
                ax.scatter(gt_data[:, ax_i],
                           pred_data[:, ax_i],
                           s=3,
                           label='xyz_{}'.format(ax_i))
            else:
                ax.scatter(self.npdata(self.fake_At1[:, ax_i - nrows]),
                           self.npdata(self.gt1[:, ax_i - nrows]),
                           label='t1_{}'.format(ax_i - nrows))

    def npdata(self, item):
        return item.cpu().data.numpy()

    def visual(self, path):
        # plt.xlim(-4,4)
        # plt.ylim(-1.5,1.5)
        self.show_points()
        plt.legend()
        plt.savefig(path)
        plt.cla()
        plt.clf()
        self.gt_buffer = []
        self.pred_buffer = []
Пример #23
0
def train():
    if FLAGS.load_model is not None:
        checkpoints_dir = "checkpoints/" + FLAGS.load_model
    else:
        current_time = datetime.now().strftime("%Y%m%d-%H%M")
        checkpoints_dir = "checkpoints/{}".format(current_time)
        try:
            os.makedirs(checkpoints_dir)
        except os.error:
            pass

    graph = tf.Graph()
    with graph.as_default():
        cycle_gan = CycleGAN(
            X_train_file=FLAGS.X,
            Y_train_file=FLAGS.Y,
            batch_size=FLAGS.batch_size,
            image_size=FLAGS.image_size,
            use_lsgan=FLAGS.use_lsgan,
            norm=FLAGS.norm,
            lambda1=FLAGS.lambda1,
            lambda2=FLAGS.lambda1,
            learning_rate=FLAGS.learning_rate,
            beta1=FLAGS.beta1,
            ngf=FLAGS.ngf
        )
        G_loss, D_Y_loss, F_loss, D_X_loss, fake_y, fake_x = cycle_gan.model()
        optimizers = cycle_gan.optimize(G_loss, D_Y_loss, F_loss, D_X_loss)

        summary_op = tf.summary.merge_all()
        train_writer = tf.summary.FileWriter(checkpoints_dir, graph)
        saver = tf.train.Saver()
    GPU_OPTIONS=tf.GPUOptions(per_process_gpu_memory_fraction=0.5)
    DEVICE_CONFIG_GPU = tf.ConfigProto(device_count={"CPU": 6}, 
                                    gpu_options=GPU_OPTIONS,
                                    intra_op_parallelism_threads=0,
                                    inter_op_parallelism_threads=0)
    with tf.Session(graph=graph, config=DEVICE_CONFIG_GPU) as sess:
        if FLAGS.load_model is not None:
            checkpoint = tf.train.get_checkpoint_state(checkpoints_dir)
            meta_graph_path = checkpoint.model_checkpoint_path + ".meta"
            restore = tf.train.import_meta_graph(meta_graph_path)
            restore.restore(sess, tf.train.latest_checkpoint(checkpoints_dir))
            step = int(meta_graph_path.split("-")[2].split(".")[0])
        else:
            sess.run(tf.global_variables_initializer())
            step = 0

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        try:
            fake_Y_pool = ImagePool(FLAGS.pool_size)
            fake_X_pool = ImagePool(FLAGS.pool_size)

            while not coord.should_stop():
                # get previously generated images
                fake_y_val, fake_x_val = sess.run([fake_y, fake_x])

                # train
                _, G_loss_val, D_Y_loss_val, F_loss_val, D_X_loss_val, summary = (
                    sess.run(
                        [optimizers, G_loss, D_Y_loss, F_loss, D_X_loss, summary_op],
                        feed_dict={cycle_gan.fake_y: fake_Y_pool.query(fake_y_val),
                                   cycle_gan.fake_x: fake_X_pool.query(fake_x_val)}
                    )
                )
                if step % 100 == 0:
                    train_writer.add_summary(summary, step)
                    train_writer.flush()

                if step % 100 == 0:
                    logging.info('-----------Step %d:-------------' % step)
                    logging.info('  G_loss   : {}'.format(G_loss_val))
                    logging.info('  D_Y_loss : {}'.format(D_Y_loss_val))
                    logging.info('  F_loss   : {}'.format(F_loss_val))
                    logging.info('  D_X_loss : {}'.format(D_X_loss_val))

                if step % 10000 == 0:
                    save_path = saver.save(sess, checkpoints_dir + "/model.ckpt", global_step=step)
                    logging.info("Model saved in file: %s" % save_path)

                step += 1

        except KeyboardInterrupt:
            logging.info('Interrupted')
            coord.request_stop()
        except Exception as e:
            coord.request_stop(e)
        finally:
            save_path = saver.save(sess, checkpoints_dir + "/model.ckpt", global_step=step)
            logging.info("Model saved in file: %s" % save_path)
            # When done, ask the threads to stop.
            coord.request_stop()
            coord.join(threads)
Пример #24
0
class CycleGANModel():
    def __init__(self,opt):
        self.opt = opt
        self.isTrain = opt.istrain
        self.Tensor = torch.cuda.FloatTensor

        # load/define networks
        # The naming conversion is different from those used in the paper
        # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)

        self.netG_A = GModel(opt).cuda()
        self.netG_B = GModel(opt).cuda()

        if self.isTrain:
            self.netD_A = DModel(opt).cuda()
            self.netD_B = DModel(opt).cuda()

        if self.isTrain:
            self.fake_A_pool = ImagePool(pool_size=128)
            self.fake_B_pool = ImagePool(pool_size=128)
            # define loss functions
            self.criterionGAN = GANLoss(tensor=self.Tensor).cuda()
            if opt.loss == 'l1':
                self.criterionCycle = torch.nn.L1Loss()
                self.criterionIdt = torch.nn.L1Loss()
            elif opt.loss == 'l2':
                self.criterionCycle = torch.nn.MSELoss()
                self.criterionIdt = torch.nn.MSELoss()
            # initialize optimizers
            self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()))
            self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters())
            self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters())
            self.optimizers = []
            self.schedulers = []
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D_A)
            self.optimizers.append(self.optimizer_D_B)
            for optimizer in self.optimizers:
                # self.schedulers.append(networks.get_scheduler(optimizer, opt))
                self.schedulers.append(optimizer)

        print('---------- Networks initialized -------------')
        # networks.print_network(self.netG_A)
        # networks.print_network(self.netG_B)
        # if self.isTrain:
        #     networks.print_network(self.netD_A)
        #     networks.print_network(self.netD_B)
        print('-----------------------------------------------')

    def set_input(self, input):
        # AtoB = self.opt.which_direction == 'AtoB'
        # input_A = input['A' if AtoB else 'B']
        # input_B = input['B' if AtoB else 'A']
        self.input_A = input[0]
        self.input_B = input[1]


    def forward(self):
        self.real_A = Variable(self.input_A).float().cuda()
        self.real_B = Variable(self.input_B).float().cuda()

    def test(self):
        self.forward()
        real_A = Variable(self.input_A, volatile=True).float().cuda()
        fake_B = self.netG_A(real_A)
        self.rec_A = self.netG_B(fake_B).data
        self.fake_B = fake_B.data

        real_B = Variable(self.input_B, volatile=True).float().cuda()
        fake_A = self.netG_B(real_B)
        self.rec_B = self.netG_A(fake_A).data
        self.fake_A = fake_A.data


    def backward_D_basic(self, netD, real, fake):
        # Real
        pred_real = netD(real)
        loss_D_real = self.criterionGAN(pred_real, True)
        # Fake
        pred_fake = netD(fake.detach())
        loss_D_fake = self.criterionGAN(pred_fake, False)
        # Combined loss
        loss_D = (loss_D_real + loss_D_fake) * 0.5
        # backward
        loss_D.backward()
        return loss_D

    def backward_D_A(self):
        fake_B = self.fake_B_pool.query(self.fake_B)
        loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B)
        self.loss_D_A = loss_D_A.item()

    def backward_D_B(self):
        fake_A = self.fake_A_pool.query(self.fake_A)
        loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A)
        self.loss_D_B = loss_D_B.item()

    def backward_G(self):
        lambda_idt = 0.5
        lambda_A = self.opt.lambda_AB
        lambda_B = self.opt.lambda_AB
        # Identity loss
        if lambda_idt > 0:
            # G_A should be identity if real_B is fed.
            idt_A = self.netG_A(self.real_B)
            loss_idt_A = self.criterionIdt(idt_A, self.real_B) * lambda_B * lambda_idt
            # G_B should be identity if real_A is fed.
            idt_B = self.netG_B(self.real_A)
            loss_idt_B = self.criterionIdt(idt_B, self.real_A) * lambda_A * lambda_idt

            self.idt_A = idt_A.data
            self.idt_B = idt_B.data
            self.loss_idt_A = loss_idt_A.item()
            self.loss_idt_B = loss_idt_B.item()
        else:
            loss_idt_A = 0
            loss_idt_B = 0
            self.loss_idt_A = 0
            self.loss_idt_B = 0

        lambda_G = 1.0

        # GAN loss D_A(G_A(A))
        fake_B = self.netG_A(self.real_A)
        pred_fake = self.netD_A(fake_B)
        loss_G_A = self.criterionGAN(pred_fake, True) * lambda_G

        # GAN loss D_B(G_B(B))
        fake_A = self.netG_B(self.real_B)
        pred_fake = self.netD_B(fake_A)
        loss_G_B = self.criterionGAN(pred_fake, True) * lambda_G

        # Forward cycle loss
        rec_A = self.netG_B(fake_B)
        loss_cycle_A = self.criterionCycle(rec_A, self.real_A) * lambda_A

        # Backward cycle loss
        rec_B = self.netG_A(fake_A)
        loss_cycle_B = self.criterionCycle(rec_B, self.real_B) * lambda_B
        # combined loss
        loss_G = loss_G_A + loss_G_B + loss_cycle_A + loss_cycle_B + loss_idt_A + loss_idt_B
        loss_G.backward()

        self.fake_B = fake_B.data
        self.fake_A = fake_A.data
        self.rec_A = rec_A.data
        self.rec_B = rec_B.data

        self.loss_G_A = loss_G_A.item()
        self.loss_G_B = loss_G_B.item()
        self.loss_cycle_A = loss_cycle_A.item()
        self.loss_cycle_B = loss_cycle_B.item()

    def optimize_parameters(self):
        # forward
        self.forward()
        # G_A and G_B
        self.optimizer_G.zero_grad()
        self.backward_G()
        self.optimizer_G.step()
        # D_A
        self.optimizer_D_A.zero_grad()
        self.backward_D_A()
        self.optimizer_D_A.step()
        # D_B
        self.optimizer_D_B.zero_grad()
        self.backward_D_B()
        self.optimizer_D_B.step()

    def get_current_errors(self):
        ret_errors = OrderedDict([('D_A', self.loss_D_A), ('G_A', self.loss_G_A), ('Cyc_A', self.loss_cycle_A),
                                 ('D_B', self.loss_D_B), ('G_B', self.loss_G_B), ('Cyc_B',  self.loss_cycle_B)])
        # if self.opt.identity > 0.0:
        ret_errors['idt_A'] = self.loss_idt_A
        ret_errors['idt_B'] = self.loss_idt_B
        return ret_errors

    # helper saving function that can be used by subclasses
    def save_network(self, network, network_label, path):
        save_filename = 'model_{}.pth'.format(network_label)
        save_path = os.path.join(path, save_filename)
        torch.save(network.state_dict(), save_path)

    def save(self, path):
        self.save_network(self.netG_A, 'G_A', path)
        self.save_network(self.netD_A, 'D_A', path)
        self.save_network(self.netG_B, 'G_B', path)
        self.save_network(self.netD_B, 'D_B', path)

    def load_network(self, network, network_label, path):
        weight_filename = 'model_{}.pth'.format(network_label)
        weight_path = os.path.join(path, weight_filename)
        network.load_state_dict(torch.load(weight_path))

    def load(self,path):
        self.load_network(self.netG_A, 'G_A', path)
        self.load_network(self.netG_B, 'G_B', path)

    def visual(self,path):
        imgs = []
        for i in range(self.real_A.shape[0]):
            imgs_i = [self.real_A[i], self.fake_B[i]]
            imgs_i += [self.rec_A[i], self.real_B[i]]
            imgs_i = torch.cat(imgs_i, 2).cpu()
            imgs.append(imgs_i)
        imgs = torch.cat(imgs, 1)
        imgs = (imgs + 1) / 2
        imgs = transforms.ToPILImage()(imgs)
        imgs.save(path)
Пример #25
0
def train():
  if FLAGS.load_model is not None:
    checkpoints_dir = "checkpoints/" + FLAGS.load_model.lstrip("checkpoints/")
  else:
    #current_time = datetime.now().strftime("%Y%m%d-%H%M")
    checkpoints_dir = "checkpoints/{}".format(nameNet)
    try:
      os.makedirs(checkpoints_dir)
    except os.error:
      pass

  graph = tf.Graph()
  with graph.as_default():
    paired_gan = PairedGANDisenRevFullRep(
        XY_train_file=FLAGS.XY,
        batch_size=FLAGS.batch_size,
        image_size=FLAGS.image_size,
        use_lsgan=FLAGS.use_lsgan,
        norm=FLAGS.norm,
        lambdaRecon=FLAGS.lambdaRecon,
        lambdaAlign=FLAGS.lambdaAlign,
        lambdaRev=FLAGS.lambdaRev,
        learning_rate=FLAGS.learning_rate,
        beta1=FLAGS.beta1,
        nfS=FLAGS.nfS,
        nfE=FLAGS.nfE
    )
    G_loss, D_Y_loss, Dex_Y_loss, F_loss, D_X_loss, Dex_X_loss, A_loss, DC_loss, fake_y, fake_x, fake_ex_y, fake_ex_x = paired_gan.model()
    #G_loss, D_Y_loss, F_loss, D_X_loss, A_loss, fake_y, fake_x = paired_gan.model()
    optimizers = paired_gan.optimize(G_loss, D_Y_loss, Dex_Y_loss, F_loss, D_X_loss, Dex_X_loss, A_loss, DC_loss)

    summary_op = tf.summary.merge_all()
    train_writer = tf.summary.FileWriter(checkpoints_dir, graph)
    saver = tf.train.Saver()

  with tf.Session(graph=graph) as sess:
    if FLAGS.load_model is not None:
      checkpoint = tf.train.get_checkpoint_state(checkpoints_dir)
      meta_graph_path = checkpoint.model_checkpoint_path + ".meta"
      restore = tf.train.import_meta_graph(meta_graph_path)
      restore.restore(sess, tf.train.latest_checkpoint(checkpoints_dir))
      step = int(meta_graph_path.split("-")[2].split(".")[0])
    else:
      sess.run(tf.global_variables_initializer())
      step = 0

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    try:
      fake_Y_pool = ImagePool(FLAGS.pool_size)
      fake_X_pool = ImagePool(FLAGS.pool_size)

      fake_ex_Y_pool = ImagePool(FLAGS.pool_size)
      fake_ex_X_pool = ImagePool(FLAGS.pool_size)


      while not coord.should_stop():
        # get previously generated images
        fake_y_val, fake_x_val, fake_ex_y_val, fake_ex_x_val = sess.run([fake_y, fake_x, fake_ex_y, fake_ex_x])

        train
        _, G_loss_val, D_Y_loss_val, F_loss_val, D_X_loss_val, A_loss_val, DC_loss_val, summary = (
              sess.run(
                  [optimizers, G_loss, D_Y_loss, F_loss, D_X_loss, A_loss, DC_loss, summary_op],
                  feed_dict={paired_gan.fake_y: fake_Y_pool.query(fake_y_val),
                             paired_gan.fake_x: fake_X_pool.query(fake_x_val),
                             paired_gan.fake_ex_y: fake_ex_Y_pool.query(fake_y_val),
                             paired_gan.fake_ex_x: fake_ex_X_pool.query(fake_x_val)}
              )
        )
        #_, G_loss_val, D_Y_loss_val, F_loss_val, D_X_loss_val, A_loss_val, summary = (
              #sess.run(
                  #[optimizers, G_loss, D_Y_loss, F_loss, D_X_loss, A_loss, summary_op],
                  #feed_dict={paired_gan.fake_y: fake_Y_pool.query(fake_y_val),
                             #paired_gan.fake_x: fake_X_pool.query(fake_x_val)}
              #)
        #)
        train_writer.add_summary(summary, step)
        train_writer.flush()

        if step % 100 == 0:
          logging.info('-----------Step %d:-------------' % step)
          logging.info('  G_loss   : {}'.format(G_loss_val))
          logging.info('  D_Y_loss : {}'.format(D_Y_loss_val))
          logging.info('  F_loss   : {}'.format(F_loss_val))
          logging.info('  D_X_loss : {}'.format(D_X_loss_val))
          logging.info('  A_loss : {}'.format(A_loss_val))

        if step % 10000 == 0:
          save_path = saver.save(sess, checkpoints_dir + "/model.ckpt", global_step=step)
          logging.info("Model saved in file: %s" % save_path)

        step += 1

    except KeyboardInterrupt:
      logging.info('Interrupted')
      coord.request_stop()
    except Exception as e:
      coord.request_stop(e)
    finally:
      save_path = saver.save(sess, checkpoints_dir + "/model.ckpt", global_step=step)
      logging.info("Model saved in file: %s" % save_path)
      # When done, ask the threads to stop.
      coord.request_stop()
      coord.join(threads)
Пример #26
0
class Model():
    def initialize(self, cfg):
        self.cfg = cfg

        ## set devices
        if cfg['GPU_IDS']:
            assert(torch.cuda.is_available())
            self.device = torch.device('cuda:{}'.format(cfg['GPU_IDS'][0]))
            torch.backends.cudnn.benchmark = True
            print('Using %d GPUs'% len(cfg['GPU_IDS']))
        else:
            self.device = torch.device('cpu')

        # define network
        if cfg['ARCHI'] == 'alexnet':
            self.netB = networks.netB_alexnet()
            self.netH = networks.netH_alexnet()
            if self.cfg['USE_DA'] and self.cfg['TRAIN']:
                self.netD = networks.netD_alexnet(self.cfg['DA_LAYER'])
        elif cfg['ARCHI'] == 'vgg16':
            raise NotImplementedError
            self.netB = networks.netB_vgg16()
            self.netH = networks.netH_vgg16()
            if self.cfg['USE_DA'] and self.cfg['TRAIN']:
                self.netD = netD_vgg16(self.cfg['DA_LAYER'])
        elif 'resnet' in cfg['ARCHI']:
            raise NotImplementedError
            self.netB = networks.netB_resnet()
            self.netH = networks.netH_resnet()
            if self.cfg['USE_DA'] and self.cfg['TRAIN']:
                self.netD = networks.netD_resnet(self.cfg['DA_LAYER'])
        else:
            raise ValueError('Un-supported network')

        ## initialize network param.
        self.netB = networks.init_net(self.netB, cfg['GPU_IDS'], 'xavier')
        self.netH = networks.init_net(self.netH, cfg['GPU_IDS'], 'xavier')
        if self.cfg['USE_DA'] and self.cfg['TRAIN']:
            self.netD = networks.init_net(self.netD, cfg['GPU_IDS'], 'xavier')
        print(self.netB, self.netH, self.netD)

        # loss, optimizer, and scherduler
        if cfg['TRAIN']:
            self.total_steps = 0
            ## Output path
            self.save_dir = os.path.join(cfg['OUTPUT_PATH'], cfg['ARCHI'],
                    datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
            if not os.path.isdir(self.save_dir):
                os.makedirs(self.save_dir)
            self.logger = Logger(self.save_dir)

            ## model names
            self.model_names = ['netB', 'netH']
            ## loss
            self.criterionGAN = networks.GANLoss().to(self.device)
            self.criterionDepth1 = torch.nn.MSELoss().to(self.device)
            self.criterionNorm = torch.nn.CosineEmbeddingLoss().to(self.device)
            # define during running, rely on data weight
            self.criterionEdge = None

            ## optimizers
            self.lr = cfg['LR']
            self.optimizers = []
            self.optimizer_B = torch.optim.Adam(self.netB.parameters(),
                                    lr=cfg['LR'], betas=(cfg['BETA1'], cfg['BETA2']))
            self.optimizer_H = torch.optim.Adam(self.netH.parameters(),
                                    lr=cfg['LR'], betas=(cfg['BETA1'], cfg['BETA2']))
            self.optimizers.append(self.optimizer_B)
            self.optimizers.append(self.optimizer_H)
            if cfg['USE_DA']:
                self.real_pool = ImagePool(cfg['POOL_SIZE'])
                self.syn_pool = ImagePool(cfg['POOL_SIZE'])
                self.model_names.append('netD')
                ## use SGD for discriminator
                self.optimizer_D = torch.optim.SGD(self.netD.parameters(),
                                    lr=cfg['LR'], momentum=cfg['MOMENTUM'], weight_decay=cfg['WEIGHT_DECAY'])
                self.optimizers.append(self.optimizer_D)
            ## LR scheduler
            self.schedulers = [networks.get_scheduler(optimizer, cfg) for optimizer in self.optimizers]

        if cfg['TEST'] or cfg['RESUME']:
            self.load_networks(cfg['CKPT_PATH'])

    def set_input(self, inputs):
        if self.cfg['GRAY']:
            _ch = np.random.randint(3)
            _syn = inputs['syn']['color'][:, _ch, :, :]
            self.input_syn_color = torch.stack((_syn, _syn, _syn), dim=1).to(self.device)
        else:
            self.input_syn_color = inputs['syn']['color'].to(self.device)
        self.input_syn_dep = inputs['syn']['depth'].to(self.device)
        self.input_syn_edge = inputs['syn']['edge'].to(self.device)
        self.input_syn_edge_count = inputs['syn']['edge_pix'].to(self.device)
        self.input_syn_norm = inputs['syn']['normal'].to(self.device)
        if self.cfg['USE_DA']:
            if self.cfg['GRAY']:
                _ch = np.random.randint(3)
                _real = inputs['real'][0][:, _ch, :, :]
                self.input_real_color = torch.stack((_real, _real, _real), dim=1).to(self.device)
            else:
                self.input_real_color = inputs['real'][0].to(self.device)

    def forward(self):
        self.feat_syn = self.netB(self.input_syn_color)
        self.head_pred = self.netH(self.feat_syn['out'])
        if self.cfg['USE_DA'] and self.cfg['TRAIN']:
            self.feat_real = self.netB(self.input_real_color)
            self.pred_D_real = self.netD(self.feat_real[self.cfg['DA_LAYER']])
            self.pred_D_syn  = self.netD(self.feat_syn[self.cfg['DA_LAYER']])

    def backward_BH(self):
        ## forward to compute prediction
        self.task_pred = self.netH(self.feat_syn['out'])

        # depth
        depth_diff = self.task_pred['depth'] - self.input_syn_dep
        _n = self.task_pred['depth'].size(0) * self.task_pred['depth'].size(2) * self.task_pred['depth'].size(3)
        loss_depth2 = depth_diff.sum().div_(_n).pow(2).mul_(0.5)
        loss_depth1 = self.criterionDepth1(self.task_pred['depth'], self.input_syn_dep)
        self.loss_dep  = self.cfg['DEP_WEIGHT'] * (loss_depth1 + loss_depth2) * 0.5

        # surface normal
        ch = self.task_pred['norm'].size(1)
        _pred = self.task_pred['norm'].permute(0, 2, 3, 1).contiguous().view(-1,ch)
        _gt = self.input_syn_norm.permute(0, 2, 3, 1).contiguous().view(-1,ch)
        _gt = (_gt / 127.5) - 1
        _pred = torch.nn.functional.normalize(_pred, dim=1)
        self.task_pred['norm'] = _pred.view(self.task_pred['norm'].size(0), self.task_pred['norm'].size(2), self.task_pred['norm'].size(3),3).permute(0, 3, 1, 2)
        self.task_pred['norm'] = (self.task_pred['norm'] + 1) * 127.5
        cos_label = torch.ones(_gt.size(0)).to(self.device)
        self.loss_norm = self.cfg['NORM_WEIGHT'] * self.criterionNorm(_pred, _gt, cos_label)

        # edge
        weight_e = (self.task_pred['edge'].size(2) * self.task_pred['edge'].size(3) - self.input_syn_edge_count ) / self.input_syn_edge_count
        self.criterionEdge = torch.nn.BCEWithLogitsLoss(weight=weight_e.float().view(-1,1,1,1)).to(self.device)
        self.loss_edge = self.cfg['EDGE_WEIGHT'] * self.criterionEdge(self.task_pred['edge'], self.input_syn_edge)

        ## combined loss
        loss = self.loss_edge + self.loss_norm + self.loss_dep

        if self.cfg['USE_DA']:
            pred_syn = self.netD(self.feat_syn[self.cfg['DA_LAYER']].detach())
            self.loss_DA = self.criterionGAN(pred_syn, True)
            loss += self.loss_DA * self.cfg['DA_WEIGHT']

        loss.backward()

    def backward_D(self):
        ## Synthetic
        # stop backprop to netB by detaching
        _feat_s = self.syn_pool.query(self.feat_syn[self.cfg['DA_LAYER']].detach().cpu())
        pred_syn = self.netD(_feat_s.to(self.device))
        self.loss_D_syn = self.criterionGAN(pred_syn, False)

        ## Real
        _feat_r = self.real_pool.query(self.feat_real[self.cfg['DA_LAYER']].detach().cpu())
        pred_real = self.netD(_feat_r.to(self.device))
        self.loss_D_real = self.criterionGAN(pred_real, True)

        ## Combined
        self.loss_D = (self.loss_D_syn + self.loss_D_real) * 0.5
        self.loss_D.backward()

    def optimize(self):
        self.total_steps += 1
        self.forward()
        # if DA, update on real data
        if self.cfg['USE_DA']:
            self.set_requires_grad(self.netD, True)
            self.set_requires_grad([self.netB, self.netH], False)
            self.optimizer_D.zero_grad()
            self.backward_D()
            self.optimizer_D.step()

        # update on synthetic data
        self.set_requires_grad([self.netB, self.netH], True)
        self.set_requires_grad(self.netD, False)
        self.optimizer_B.zero_grad()
        self.optimizer_H.zero_grad()
        self.backward_BH()
        self.optimizer_B.step()
        self.optimizer_H.step()

    # make models eval mode during test time
    def eval(self):
        self.netB.eval()
        self.netH.eval()
        self.netD.eval()

    # used in test time, wrapping `forward` in no_grad() so we don't save
    # intermediate steps for backprop
    def test(self):
        with torch.no_grad():
            self.forward()

    # update learning rate (called once every epoch)
    def update_learning_rate(self):
        for scheduler in self.schedulers:
            scheduler.step()
        self.lr = self.cfgimizers[0].param_groups[0]['lr']
        print('learning rate = %.7f' % self.lr)

    #  return visualization images. train.py will save the images.
    def visualize_pred(self, ep=0):
        vis_dir = os.path.join(self.save_dir, 'vis')
        if not os.path.isdir(vis_dir):
            os.makedirs(vis_dir)
        if self.total_steps % self.cfg['VIS_FREQ'] == 0:
            num_pic = min(8, self.task_pred['norm'].size(0))
            torchvision.utils.save_image(self.input_syn_color[0:num_pic].cpu(),
                                        '%s/ep_%d_iter_%d_color.jpg' % (vis_dir,ep,self.total_steps),
                                        nrow=num_pic, normalize=True)
            vis_norm = torch.cat((self.input_syn_norm[0:num_pic], self.task_pred['norm'][0:num_pic]), dim=0)
            torchvision.utils.save_image(vis_norm.detach(),
                                        '%s/ep_%d_iter_%d_norm.jpg' % (vis_dir,ep,self.total_steps),
                                        nrow=num_pic, normalize=True)
            vis_depth = torch.cat((self.input_syn_dep[0:num_pic], self.task_pred['depth'][0:num_pic]), dim=0)
            torchvision.utils.save_image(vis_depth.detach(),
                                        '%s/ep_%d_iter_%d_depth.jpg' % (vis_dir,ep,self.total_steps),
                                        nrow=num_pic, normalize=True)
            # TODO: visualization
            edge_vis = torch.nn.functional.sigmoid(self.task_pred['edge'])
            vis_edge = torch.cat((self.input_syn_edge[0:num_pic], edge_vis[0:num_pic]), dim=0)
            torchvision.utils.save_image(vis_edge.detach(),
                                        '%s/ep_%d_iter_%d_edge.jpg' % (vis_dir,ep,self.total_steps),
                                        nrow=num_pic, normalize=False)
            if self.cfg['USE_DA']:
                torchvision.utils.save_image(self.input_real_color[0:num_pic].cpu(),
                                            '%s/ep_%d_iter_%d_real.jpg' % (vis_dir,ep,self.total_steps),
                                            nrow=num_pic, normalize=True)
            print('==> Saved epoch %d total step %d visualization to %s' % (ep, self.total_steps, vis_dir))

    # print on screen, log into tensorboard
    def print_n_log_losses(self, ep=0):
        if self.total_steps % self.cfg['PRINT_FREQ'] == 0:
            print('\nEpoch: %d  Total_step: %d  LR: %f' % (ep, self.total_steps, self.lr))
            print('Train on tasks: Loss_dep: %.4f   | Loss_edge: %.4f   | Loss_norm: %.4f'
                  % (self.loss_dep, self.loss_edge, self.loss_norm))
            info = {
                'loss_dep': self.loss_dep,
                'loss_norm': self.loss_norm,
                'loss_edge': self.loss_edge
                }
            if self.cfg['USE_DA']:
                print('Train for DA:   Loss_D_syn: %.4f | Loss_D_real: %.4f | Loss_DA: %.4f'
                      % (self.loss_D_syn, self.loss_D_real, self.loss_DA))
                info['loss_D_syn'] = self.loss_D_syn
                info['loss_D_real'] = self.loss_D_real
                info['loss_DA'] = self.loss_DA

            for tag, value in info.items():
                self.logger.scalar_summary(tag, value, self.total_steps)

    # save models to the disk
    def save_networks(self, which_epoch):
        for name in self.model_names:
            save_filename = '%s_ep%s.pth' % (name, which_epoch)
            save_path = os.path.join(self.save_dir, save_filename)
            net = getattr(self, name)
            if isinstance(net, torch.nn.DataParallel):
                torch.save(net.module.cpu().state_dict(), save_path)
            else:
                torch.save(net.cpu().state_dict(), save_path)
            print('==> Saved to %s' % save_path)
            if torch.cuda.is_available:
                net.cuda(self.device)

    # load models from the disk
    def load_networks(self, which_epoch):
        for name in self.model_names:
            if isinstance(name, str):
                load_filename = '%s_%s.pth' % (which_epoch, name)
                load_path = os.path.join(self.save_dir, load_filename)
                net = getattr(self, 'net' + name)
                if isinstance(net, torch.nn.DataParallel):
                    net = net.module
                    print('loading the model from %s' % load_path)
                    # if you are using PyTorch newer than 0.4 (e.g., built from
                    # GitHub source), you can remove str() on self.device
                    state_dict = torch.load(load_path, map_location=str(self.device))
                net.load_state_dict(state_dict)

    # set requies_grad=Fasle to avoid computation
    def set_requires_grad(self, nets, requires_grad=False):
        if not isinstance(nets, list):
            nets = [nets]
        for net in nets:
            if net is not None:
                for param in net.parameters():
                    param.requires_grad = requires_grad
Пример #27
0
def train():
  if FLAGS.load_model is not None:
    checkpoints_dir = "checkpoints/" + FLAGS.load_model
  else:
    current_time = datetime.now().strftime("%Y%m%d-%H%M")
    checkpoints_dir = "checkpoints/{}".format(current_time)
    try:
      os.makedirs(checkpoints_dir)
    except os.error:
      pass

  graph = tf.Graph()
  with graph.as_default():
    cycle_gan = CycleGAN(
        X_train_file=FLAGS.X,
        Y_train_file=FLAGS.Y,
        batch_size=FLAGS.batch_size,
        image_size=FLAGS.image_size,
        use_lsgan=FLAGS.use_lsgan,
        norm=FLAGS.norm,
        lambda1=FLAGS.lambda1,
        lambda2=FLAGS.lambda1,
        learning_rate=FLAGS.learning_rate,
        beta1=FLAGS.beta1,
        ngf=FLAGS.ngf
    )
    G_loss, D_Y_loss, F_loss, D_X_loss, fake_y, fake_x = cycle_gan.model()
    optimizers = cycle_gan.optimize(G_loss, D_Y_loss, F_loss, D_X_loss)

    summary_op = tf.summary.merge_all()
    train_writer = tf.summary.FileWriter(checkpoints_dir, graph)
    saver = tf.train.Saver()

  with tf.Session(graph=graph) as sess:
    if FLAGS.load_model is not None:
      checkpoint = tf.train.get_checkpoint_state(checkpoints_dir)
      meta_graph_path = checkpoint.model_checkpoint_path + ".meta"
      restore = tf.train.import_meta_graph(meta_graph_path)
      restore.restore(sess, tf.train.latest_checkpoint(checkpoints_dir))
      step = int(meta_graph_path.split("-")[2].split(".")[0])
    else:
      sess.run(tf.global_variables_initializer())
      step = 0

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    try:
      fake_Y_pool = ImagePool(FLAGS.pool_size)
      fake_X_pool = ImagePool(FLAGS.pool_size)

      while not coord.should_stop():
        # get previously generated images
        fake_y_val, fake_x_val = sess.run([fake_y, fake_x])

        # train
        _, G_loss_val, D_Y_loss_val, F_loss_val, D_X_loss_val, summary = (
              sess.run(
                  [optimizers, G_loss, D_Y_loss, F_loss, D_X_loss, summary_op],
                  feed_dict={cycle_gan.fake_y: fake_Y_pool.query(fake_y_val),
                             cycle_gan.fake_x: fake_X_pool.query(fake_x_val)}
              )
        )
        if step % 100 == 0:
          train_writer.add_summary(summary, step)
          train_writer.flush()

        if step % 100 == 0:
          logging.info('-----------Step %d:-------------' % step)
          logging.info('  G_loss   : {}'.format(G_loss_val))
          logging.info('  D_Y_loss : {}'.format(D_Y_loss_val))
          logging.info('  F_loss   : {}'.format(F_loss_val))
          logging.info('  D_X_loss : {}'.format(D_X_loss_val))

        if step % 10000 == 0:
          save_path = saver.save(sess, checkpoints_dir + "/model.ckpt", global_step=step)
          logging.info("Model saved in file: %s" % save_path)

        step += 1

    except KeyboardInterrupt:
      logging.info('Interrupted')
      coord.request_stop()
    except Exception as e:
      coord.request_stop(e)
    finally:
      save_path = saver.save(sess, checkpoints_dir + "/model.ckpt", global_step=step)
      logging.info("Model saved in file: %s" % save_path)
      # When done, ask the threads to stop.
      coord.request_stop()
      coord.join(threads)
Пример #28
0
def train():
    current_time = datetime.now().strftime("%Y%m%d-%H%M")
    checkpoints_dir = "checkpoints/{}".format(current_time)
    os.makedirs(checkpoints_dir, exist_ok=True)

    graph = tf.Graph()
    with graph.as_default():
        cycle_gan = CycleGAN(X_train_file=FLAGS.X_train_file,
                             Y_train_file=FLAGS.Y_train_file,
                             batch_size=FLAGS.batch_size,
                             image_size=FLAGS.image_size,
                             use_lsgan=FLAGS.use_lsgan,
                             norm=FLAGS.norm,
                             lambda1=FLAGS.lambda1,
                             lambda2=FLAGS.lambda1,
                             learning_rate=FLAGS.learning_rate,
                             beta1=FLAGS.beta1)
        G_loss, D_Y_loss, F_loss, D_X_loss, fake_y, fake_x = cycle_gan.model()
        optimizers = cycle_gan.optimize(G_loss, D_Y_loss, F_loss, D_X_loss)

        summary_op = tf.summary.merge_all()
        train_writer = tf.summary.FileWriter(checkpoints_dir, graph)
        saver = tf.train.Saver()

    with tf.Session(graph=graph) as sess:
        sess.run(tf.global_variables_initializer())

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        try:
            step = 0
            while not coord.should_stop():
                # update previously generated images
                fake_y_val, fake_x_val = sess.run([fake_y, fake_x])
                fake_Y_pool = ImagePool(FLAGS.pool_size)
                fake_X_pool = ImagePool(FLAGS.pool_size)

                # train
                _, G_loss_val, D_Y_loss_val, F_loss_val, D_X_loss_val, summary = (
                    sess.run(
                        [
                            optimizers, G_loss, D_Y_loss, F_loss, D_X_loss,
                            summary_op
                        ],
                        feed_dict={
                            cycle_gan.fake_y: fake_Y_pool.query(fake_y_val),
                            cycle_gan.fake_x: fake_X_pool.query(fake_x_val)
                        }))

                train_writer.add_summary(summary, step)
                train_writer.flush()

                if step % 100 == 0:
                    logging.info('-----------Step %d:-------------' % step)
                    logging.info('  G_loss   : {}'.format(G_loss_val))
                    logging.info('  D_Y_loss : {}'.format(D_Y_loss_val))
                    logging.info('  F_loss   : {}'.format(F_loss_val))
                    logging.info('  D_X_loss : {}'.format(D_X_loss_val))

                if step % 10000 == 0:
                    save_path = saver.save(sess,
                                           checkpoints_dir + "/model.ckpt",
                                           global_step=step)
                    logging.info("Model saved in file: %s" % save_path)

                step += 1

        except KeyboardInterrupt:
            logging.info('Interrupted')
            coord.request_stop()
        except Exception as e:
            coord.request_stop(e)
        finally:
            save_path = saver.save(sess,
                                   checkpoints_dir + "/model.ckpt",
                                   global_step=step)
            logging.info("Model saved in file: %s" % save_path)
            # When done, ask the threads to stop.
            coord.request_stop()
            coord.join(threads)
Пример #29
0
def train():
    # train the model
    if FLAGS.load_model is not None:
        checkpoints_dir = "checkpoints/" + FLAGS.load_model.lstrip(
            "checkpoints/")
    else:
        current_time = datetime.now().strftime("%Y%m%d-%H%M")
        checkpoints_dir = "checkpoints/{}".format(current_time)
        try:
            os.makedirs(checkpoints_dir)
        except os.error:
            pass

    graph = tf.Graph()
    # add segmentation work-flow yw3025
    segementation = Segementation(None)
    with graph.as_default():
        cycle_gan = CycleGAN(
            X_train_file=FLAGS.X,
            Y_train_file=FLAGS.Y,
            batch_size=FLAGS.batch_size,
            image_size=FLAGS.image_size,
            use_lsgan=FLAGS.use_lsgan,
            norm=FLAGS.norm,
            lambda1=FLAGS.lambda1,
            lambda2=FLAGS.lambda2,
            learning_rate=FLAGS.learning_rate,
            beta1=FLAGS.beta1,
            ngf=FLAGS.ngf,
        )
        G_loss, D_Y_loss, F_loss, D_X_loss, fake_y, fake_x, y, x = cycle_gan.model(
        )
        optimizers = cycle_gan.optimize(G_loss, D_Y_loss, F_loss, D_X_loss)

        summary_op = tf.summary.merge_all()
        train_writer = tf.summary.FileWriter(checkpoints_dir, graph)
        saver = tf.train.Saver()

    with tf.Session(graph=graph) as sess:
        if FLAGS.load_model is not None:
            checkpoint = tf.train.get_checkpoint_state(checkpoints_dir)
            meta_graph_path = checkpoint.model_checkpoint_path + ".meta"
            restore = tf.train.import_meta_graph(meta_graph_path)
            restore.restore(sess, tf.train.latest_checkpoint(checkpoints_dir))
            step = int(meta_graph_path.split("-")[2].split(".")[0])
        else:
            sess.run(tf.global_variables_initializer())
            step = 0

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        try:
            fake_Y_pool = ImagePool(FLAGS.pool_size)
            fake_X_pool = ImagePool(FLAGS.pool_size)

            while not coord.should_stop():
                # get previously generated images
                fake_y_val, fake_x_val, y_val, x_val = sess.run(
                    [fake_y, fake_x, y, x])
                # add segmentation work-flow yw3025
                mask_y = segementation.get_result_railway(y_val)
                mask_x = segementation.get_result_railway(x_val)
                mask_fake_x = segementation.get_result_railway(fake_x_val)
                mask_fake_y = segementation.get_result_railway(fake_y_val)
                # train yw3025
                _, G_loss_val, D_Y_loss_val, F_loss_val, D_X_loss_val, summary = (
                    sess.run(
                        [
                            optimizers, G_loss, D_Y_loss, F_loss, D_X_loss,
                            summary_op
                        ],
                        feed_dict={
                            cycle_gan.fake_y: fake_Y_pool.query(fake_y_val),
                            cycle_gan.fake_x: fake_X_pool.query(fake_x_val),
                            cycle_gan.g_mask_x: mask_x,
                            cycle_gan.g_mask_y: mask_y,
                            cycle_gan.g_mask_fake_y: mask_fake_y,
                            cycle_gan.g_mask_fake_x: mask_fake_x,
                            cycle_gan.mask_x: mask_x,
                            cycle_gan.mask_y: mask_y
                        }))

                train_writer.add_summary(summary, step)
                train_writer.flush()

                if step % 100 == 0:
                    logging.info('-----------Step %d:-------------' % step)
                    logging.info('  G_loss   : {}'.format(G_loss_val))
                    logging.info('  D_Y_loss : {}'.format(D_Y_loss_val))
                    logging.info('  F_loss   : {}'.format(F_loss_val))
                    logging.info('  D_X_loss : {}'.format(D_X_loss_val))

                if step % 3000 == 0:
                    save_path = saver.save(sess,
                                           checkpoints_dir + "/model.ckpt",
                                           global_step=step)
                    logging.info("Model saved in file: %s" % save_path)

                step += 1

        except KeyboardInterrupt:
            logging.info('Interrupted')
            coord.request_stop()
        except Exception as e:
            coord.request_stop(e)
        finally:
            save_path = saver.save(sess,
                                   checkpoints_dir + "/model.ckpt",
                                   global_step=step)
            logging.info("Model saved in file: %s" % save_path)
            # When done, ask the threads to stop.
            coord.request_stop()
            coord.join(threads)
Пример #30
0
class CycleMcdModel(BaseModel):
    def __init__(self, opt):
        super(CycleMcdModel, self).__init__(opt)
        print('-------------- Networks initializing -------------')

        self.mode = None

        # specify the training losses you want to print out. The program will call base_model.get_current_losses
        self.lossNames = [
            'loss{}'.format(i) for i in [
                'GenA', 'DisA', 'CycleA', 'IdtA', 'DisB', 'GenB', 'CycleB',
                'IdtB', 'Supervised', 'UnsupervisedClassifier',
                'UnsupervisedFeature'
            ]
        ]
        self.lossGenA, self.lossDisA, self.lossCycleA, self.lossIdtA = 0, 0, 0, 0
        self.lossGenB, self.lossDisB, self.lossCycleB, self.lossIdtB = 0, 0, 0, 0
        self.lossSupervised, self.lossUnsupervisedClassifier, self.lossUnsupervisedFeature = 0, 0, 0

        # define loss functions
        self.criterionGAN = networks.GANLoss(use_lsgan=opt.lsgan).to(
            opt.device)  # lsgan = True use MSE loss, False use BCE loss
        self.criterionCycle = torch.nn.L1Loss()
        self.criterionIdt = torch.nn.L1Loss()
        self.criterionSeg = CrossEntropyLoss2d(opt)  # 2d for each pixels
        self.criterionDis = Distance(opt)

        # specify the training miou you want to print out. The program will call base_model.get_current_mious
        self.miouNames = [
            'miou{}'.format(i) for i in
            ['SupervisedA', 'UnsupervisedA', 'SupervisedB', 'UnsupervisedB']
        ]
        self.miouSupervisedA = IouEval(opt.nClass)
        self.miouUnsupervisedA = IouEval(opt.nClass)
        self.miouSupervisedB = IouEval(opt.nClass)
        self.miouUnsupervisedB = IouEval(opt.nClass)

        # specify the images you want to save/display. The program will call base_model.get_current_visuals
        # only image doesn't have prefix
        imageNamesA = [
            'realA', 'fakeA', 'recA', 'idtA', 'supervisedA', 'predSupervisedA',
            'gndSupervisedA', 'unsupervisedA', 'predUnsupervisedA',
            'gndUnsupervisedA'
        ]
        imageNamesB = [
            'realB', 'fakeB', 'recB', 'idtB', 'supervisedB', 'predSupervisedB',
            'gndSupervisedB', 'unsupervisedB', 'predUnsupervisedB',
            'gndUnsupervisedB'
        ]
        self.imageNames = imageNamesA + imageNamesB
        self.realA, self.fakeA, self.recA, self.idtA = None, None, None, None
        self.supervisedA, self.predSupervisedA, self.gndSupervisedA = None, None, None
        self.unsupervisedA, self.predUnsupervisedA, self.gndUnsupervisedA = None, None, None
        self.realB, self.fakeB, self.recB, self.idtB = None, None, None, None
        self.supervisedB, self.predSupervisedB, self.gndSupervisedB = None, None, None
        self.unsupervisedB, self.predUnsupervisedB, self.gndUnsupervisedB = None, None, None

        # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks
        # naming is by the input domain
        # Cycle gan model: 'GenA', 'DisA', 'GenB', 'DisB'
        # Mcd model : 'Features', 'Classifier1', 'Classifier2'
        self.modelNames = [
            'net{}'.format(i) for i in [
                'GenA', 'DisA', 'GenB', 'DisB', 'Features', 'Classifier1',
                'Classifier2'
            ]
        ]

        # load/define networks
        # The naming conversion is different from those used in the paper
        # Code (paper): G_RGB (G), G_D (F), D_RGB (D_Y), D_D (D_X)
        self.netGenA = networks.define_G(opt.inputCh, opt.inputCh, opt.ngf,
                                         opt.which_model_netG, opt.norm,
                                         opt.dropout, opt.init_type,
                                         opt.init_gain, opt.gpuIds)
        self.netDisA = networks.define_D(opt.inputCh, opt.inputCh,
                                         opt.which_model_netD, opt.n_layers_D,
                                         opt.norm, not opt.lsgan,
                                         opt.init_type, opt.init_gain,
                                         opt.gpuIds)
        self.netGenB = networks.define_G(opt.inputCh, opt.inputCh, opt.ngf,
                                         opt.which_model_netG, opt.norm,
                                         opt.dropout, opt.init_type,
                                         opt.init_gain, opt.gpuIds)
        self.netDisB = networks.define_D(opt.inputCh, opt.inputCh,
                                         opt.which_model_netD, opt.n_layers_D,
                                         opt.norm, not opt.lsgan,
                                         opt.init_type, opt.init_gain,
                                         opt.gpuIds)

        self.netFeatures = self.initNet(
            DRNSegBase(model_name=opt.segNet,
                       n_class=opt.nClass,
                       input_ch=opt.inputCh))
        self.netClassifier1 = self.initNet(
            DRNSegPixelClassifier(n_class=opt.nClass))
        self.netClassifier2 = self.initNet(
            DRNSegPixelClassifier(n_class=opt.nClass))

        self.set_requires_grad([
            self.netGenA, self.netGenB, self.netDisA, self.netDisB,
            self.netFeatures, self.netClassifier1, self.netClassifier2
        ], True)

        # define image pool
        self.fakeAPool = ImagePool(opt.pool_size)
        self.fakeBPool = ImagePool(opt.pool_size)

        # initialize optimizers
        self.optimizerG = getOptimizer(itertools.chain(
            self.netGenA.parameters(), self.netGenB.parameters()),
                                       opt=opt.cycleOpt,
                                       lr=opt.lr,
                                       beta1=opt.beta1,
                                       momentum=opt.momentum,
                                       weight_decay=opt.weight_decay)
        self.optimizerD = getOptimizer(itertools.chain(
            self.netDisA.parameters(), self.netDisB.parameters()),
                                       opt=opt.cycleOpt,
                                       lr=opt.lr,
                                       beta1=opt.beta1,
                                       momentum=opt.momentum,
                                       weight_decay=opt.weight_decay)
        self.optimizerF = getOptimizer(itertools.chain(
            self.netFeatures.parameters()),
                                       opt=opt.mcdOpt,
                                       lr=opt.lr,
                                       beta1=opt.beta1,
                                       momentum=opt.momentum,
                                       weight_decay=opt.weight_decay)
        self.optimizerC = getOptimizer(itertools.chain(
            self.netClassifier1.parameters(),
            self.netClassifier2.parameters()),
                                       opt=opt.mcdOpt,
                                       lr=opt.lr,
                                       beta1=opt.beta1,
                                       momentum=opt.momentum,
                                       weight_decay=opt.weight_decay)
        self.optimizers = []
        self.optimizers.append(self.optimizerG)
        self.optimizers.append(self.optimizerD)
        self.optimizers.append(self.optimizerF)
        self.optimizers.append(self.optimizerC)

        self.colorize = Colorize()
        print('--------------------------------------------------')

    def name(self):
        return 'CycleMcdModel'

    def current_images(self):
        imageNames = [
            'realA', 'fakeA', 'recA', 'idtA', 'realB', 'fakeB', 'recB', 'idtB',
            'supervisedA', 'supervisedB', 'unsupervisedA', 'unsupervisedB'
        ]
        segmentationMapNames = [
            'predSupervisedA', 'gndSupervisedA', 'predUnsupervisedA',
            'gndUnsupervisedA', 'predSupervisedB', 'gndSupervisedB',
            'predUnsupervisedB', 'gndUnsupervisedB'
        ]
        visual_ret = OrderedDict()
        for name in self.imageNames:
            if name in imageNames:
                visual_ret[name] = self.invTransform(getattr(self, name)[0])
            elif name in segmentationMapNames:
                visual_ret[name] = \
                    self.colorize(getattr(self,name)[0]).permute(2,0,1).float()/255
            else:
                raise NotImplementedError
        return visual_ret

    def set_input(self, input):
        self.supervisedA = input['supervisedA']['image'].to(self.opt.device)
        self.gndSupervisedA = input['supervisedA']['label'].to(self.opt.device)
        self.unsupervisedA = input['unsupervisedA']['image'].to(
            self.opt.device)
        self.gndUnsupervisedA = input['unsupervisedA']['label'].to(
            self.opt.device)
        self.supervisedB = input['supervisedB']['image'].to(self.opt.device)
        self.gndSupervisedB = input['supervisedB']['label'].to(self.opt.device)
        self.unsupervisedB = input['unsupervisedB']['image'].to(
            self.opt.device)
        self.gndUnsupervisedB = input['unsupervisedB']['label'].to(
            self.opt.device)

    def forward(self):
        '''
        self.predSupervisedA = self.forwardSegmentation(self.supervisedA)
        self.predUnsupervisedA = self.forwardSegmentation(self.unsupervisedA)
        self.predSupervisedB = self.forwardSegmentation(self.supervisedB)
        self.predUnsupervisedB = self.forwardSegmentation(self.unsupervisedB)
        '''

    def backward_dis_basic(self, netDis, real, fake):
        # Real
        predReal = netDis(real)
        lossDisReal = self.criterionGAN(predReal, True)
        # Fake
        predFake = netDis(fake.detach())
        lossDisFake = self.criterionGAN(predFake, False)
        # Combined loss
        lossDis = (lossDisReal + lossDisFake) * 0.5
        # backward
        lossDis.backward()
        return float(lossDis)

    def backward_dis_A(self):
        fakeA = self.fakeAPool.query(self.fakeA)
        self.lossDisA = self.backward_dis_basic(self.netDisA, self.realA,
                                                fakeA)
        self.fakeA = self.fakeA.to('cpu')

    def backward_dis_B(self):
        fakeB = self.fakeBPool.query(self.fakeB)
        self.lossDisB = self.backward_dis_basic(self.netDisB, self.realB,
                                                fakeB)
        self.fakeB = self.fakeB.to('cpu')

    def backward_gen(self, retain_graph=False):
        lambdaIdt = self.opt.lambdaIdentity
        lambdaA = self.opt.lambdaA
        lambdaB = self.opt.lambdaB
        # Identity loss
        self.realA = torch.cat([self.supervisedA, self.unsupervisedA], 0)
        self.realB = torch.cat([self.supervisedB, self.unsupervisedB], 0)
        self.fakeA = self.netGenB(self.realB)
        self.fakeB = self.netGenA(self.realA)
        self.recA = self.netGenB(self.fakeB)
        self.recB = self.netGenA(self.fakeA)
        if lambdaIdt > 0:
            # GenB should be identity if realA is fed.
            self.idtA = self.netGenB(self.realA)
            lossIdtA = self.criterionIdt(self.idtA,
                                         self.realA) * lambdaA * lambdaIdt
            # GenA should be identity if realB is fed.
            self.idtB = self.netGenA(self.realB)
            lossIdtB = self.criterionIdt(self.idtB,
                                         self.realB) * lambdaB * lambdaIdt
        else:
            lossIdtA = 0
            lossIdtB = 0

        # GAN D loss
        lossGenA = self.criterionGAN(self.netDisB(self.fakeB), True)
        # GAN D loss
        lossGenB = self.criterionGAN(self.netDisA(self.fakeA), True)
        # Forward cycle loss
        lossCycleA = self.criterionCycle(self.recA, self.realA) * lambdaA
        # Backward cycle loss
        lossCycleB = self.criterionCycle(self.recB, self.realB) * lambdaB
        # combined loss
        lossG = lossGenA + lossGenB + lossCycleA + lossCycleB + lossIdtA + lossIdtB
        lossG.backward(retain_graph=retain_graph)
        # move image to cpu
        self.realA = self.realA.to('cpu')
        self.realB = self.realB.to('cpu')
        self.recA = self.recA.to('cpu')
        self.recB = self.recB.to('cpu')
        self.recA = self.recA.to('cpu')
        self.recB = self.recB.to('cpu')
        self.lossGenA = float(lossGenA)
        self.lossGenB = float(lossGenB)
        self.lossCycleA = float(lossCycleA)
        self.lossCycleB = float(lossCycleB)
        self.lossIdtA = float(lossIdtA)
        self.lossIdtB = float(lossIdtB)

    def optimize_parameters_cyclegan(self):
        # GenA and GenB
        self.set_requires_grad([self.netDisA, self.netDisB], False)
        self.optimizerG.zero_grad()
        self.backward_gen()
        self.optimizerG.step()
        # DisA and DisB
        self.set_requires_grad([self.netDisA, self.netDisB], True)
        self.optimizerD.zero_grad()
        self.backward_dis_A()
        self.backward_dis_B()
        self.optimizerD.step()

    def forward_mcd(self, data):
        feature = self.netFeatures(data)
        pred1 = self.netClassifier1(feature)
        pred2 = self.netClassifier2(feature)
        return pred1, pred2

    def backward_supervised(self, retain_graph=False):
        supervised = self.concate_from_A(self.supervisedA)
        gnd = self.gndSupervisedA.repeat(2, 1, 1)

        feature = self.netFeatures(supervised)
        supervisedPred1 = self.netClassifier1(feature)
        supervisedPred2 = self.netClassifier2(feature)
        lossSupervisedA = self.criterionSeg(supervisedPred1, gnd) \
            + self.criterionSeg(supervisedPred2, gnd)
        lossSupervisedA.backward(retain_graph=retain_graph)

        self.predSupervisedA = (supervisedPred1 +
                                supervisedPred2).argmax(1).to('cpu')
        self.miouSupervisedA.update(self.predSupervisedA, gnd)

        supervised = self.concate_from_B(self.supervisedB)
        gnd = self.gndSupervisedB.repeat(2, 1, 1)

        feature = self.netFeatures(supervised)
        supervisedPred1 = self.netClassifier1(feature)
        supervisedPred2 = self.netClassifier2(feature)
        lossSupervisedB = self.criterionSeg(supervisedPred1, gnd) \
            + self.criterionSeg(supervisedPred2, gnd)
        lossSupervisedB.backward(retain_graph=retain_graph)

        self.predSupervisedB = (supervisedPred1 +
                                supervisedPred2).argmax(1).to('cpu')
        self.miouSupervisedB.update(self.predSupervisedB, gnd)

        self.lossSupervised = float(lossSupervisedA) + float(lossSupervisedB)

    def backward_unsupervised_classifier(self, retain_graph=False):
        # A domain
        supervised = self.concate_from_A(self.supervisedA)
        supervisedGnd = self.gndSupervisedA.repeat(2, 1, 1)
        unsupervised = self.concate_from_A(self.unsupervisedA)
        unsupervisedGnd = self.gndUnsupervisedA.repeat(2, 1, 1)

        # forward supervised
        supervisedPred1, supervisedPred2 = self.forward_mcd(supervised)
        # forward unsupervised
        unsupervisedPred1, unsupervisedPred2 = self.forward_mcd(unsupervised)

        lossUnsupervisedClassifierA = self.criterionSeg(supervisedPred1, supervisedGnd) \
            + self.criterionSeg(supervisedPred2, supervisedGnd) \
            - self.criterionDis(unsupervisedPred1, unsupervisedPred2)
        lossUnsupervisedClassifierA.backward(retain_graph=retain_graph)

        self.predUnsupervisedA = (unsupervisedPred1 +
                                  unsupervisedPred2).argmax(1).to('cpu')
        self.miouUnsupervisedA.update(self.predUnsupervisedA, unsupervisedGnd)
        # B domain
        supervised = self.concate_from_B(self.supervisedB)
        supervisedGnd = self.gndSupervisedB.repeat(2, 1, 1)
        unsupervised = self.concate_from_B(self.unsupervisedB)
        unsupervisedGnd = self.gndUnsupervisedB.repeat(2, 1, 1)

        # forward supervised
        supervisedPred1, supervisedPred2 = self.forward_mcd(supervised)
        # forward unsupervised
        unsupervisedPred1, unsupervisedPred2 = self.forward_mcd(unsupervised)

        lossUnsupervisedClassifierB = self.criterionSeg(supervisedPred1, supervisedGnd) \
            + self.criterionSeg(supervisedPred2, supervisedGnd) \
            - self.criterionDis(unsupervisedPred1, unsupervisedPred2)
        lossUnsupervisedClassifierB.backward(retain_graph=retain_graph)

        self.predUnsupervisedB = (unsupervisedPred1 +
                                  unsupervisedPred2).argmax(1).to('cpu')
        self.miouUnsupervisedB.update(self.predUnsupervisedB, unsupervisedGnd)

        self.lossUnsupervisedClassifier = float(lossUnsupervisedClassifierA) + \
                float(lossUnsupervisedClassifierB)

    def backward_unsupervised_feature(self, retain_graph=False):
        # A domain
        unsupervised = self.concate_from_A(self.unsupervisedA)
        # forward unsupervised
        unsupervisedPred1, unsupervisedPred2 = self.forward_mcd(unsupervised)

        lossUnsupervisedFeatureA = self.criterionDis(unsupervisedPred1, unsupervisedPred2) \
                * self.opt.nTimesDLoss
        lossUnsupervisedFeatureA.backward(retain_graph=retain_graph)

        # B domain
        unsupervised = self.concate_from_B(self.unsupervisedB)
        # forward unsupervised
        unsupervisedPred1, unsupervisedPred2 = self.forward_mcd(unsupervised)

        lossUnsupervisedFeatureB = self.criterionDis(unsupervisedPred1, unsupervisedPred2) \
                * self.opt.nTimesDLoss
        lossUnsupervisedFeatureB.backward(retain_graph=retain_graph)

        self.lossUnsupervisedFeature = float(lossUnsupervisedFeatureA) + \
                float(lossUnsupervisedFeatureB)

    def concate_from_A(self, A):
        B = self.netGenA(A)
        return torch.cat([A, B], 0)

    def concate_from_B(self, B):
        A = self.netGenB(B)
        return torch.cat([A, B], 0)

    def optimize_parameters_mcd(self):
        # update F and C for Source
        self.set_requires_grad([self.netClassifier1, self.netClassifier2],
                               True)
        self.optimizerF.zero_grad()
        self.optimizerC.zero_grad()
        self.backward_supervised(retain_graph=False)
        self.optimizerF.step()
        self.optimizerC.step()
        # update C for Target
        self.set_requires_grad([self.netFeatures], False)
        self.optimizerC.zero_grad()
        self.backward_unsupervised_classifier()
        self.optimizerC.step()
        # update F for Target
        self.set_requires_grad([self.netFeatures], True)
        self.set_requires_grad([self.netClassifier1, self.netClassifier2],
                               False)
        for i in range(self.opt.k):
            self.optimizerG.zero_grad()
            self.optimizerF.zero_grad()
            self.backward_unsupervised_feature()
            self.optimizerG.step()
            self.optimizerF.step()

    def optimize_parameters(self):
        self.optimize_parameters_cyclegan()
        self.optimize_parameters_mcd()