コード例 #1
0
def test():
  checkpoints_dir = "checkpoints/{}".format(nameNet)
  if FLAGS.load_model is None:
    try:
      os.makedirs(checkpoints_dir)
    except os.error:
      pass

  graph = tf.Graph()
  with graph.as_default():
    paired_gan = PairedGANDisen(
        XY_train_file=FLAGS.XY,
        batch_size=FLAGS.batch_size,
        image_size=FLAGS.image_size,
        use_lsgan=FLAGS.use_lsgan,
        norm=FLAGS.norm,
        lambdaRecon=FLAGS.lambdaRecon,
        lambdaAlign=FLAGS.lambdaAlign,
        lambdaRev=FLAGS.lambdaRev,
        learning_rate=FLAGS.learning_rate,
        beta1=FLAGS.beta1,
        nfS=FLAGS.nfS,
        nfE=FLAGS.nfE
    )
    G_loss, D_Y_loss, F_loss, D_X_loss, A_loss, DC_loss, fake_y, fake_x = paired_gan.model()
    optimizers = paired_gan.optimize(G_loss, D_Y_loss, F_loss, D_X_loss, A_loss, DC_loss)

    summary_op = tf.summary.merge_all()
    train_writer = tf.summary.FileWriter(checkpoints_dir, graph)
    saver = tf.train.Saver()

  with tf.Session(graph=graph) as sess:
    if FLAGS.load_model is not None:
      checkpoint = tf.train.get_checkpoint_state(checkpoints_dir)
      meta_graph_path = checkpoint.model_checkpoint_path + ".meta"
      restore = tf.train.import_meta_graph(meta_graph_path)
      restore.restore(sess, tf.train.latest_checkpoint(checkpoints_dir))
      step = 0

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    try:
      fake_Y_pool = ImagePool(FLAGS.pool_size)
      fake_X_pool = ImagePool(FLAGS.pool_size)

      while not coord.should_stop():
        # get previously generated images
        fake_y_val, fake_x_val = sess.run([fake_y, fake_x])
        pdb.set_trace()

    except KeyboardInterrupt:
      logging.info('Interrupted')
      coord.request_stop()
    except Exception as e:
      coord.request_stop(e)
    finally:
      # When done, ask the threads to stop.
      coord.request_stop()
      coord.join(threads)
コード例 #2
0
    def init_architecture(self, opt) :
        self.opt = opt 
        self.netG_A = define_G(opt.in_nc, opt.out_nc, opt.nz, opt.ngf, which_model_netG=opt.G_model)
        if opt.use_gpu :
            self.netG_A.cuda()

        if opt.isTrain :
            self.netD_A = define_D(opt.in_nc, opt.ngf, 'basic_128')
            self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(),
                                                        lr=opt.lr, betas=(opt.beta1, 0.999))

            self.netG_B = define_G(opt.in_nc, opt.out_nc, opt.nz, opt.ngf, which_model_netG=opt.G_model)
            self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()),
                                                            lr=opt.lr, betas=(opt.beta1, 0.999))

            self.netD_B = define_D(opt.in_nc, opt.ngf, 'basic_128')
            self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(),
                                                        lr=opt.lr, betas=(opt.beta1, 0.999))

            if opt.use_gpu :
                self.netD_A.cuda()
                self.netG_B.cuda()
                self.netD_B.cuda()

            self.optimizers = [self.optimizer_G, self.optimizer_D_A, self.optimizer_D_B]

            self.fake_A_pool = ImagePool(opt.pool_size)
            self.fake_B_pool = ImagePool(opt.pool_size)
コード例 #3
0
ファイル: cyclechain.py プロジェクト: kiminh/Cycle_Dynamics
    def __init__(self,opt):
        self.opt = opt
        self.dynamic = opt.dynamic
        self.isTrain = opt.istrain
        self.Tensor = torch.cuda.FloatTensor

        # load/define networks
        # The naming conversion is different from those used in the paper
        # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)

        self.netG_A = state2img().cuda()
        self.netG_B = img2state().cuda()
        self.netF_A = Fmodel().cuda()
        self.dataF = CDFdata.get_loader(opt)
        self.train_forward()

        if self.isTrain:
            self.netD_A = imgDmodel().cuda()
            self.netD_B = stateDmodel().cuda()

        if self.isTrain:
            self.fake_A_pool = ImagePool(pool_size=128)
            self.fake_B_pool = ImagePool(pool_size=128)
            # define loss functions
            self.criterionGAN = GANLoss(tensor=self.Tensor).cuda()
            if opt.loss == 'l1':
                self.criterionCycle = torch.nn.L1Loss()
                self.criterionIdt = torch.nn.L1Loss()
            elif opt.loss == 'l2':
                self.criterionCycle = torch.nn.MSELoss()
                self.criterionIdt = torch.nn.MSELoss()
            # initialize optimizers
            # self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()))
            self.optimizer_G = torch.optim.Adam([{'params':self.netG_A.parameters(),'lr':1e-3},
                                                 {'params':self.netF_A.parameters(),'lr':0.0},
                                                 {'params':self.netG_B.parameters(),'lr':1e-3}])
            self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters())
            self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters())
            self.optimizers = []
            self.schedulers = []
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D_A)
            self.optimizers.append(self.optimizer_D_B)
            for optimizer in self.optimizers:
                # self.schedulers.append(networks.get_scheduler(optimizer, opt))
                self.schedulers.append(optimizer)

        print('---------- Networks initialized -------------')
        # networks.print_network(self.netG_A)
        # networks.print_network(self.netG_B)
        # if self.isTrain:
        #     networks.print_network(self.netD_A)
        #     networks.print_network(self.netD_B)
        print('-----------------------------------------------')
コード例 #4
0
ファイル: model.py プロジェクト: eric-kong/Cycle-GAN
    def __init__(self,
                 sess,
                 name="cylegan",
                 dataset="horse2zebra",
                 image_size=256,
                 batch_size=1,
                 ngf=64,
                 ndf=64,
                 lambda1=10,
                 lambda2=10,
                 beta1=0.5):
        self.image_size = image_size
        self.ngf = ngf
        self.ndf = ndf

        self.lambda1 = lambda1

        self.lambda2 = lambda2

        # adam
        self.beta1 = beta1

        self.batch_size = batch_size

        self.sess = sess

        self.dataset = dataset

        self.pool_fake_X = ImagePool()

        self.pool_fake_Y = ImagePool()

        # generate images in the domain X
        self.generator_X = Generator(name="generator_X",
                                     ngf=self.ngf,
                                     image_size=image_size)
        # generate images in the domain Y
        self.generator_Y = Generator(name="generator_Y",
                                     ngf=self.ngf,
                                     image_size=image_size)
        # discriminate images in the domain X
        self.discriminator_X = Discriminator(name="discriminator_X",
                                             ndf=self.ndf)
        # discriminate images in the domain Y
        self.discriminator_Y = Discriminator(name="discriminator_Y",
                                             ndf=self.ndf)

        self._build_model()
        print("[SUCCEED] build up model")
        self._init_optimizer()
        print("[SUCCEED] initialize optimizers")
コード例 #5
0
    def __init__(self, opt):
        self.opt = opt
        self.dynamic = opt.dynamic
        self.isTrain = opt.istrain
        self.Tensor = torch.cuda.FloatTensor

        # load/define networks
        # The naming conversion is different from those used in the paper
        # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)

        self.netG_B = img2state().cuda()
        self.netF_A = Fmodel().cuda()
        self.netAction = Amodel().cuda()
        self.dataF = CDFdata.get_loader(opt)
        self.train_forward(pretrained=True)

        self.gt_buffer = []
        self.pred_buffer = []

        # if self.isTrain:
        self.netD_B = stateDmodel().cuda()

        # if self.isTrain:
        self.fake_A_pool = ImagePool(pool_size=128)
        self.fake_B_pool = ImagePool(pool_size=128)
        # define loss functions
        self.criterionGAN = GANLoss(tensor=self.Tensor).cuda()
        if opt.loss == 'l1':
            self.criterionCycle = torch.nn.L1Loss()
            self.criterionIdt = torch.nn.L1Loss()
        elif opt.loss == 'l2':
            self.criterionCycle = torch.nn.MSELoss()
            self.criterionIdt = torch.nn.MSELoss()
        # initialize optimizers
        # self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()))
        self.optimizer_G = torch.optim.Adam([{
            'params': self.netF_A.parameters(),
            'lr': self.opt.F_lr
        }, {
            'params': self.netG_B.parameters(),
            'lr': self.opt.G_lr
        }, {
            'params':
            self.netAction.parameters(),
            'lr':
            self.opt.G_lr
        }])
        self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters())

        print('---------- Networks initialized ---------------')
        print('-----------------------------------------------')
コード例 #6
0
ファイル: model.py プロジェクト: gogobd/CycleGAN-tensorflow
    def __init__(self, sess, args):
        self.sess = sess
        self.image_size = args.fine_size
        self.input_c_dim = args.input_nc
        self.output_c_dim = args.output_nc
        self.L1_lambda = args.L1_lambda
        self.dataset_dir = args.dataset_dir
        self.stdscr = args.stdscr

        self.discriminator = discriminator
        if args.use_resnet:
            self.generator = generator_resnet
        else:
            self.generator = generator_unet
        if args.use_lsgan:
            self.criterionGAN = mae_criterion
        else:
            self.criterionGAN = sce_criterion

        OPTIONS = namedtuple('OPTIONS', 'image_size \
                              gf_dim df_dim output_c_dim is_training')
        self.options = OPTIONS._make((args.fine_size,
                                      args.ngf, args.ndf, args.output_nc,
                                      args.phase == 'train'))

        self._build_model()
        self.saver = tf.train.Saver()
        self.pool = ImagePool(args.max_size)
コード例 #7
0
ファイル: model.py プロジェクト: fhfonsecaa/SG-GAN-TF2
    def __init__(self, args):
        self.batch_size = args.batch_size
        self.image_width = args.image_width
        self.image_height = args.image_height
        self.input_c_dim = args.input_nc
        self.output_c_dim = args.output_nc
        self.L1_lambda = args.L1_lambda
        self.Lg_lambda = args.Lg_lambda
        self.dataset_dir = args.dataset_dir
        self.segment_class = args.segment_class
        self.alpha_recip = 1. / args.ratio_gan2seg if args.ratio_gan2seg > 0 else 0

        self.use_pix2pix = args.use_pix2pix

        self.discriminator = discriminator()
        if args.use_resnet:
            self.generator = generator_resnet()
        else:
            if args.use_pix2pix:
                self.generator = generator_pix2pix()
                self.discriminator = discriminator_pix2pix()
            else:
                self.generator = generator_unet()

        if args.use_lsgan:
            self.criterionGAN = mae_criterion
        else:
            self.criterionGAN = sce_criterion

        # tf.keras.utils.plot_model(self.discriminator, 'multi_input_and_output_model.png', show_shapes=True)
        # input("")

        OPTIONS = namedtuple(
            'OPTIONS', 'batch_size image_height image_width \
                              gf_dim df_dim output_c_dim is_training segment_class'
        )
        self.options = OPTIONS._make(
            (args.batch_size, args.image_height, args.image_width, args.ngf,
             args.ndf, args.output_nc, args.phase == 'train',
             args.segment_class))

        self._build_model(args)
        self.pool = ImagePool(args.max_size)

        #### [ADDED] CHECKPOINT MANAGER
        self.lr = 0.001
        self.d_optim = tf.keras.optimizers.Adam(learning_rate=self.lr,
                                                beta_1=args.beta1)
        self.g_optim = tf.keras.optimizers.Adam(learning_rate=self.lr,
                                                beta_1=args.beta1)

        self.gen_ckpt = tf.train.Checkpoint(optimizer=self.g_optim,
                                            net=self.generator)
        self.disc_ckpt = tf.train.Checkpoint(optimizer=self.d_optim,
                                             net=self.discriminator)
        self.gen_ckpt_manager = tf.train.CheckpointManager(
            self.gen_ckpt, './checkpoint/gta/gen_ckpts', max_to_keep=3)
        self.disc_ckpt_manager = tf.train.CheckpointManager(
            self.disc_ckpt, './checkpoint/gta/disc_ckpts', max_to_keep=3)
コード例 #8
0
def test():
    graph = tf.Graph()
    with graph.as_default():
        MmNet_model = MmNet(batch_size=FLAGS.batch_size,
                            image_size=FLAGS.image_size,
                            use_lsgan=FLAGS.use_lsgan,
                            norm=FLAGS.norm,
                            lambda1=FLAGS.lambda1,
                            lambda2=FLAGS.lambda1,
                            learning_rate=FLAGS.learning_rate,
                            beta1=FLAGS.beta1,
                            ngf=FLAGS.ngf,
                            number_domain=FLAGS.number_domain,
                            train_file=FLAGS.test_file)
        G_loss, D_Y_loss, F_loss, D_X_loss = MmNet_model.model()
        sess.run(tf.global_variables_initializer())
        # It is batter to import trainable variables instead of all graph, for it will need lots of memory and time, even it fails
        saver = tf.train.Saver(cycle_gan.F_train_var + cycle_gan.G_train_var)
        saver.restore(sess, FLAGS.saved_model)
        acount = 1
        domain_idx = 0
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        try:
            fake_pool = [
                ImagePool(FLAGS.pool_size) for i in xrange(FLAGS.number_domain)
            ]

            while not coord.should_stop():
                # get previously generated images
                #Probility = np.random.randint(2,size = 1)
                fake_ = [MmNet.loss[domain_idx][-1]] + [
                    MmNet.loss[i][-2] for i in xrange(FLAGS.number_domain - 1)
                ]
                fake_gene = sess.run(fake_)
                #fake_y_val, fake_x_val,fake_z_val,fake_x_from_z_val = sess.run([fake_y, fake_x, fake_z, fake_x_from_z])
                feed_dict = {
                    MmNet.fake_set[i]: fake_pool[i].query(fake_gene[i])
                    for i in xrange(FLAGS.number_domain)
                }
                raw_image_generated_images = sess.run(
                    MmNet.raw_image_generated_images, feed_dict)
                visualization(raw_image_generated_images, acount, labels)
                acount += 1
                domain_idx += 1
                print acount
        except KeyboardInterrupt:
            logging.info('Interrupted')
            coord.request_stop()
        except Exception as e:
            coord.request_stop(e)
        finally:
            #save_path = saver.save(sess, checkpoints_dir + "/model.ckpt", global_step=step)
            #logging.info("Model saved in file: %s" % save_path)
            # When done, ask the threads to stop.
            coord.request_stop()
            coord.join(threads)
コード例 #9
0
ファイル: cycleGAN.py プロジェクト: roshandhakal/Bokeh-Effect
 def __init__(self, opt):
     super(cycleGan, self).__init__()
     self.opt = opt
     self.G_XtoY = RensetGenerator(opt.input_nc, opt.output_nc, opt.ngf,
                                   opt.norm, not opt.no_dropout,
                                   opt.n_blocks,
                                   opt.padding_type).to(device)
     # self.G_YtoX = CycleGenerator(d_conv_dim, res_blocks)
     self.G_YtoX = RensetGenerator(opt.output_nc, opt.input_nc, opt.ngf,
                                   opt.norm, not opt.no_dropout,
                                   opt.n_blocks,
                                   opt.padding_type).to(device)
     # self.D_X = CycleDiscriminator(d_conv_dim)
     self.D_X = PatchDiscriminator(opt.output_nc, opt.ndf, opt.n_layers_D,
                                   opt.norm).to(device)
     # self.D_Y = CycleDiscriminator(d_conv_dim)
     self.D_Y = PatchDiscriminator(opt.input_nc, opt.ndf, opt.n_layers_D,
                                   opt.norm).to(device)
     # self.G_XtoY.apply(init_weights)
     # self.G_YtoX.apply(init_weights)
     # self.D_X.apply(init_weights)
     # self.D_Y.apply(init_weights)
     print(self.G_XtoY)
     print("Parameters: ", len(list(self.G_XtoY.parameters())))
     print(self.G_YtoX)
     print("Parameters: ", len(list(self.G_YtoX.parameters())))
     print(self.D_X)
     print("Parameters: ", len(list(self.D_X.parameters())))
     print(self.D_Y)
     print("Parameters: ", len(list(self.D_Y.parameters())))
     self.fake_X_pool = ImagePool(opt.pool_size)
     self.fake_Y_pool = ImagePool(opt.pool_size)
     self.criterionCycle = torch.nn.L1Loss()
     self.criterionIdt = torch.nn.L1Loss()  #For implementing Identity Loss
     self.optimizer_G = torch.optim.Adam(itertools.chain(
         self.G_XtoY.parameters(), self.G_YtoX.parameters()),
                                         lr=opt.lr,
                                         betas=[opt.beta1, 0.999])
     self.optimizer_DX = torch.optim.Adam(self.D_X.parameters(),
                                          lr=opt.lr,
                                          betas=[opt.beta1, 0.999])
     self.optimizer_DY = torch.optim.Adam(self.D_Y.parameters(),
                                          lr=opt.lr,
                                          betas=[opt.beta1, 0.999])
コード例 #10
0
    def __init__(self, args):
        self.batch_size = args.batch_size
        self.time_step = args.time_step  # number of time steps
        self.pitch_range = args.pitch_range  # number of pitches
        self.input_c_dim = args.input_nc  # number of input image channels
        self.output_c_dim = args.output_nc  # number of output image channels
        self.lr = args.lr
        self.L1_lambda = args.L1_lambda
        self.gamma = args.gamma
        self.sigma_d = args.sigma_d
        self.dataset_A_dir = args.dataset_A_dir
        self.dataset_B_dir = args.dataset_B_dir
        self.d_loss_path = args.d_loss_path
        self.g_loss_path = args.g_loss_path
        self.cycle_loss_path = args.cycle_loss_path
        self.sample_dir = args.sample_dir

        self.model = args.model
        self.discriminator = build_discriminator
        self.generator = build_generator
        self.criterionGAN = mae_criterion

        OPTIONS = namedtuple(
            "OPTIONS",
            "batch_size "
            "time_step "
            "input_nc "
            "output_nc "
            "pitch_range "
            "gf_dim "
            "df_dim "
            "is_training",
        )
        self.options = OPTIONS._make(
            (
                args.batch_size,
                args.time_step,
                args.pitch_range,
                args.input_nc,
                args.output_nc,
                args.ngf,
                args.ndf,
                args.phase == "train",
            )
        )

        self.now_datetime = get_now_datetime()
        self.pool = ImagePool(args.max_size)

        self._build_model(args)

        print("Initialized model.")
コード例 #11
0
    def __init__(self, parsed_args, parsed_groups):
        super().__init__(**parsed_groups['trainer arguments'])

        self.gen = UNetGenerator(**parsed_groups['generator arguments'])
        self.disc = NLayerDiscriminator(
            **parsed_groups['discriminator arguments'])

        init_weights(self.gen)
        init_weights(self.disc)

        self.real_label = torch.tensor(1.0)
        self.fake_label = torch.tensor(0.0)

        self.crit = torch.nn.MSELoss()  # LSGAN
        self.image_pool = ImagePool(parsed_args.pool_size,
                                    parsed_args.replay_prob)

        self.sel_ind = 0
        self.un_normalize = lambda x: 255. * (1 + x.clamp(min=-1, max=1)) / 2.

        self.parsed_args = parsed_args
        self.n_vis = parsed_args.n_vis
        self.vis = Visualizer()
コード例 #12
0
    def __init__(self, sess, args):
        self.sess = sess

        self.im_size = args.im_size
        self.input_nc = args.input_nc
        self.output_nc = args.output_nc
        self.lambda1 = args.lambda1
        self.dataset_dir = args.dataset_dir

        self.generator = generator
        self.discriminator = discriminator

        self.gan_loss = gan_loss
        self.cyc_loss = cyc_loss

        OPTIONS = namedtuple('OPTIONS', 'gf_dim df_dim output_nc im_size')
        self.options = OPTIONS._make(
            (args.ngf, args.ndf, args.output_nc, args.im_size))

        self.pool = ImagePool(args.max_pool)

        self._build_model()
        self.saver = tf.train.Saver()
コード例 #13
0
def train():
  if FLAGS.load_model is not None:
    checkpoints_dir = "checkpoints/" + FLAGS.load_model.lstrip("checkpoints/")
  else:
    #current_time = datetime.now().strftime("%Y%m%d-%H%M")
    checkpoints_dir = "checkpoints/{}".format(nameNet)
    try:
      os.makedirs(checkpoints_dir)
    except os.error:
      pass

  graph = tf.Graph()
  with graph.as_default():
    paired_gan = PairedGANDisenRevFullRep(
        XY_train_file=FLAGS.XY,
        batch_size=FLAGS.batch_size,
        image_size=FLAGS.image_size,
        use_lsgan=FLAGS.use_lsgan,
        norm=FLAGS.norm,
        lambdaRecon=FLAGS.lambdaRecon,
        lambdaAlign=FLAGS.lambdaAlign,
        lambdaRev=FLAGS.lambdaRev,
        learning_rate=FLAGS.learning_rate,
        beta1=FLAGS.beta1,
        nfS=FLAGS.nfS,
        nfE=FLAGS.nfE
    )
    G_loss, D_Y_loss, Dex_Y_loss, F_loss, D_X_loss, Dex_X_loss, A_loss, DC_loss, fake_y, fake_x, fake_ex_y, fake_ex_x = paired_gan.model()
    #G_loss, D_Y_loss, F_loss, D_X_loss, A_loss, fake_y, fake_x = paired_gan.model()
    optimizers = paired_gan.optimize(G_loss, D_Y_loss, Dex_Y_loss, F_loss, D_X_loss, Dex_X_loss, A_loss, DC_loss)

    summary_op = tf.summary.merge_all()
    train_writer = tf.summary.FileWriter(checkpoints_dir, graph)
    saver = tf.train.Saver()

  with tf.Session(graph=graph) as sess:
    if FLAGS.load_model is not None:
      checkpoint = tf.train.get_checkpoint_state(checkpoints_dir)
      meta_graph_path = checkpoint.model_checkpoint_path + ".meta"
      restore = tf.train.import_meta_graph(meta_graph_path)
      restore.restore(sess, tf.train.latest_checkpoint(checkpoints_dir))
      step = int(meta_graph_path.split("-")[2].split(".")[0])
    else:
      sess.run(tf.global_variables_initializer())
      step = 0

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    try:
      fake_Y_pool = ImagePool(FLAGS.pool_size)
      fake_X_pool = ImagePool(FLAGS.pool_size)

      fake_ex_Y_pool = ImagePool(FLAGS.pool_size)
      fake_ex_X_pool = ImagePool(FLAGS.pool_size)


      while not coord.should_stop():
        # get previously generated images
        fake_y_val, fake_x_val, fake_ex_y_val, fake_ex_x_val = sess.run([fake_y, fake_x, fake_ex_y, fake_ex_x])

        train
        _, G_loss_val, D_Y_loss_val, F_loss_val, D_X_loss_val, A_loss_val, DC_loss_val, summary = (
              sess.run(
                  [optimizers, G_loss, D_Y_loss, F_loss, D_X_loss, A_loss, DC_loss, summary_op],
                  feed_dict={paired_gan.fake_y: fake_Y_pool.query(fake_y_val),
                             paired_gan.fake_x: fake_X_pool.query(fake_x_val),
                             paired_gan.fake_ex_y: fake_ex_Y_pool.query(fake_y_val),
                             paired_gan.fake_ex_x: fake_ex_X_pool.query(fake_x_val)}
              )
        )
        #_, G_loss_val, D_Y_loss_val, F_loss_val, D_X_loss_val, A_loss_val, summary = (
              #sess.run(
                  #[optimizers, G_loss, D_Y_loss, F_loss, D_X_loss, A_loss, summary_op],
                  #feed_dict={paired_gan.fake_y: fake_Y_pool.query(fake_y_val),
                             #paired_gan.fake_x: fake_X_pool.query(fake_x_val)}
              #)
        #)
        train_writer.add_summary(summary, step)
        train_writer.flush()

        if step % 100 == 0:
          logging.info('-----------Step %d:-------------' % step)
          logging.info('  G_loss   : {}'.format(G_loss_val))
          logging.info('  D_Y_loss : {}'.format(D_Y_loss_val))
          logging.info('  F_loss   : {}'.format(F_loss_val))
          logging.info('  D_X_loss : {}'.format(D_X_loss_val))
          logging.info('  A_loss : {}'.format(A_loss_val))

        if step % 10000 == 0:
          save_path = saver.save(sess, checkpoints_dir + "/model.ckpt", global_step=step)
          logging.info("Model saved in file: %s" % save_path)

        step += 1

    except KeyboardInterrupt:
      logging.info('Interrupted')
      coord.request_stop()
    except Exception as e:
      coord.request_stop(e)
    finally:
      save_path = saver.save(sess, checkpoints_dir + "/model.ckpt", global_step=step)
      logging.info("Model saved in file: %s" % save_path)
      # When done, ask the threads to stop.
      coord.request_stop()
      coord.join(threads)
コード例 #14
0
# device
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# Networks
G_A2B = Generator(input_dim=3, n_blocks=N_BLOCKS).to(device)
G_B2A = Generator(input_dim=3, n_blocks=N_BLOCKS).to(device)
D_A = Discriminator(input_dim=3).to(device)
D_B = Discriminator(input_dim=3).to(device)
G_A2B.apply(init_weights)
G_B2A.apply(init_weights)
D_A.apply(init_weights)
D_B.apply(init_weights)

# ImagePool
fake_A_pool = ImagePool(size=50)
fake_B_pool = ImagePool(size=50)

# loss
Loss_GAN = nn.MSELoss()
Loss_cyc = nn.L1Loss()

# optimizer , betas=(0.5, 0.999)
optimizer_G = optim.Adam(itertools.chain(G_A2B.parameters(),
                                         G_B2A.parameters()),
                         lr=LR,
                         betas=(0.5, 0.999))
optimizer_D_A = optim.Adam(D_A.parameters(), lr=LR, betas=(0.5, 0.999))
optimizer_D_B = optim.Adam(D_B.parameters(), lr=LR, betas=(0.5, 0.999))

scheduler_G = optim.lr_scheduler.LambdaLR(optimizer_G, lr_lambda=LR_LAMBDA)
コード例 #15
0
def train(args):
    print(args)

    # net
    netG = Generator()
    netG = netG.cuda()
    netD = Discriminator()
    netD = netD.cuda()

    # loss
    l1_loss = nn.L1Loss().cuda()
    l2_loss = nn.MSELoss().cuda()
    bce_loss = nn.BCELoss().cuda()

    # opt
    optimizerG = optim.Adam(netG.parameters(), lr=args.glr)
    optimizerD = optim.Adam(netD.parameters(), lr=args.dlr)

    # lr
    schedulerG = lr_scheduler.StepLR(optimizerG, args.lr_step_size,
                                     args.lr_gamma)
    schedulerD = lr_scheduler.StepLR(optimizerD, args.lr_step_size,
                                     args.lr_gamma)

    # utility for saving models, parameters and logs
    save = SaveData(args.save_dir, args.exp, True)
    save.save_params(args)

    # netG, _ = save.load_model(netG)

    dataset = MyDataset(args.data_dir, is_train=True)
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=args.batch_size,
                                             shuffle=True,
                                             num_workers=int(args.n_threads))

    real_label = Variable(
        torch.ones([1, 1, args.patch_gan, args.patch_gan],
                   dtype=torch.float)).cuda()
    fake_label = Variable(
        torch.zeros([1, 1, args.patch_gan, args.patch_gan],
                    dtype=torch.float)).cuda()

    image_pool = ImagePool(args.pool_size)

    vgg = Vgg16(requires_grad=False)
    vgg.cuda()

    for epoch in range(args.epochs):
        print("* Epoch {}/{}".format(epoch + 1, args.epochs))

        schedulerG.step()
        schedulerD.step()

        d_total_real_loss = 0
        d_total_fake_loss = 0
        d_total_loss = 0

        g_total_res_loss = 0
        g_total_per_loss = 0
        g_total_gan_loss = 0
        g_total_loss = 0

        netG.train()
        netD.train()

        for batch, images in tqdm(enumerate(dataloader)):
            input_image, target_image = images
            input_image = Variable(input_image.cuda())
            target_image = Variable(target_image.cuda())
            output_image = netG(input_image)

            # Update D
            netD.requires_grad(True)
            netD.zero_grad()

            ## real image
            real_output = netD(target_image)
            d_real_loss = bce_loss(real_output, real_label)
            d_real_loss.backward()
            d_real_loss = d_real_loss.data.cpu().numpy()
            d_total_real_loss += d_real_loss

            ## fake image
            fake_image = output_image.detach()
            fake_image = Variable(image_pool.query(fake_image.data))
            fake_output = netD(fake_image)
            d_fake_loss = bce_loss(fake_output, fake_label)
            d_fake_loss.backward()
            d_fake_loss = d_fake_loss.data.cpu().numpy()
            d_total_fake_loss += d_fake_loss

            ## loss
            d_total_loss += d_real_loss + d_fake_loss

            optimizerD.step()

            # Update G
            netD.requires_grad(False)
            netG.zero_grad()

            ## reconstruction loss
            g_res_loss = l1_loss(output_image, target_image)
            g_res_loss.backward(retain_graph=True)
            g_res_loss = g_res_loss.data.cpu().numpy()
            g_total_res_loss += g_res_loss

            ## perceptual loss
            g_per_loss = args.p_factor * l2_loss(vgg(output_image),
                                                 vgg(target_image))
            g_per_loss.backward(retain_graph=True)
            g_per_loss = g_per_loss.data.cpu().numpy()
            g_total_per_loss += g_per_loss

            ## gan loss
            output = netD(output_image)
            g_gan_loss = args.g_factor * bce_loss(output, real_label)
            g_gan_loss.backward()
            g_gan_loss = g_gan_loss.data.cpu().numpy()
            g_total_gan_loss += g_gan_loss

            ## loss
            g_total_loss += g_res_loss + g_per_loss + g_gan_loss

            optimizerG.step()

        d_total_real_loss = d_total_real_loss / (batch + 1)
        d_total_fake_loss = d_total_fake_loss / (batch + 1)
        d_total_loss = d_total_loss / (batch + 1)
        save.add_scalar('D/real', d_total_real_loss, epoch)
        save.add_scalar('D/fake', d_total_fake_loss, epoch)
        save.add_scalar('D/total', d_total_loss, epoch)

        g_total_res_loss = g_total_res_loss / (batch + 1)
        g_total_per_loss = g_total_per_loss / (batch + 1)
        g_total_gan_loss = g_total_gan_loss / (batch + 1)
        g_total_loss = g_total_loss / (batch + 1)
        save.add_scalar('G/res', g_total_res_loss, epoch)
        save.add_scalar('G/per', g_total_per_loss, epoch)
        save.add_scalar('G/gan', g_total_gan_loss, epoch)
        save.add_scalar('G/total', g_total_loss, epoch)

        if epoch % args.period == 0:
            log = "Train d_loss: {:.5f} \t g_loss: {:.5f}".format(
                d_total_loss, g_total_loss)
            print(log)
            save.save_log(log)
            save.save_model(netG, epoch)
コード例 #16
0
ファイル: models.py プロジェクト: AneesKazi/TrainingOnFly
    def _build_net(self):
        """
        build the computational graph accordding to the specified architectures
        :return: None
        """

        with tf.variable_scope('cycleGAN') as scope:

            # prep input tensors
            if self.mode == InputMode.DATASET:
                self.real_hq = self.inputs['hq'].get_next()
                self.real_lq = self.inputs['lq'].get_next()
            else:
                self.real_hq = self.inputs['hq']
                self.real_lq = self.inputs['lq']

            # build branches that operate on inputs
            with tf.name_scope('LQ_Gen') as lq_gen_scope:
                self.lq.gen.layers = self.build_arch(self.real_hq,
                                                     self.gen_arch, 'LQ_Gen',
                                                     self.graph)
            with tf.name_scope('HQ_Gen') as hq_gen_scope:
                self.hq.gen.layers = self.build_arch(self.real_lq,
                                                     self.gen_arch, 'HQ_Gen',
                                                     self.graph)
            with tf.name_scope('LQ_Dis') as lq_dis_scope:
                self.lq.dis.layers = self.build_arch(self.real_lq,
                                                     self.dis_arch, 'LQ_Dis',
                                                     self.graph)
            with tf.name_scope('HQ_Dis') as hq_dis_scope:
                self.hq.dis.layers = self.build_arch(self.real_hq,
                                                     self.dis_arch, 'HQ_Dis',
                                                     self.graph)

            # store outputs of these branches
            self.lq.gen.from_real = self.lq.gen.layers[-1]
            self.hq.gen.from_real = self.hq.gen.layers[-1]
            self.lq.dis.real = self.lq.dis.layers[-1]
            self.hq.dis.real = self.hq.dis.layers[-1]

            # only build these parts in training mode
            if not self.inference_mode:

                # build image pools
                self.lq.pool = ImagePool(self.pool_size,
                                         self.input_size,
                                         name='LQ-Pool')
                self.hq.pool = ImagePool(self.pool_size,
                                         self.input_size,
                                         name='HQ-Pool')

                # re-build components so that generated images are also taken into account
                scope.reuse_variables()
                with tf.name_scope(lq_gen_scope):
                    self.lq.gen.cyc = self.build_arch(self.hq.gen.from_real,
                                                      self.gen_arch, 'LQ_Gen',
                                                      self.graph)[-1]
                with tf.name_scope(hq_gen_scope):
                    self.hq.gen.cyc = self.build_arch(self.lq.gen.from_real,
                                                      self.gen_arch, 'HQ_Gen',
                                                      self.graph)[-1]
                with tf.name_scope(lq_dis_scope):
                    self.lq.dis.pool_fake = self.build_arch(
                        self.lq.pool.read(), self.dis_arch, 'LQ_Dis',
                        self.graph)[-1]
                    scope.reuse_variables()
                    self.lq.dis.fake = self.build_arch(self.lq.gen.from_real,
                                                       self.dis_arch, 'LQ_Dis',
                                                       self.graph)[-1]
                with tf.name_scope(hq_dis_scope):
                    self.hq.dis.pool_fake = self.build_arch(
                        self.hq.pool.read(), self.dis_arch, 'HQ_Dis',
                        self.graph)[-1]
                    scope.reuse_variables()
                    self.hq.dis.fake = self.build_arch(self.hq.gen.from_real,
                                                       self.dis_arch, 'HQ_Dis',
                                                       self.graph)[-1]

                # build write ops to image pools
                self.lq.pool_write = self.lq.pool.write(self.lq.gen.from_real)
                self.hq.pool_write = self.hq.pool.write(self.hq.gen.from_real)
コード例 #17
0
ファイル: train.py プロジェクト: lzhua/Learning_resource
def train():
    if FLAGS.load_model is not None:
        checkpoints_dir = "checkpoints/" + FLAGS.load_model
    else:
        current_time = datetime.now().strftime("%Y%m%d-%H%M")
        checkpoints_dir = "checkpoints/{}".format(current_time)
        try:
            os.makedirs(checkpoints_dir)
        except os.error:
            pass

    graph = tf.Graph()
    with graph.as_default():
        cycle_gan = CycleGAN(
            X_train_file=FLAGS.X,
            Y_train_file=FLAGS.Y,
            batch_size=FLAGS.batch_size,
            image_size=FLAGS.image_size,
            use_lsgan=FLAGS.use_lsgan,
            norm=FLAGS.norm,
            lambda1=FLAGS.lambda1,
            lambda2=FLAGS.lambda1,
            learning_rate=FLAGS.learning_rate,
            beta1=FLAGS.beta1,
            ngf=FLAGS.ngf
        )
        G_loss, D_Y_loss, F_loss, D_X_loss, fake_y, fake_x = cycle_gan.model()
        optimizers = cycle_gan.optimize(G_loss, D_Y_loss, F_loss, D_X_loss)

        summary_op = tf.summary.merge_all()
        train_writer = tf.summary.FileWriter(checkpoints_dir, graph)
        saver = tf.train.Saver()
    GPU_OPTIONS=tf.GPUOptions(per_process_gpu_memory_fraction=0.5)
    DEVICE_CONFIG_GPU = tf.ConfigProto(device_count={"CPU": 6}, 
                                    gpu_options=GPU_OPTIONS,
                                    intra_op_parallelism_threads=0,
                                    inter_op_parallelism_threads=0)
    with tf.Session(graph=graph, config=DEVICE_CONFIG_GPU) as sess:
        if FLAGS.load_model is not None:
            checkpoint = tf.train.get_checkpoint_state(checkpoints_dir)
            meta_graph_path = checkpoint.model_checkpoint_path + ".meta"
            restore = tf.train.import_meta_graph(meta_graph_path)
            restore.restore(sess, tf.train.latest_checkpoint(checkpoints_dir))
            step = int(meta_graph_path.split("-")[2].split(".")[0])
        else:
            sess.run(tf.global_variables_initializer())
            step = 0

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        try:
            fake_Y_pool = ImagePool(FLAGS.pool_size)
            fake_X_pool = ImagePool(FLAGS.pool_size)

            while not coord.should_stop():
                # get previously generated images
                fake_y_val, fake_x_val = sess.run([fake_y, fake_x])

                # train
                _, G_loss_val, D_Y_loss_val, F_loss_val, D_X_loss_val, summary = (
                    sess.run(
                        [optimizers, G_loss, D_Y_loss, F_loss, D_X_loss, summary_op],
                        feed_dict={cycle_gan.fake_y: fake_Y_pool.query(fake_y_val),
                                   cycle_gan.fake_x: fake_X_pool.query(fake_x_val)}
                    )
                )
                if step % 100 == 0:
                    train_writer.add_summary(summary, step)
                    train_writer.flush()

                if step % 100 == 0:
                    logging.info('-----------Step %d:-------------' % step)
                    logging.info('  G_loss   : {}'.format(G_loss_val))
                    logging.info('  D_Y_loss : {}'.format(D_Y_loss_val))
                    logging.info('  F_loss   : {}'.format(F_loss_val))
                    logging.info('  D_X_loss : {}'.format(D_X_loss_val))

                if step % 10000 == 0:
                    save_path = saver.save(sess, checkpoints_dir + "/model.ckpt", global_step=step)
                    logging.info("Model saved in file: %s" % save_path)

                step += 1

        except KeyboardInterrupt:
            logging.info('Interrupted')
            coord.request_stop()
        except Exception as e:
            coord.request_stop(e)
        finally:
            save_path = saver.save(sess, checkpoints_dir + "/model.ckpt", global_step=step)
            logging.info("Model saved in file: %s" % save_path)
            # When done, ask the threads to stop.
            coord.request_stop()
            coord.join(threads)
コード例 #18
0
class CycleGAN(ModelBackbone):
    def __init__(self, p):

        super(CycleGAN, self).__init__(p)
        nb = p.batchSize
        size = p.cropSize

        # load/define models
        # The naming conversion is different from those used in the paper
        # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)

        self.netG_A = networks.define_G(p.input_nc, p.output_nc, p.ngf,
                                        p.which_model_netG, p.norm,
                                        not p.no_dropout, p.init_type,
                                        self.gpu_ids)
        self.netG_B = networks.define_G(p.output_nc, p.input_nc, p.ngf,
                                        p.which_model_netG, p.norm,
                                        not p.no_dropout, p.init_type,
                                        self.gpu_ids)

        if self.isTrain:
            use_sigmoid = p.no_lsgan
            self.netD_A = networks.define_D(p.output_nc, p.ndf,
                                            p.which_model_netD, p.n_layers_D,
                                            p.norm, use_sigmoid, p.init_type,
                                            self.gpu_ids)
            self.netD_B = networks.define_D(p.input_nc, p.ndf,
                                            p.which_model_netD, p.n_layers_D,
                                            p.norm, use_sigmoid, p.init_type,
                                            self.gpu_ids)

        if not self.isTrain or p.continue_train:
            which_epoch = p.which_epoch
            self.load_model(self.netG_A, 'G_A', which_epoch)
            self.load_model(self.netG_B, 'G_B', which_epoch)
            if self.isTrain:
                self.load_model(self.netD_A, 'D_A', which_epoch)
                self.load_model(self.netD_B, 'D_B', which_epoch)

        if self.isTrain:
            self.old_lr = p.lr
            self.fake_A_pool = ImagePool(p.pool_size)
            self.fake_B_pool = ImagePool(p.pool_size)
            # define loss functions
            self.criterionGAN = networks.GANLoss(use_lsgan=not p.no_lsgan,
                                                 tensor=self.Tensor)
            self.criterionCycle = torch.nn.L1Loss()
            self.criterionIdt = torch.nn.L1Loss()

            # initialize optimizers
            self.optimizer_G = torch.optim.Adam(itertools.chain(
                self.netG_A.parameters(), self.netG_B.parameters()),
                                                lr=p.lr,
                                                betas=(p.beta1, 0.999))
            self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(),
                                                  lr=p.lr,
                                                  betas=(p.beta1, 0.999))
            self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(),
                                                  lr=p.lr,
                                                  betas=(p.beta1, 0.999))

            self.optimizers = []
            self.schedulers = []
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D_A)
            self.optimizers.append(self.optimizer_D_B)
            for optimizer in self.optimizers:
                self.schedulers.append(networks.get_scheduler(optimizer, p))

        # print('---------- Networks initialized -------------')
        # networks.print_network(self.netG_A)
        # networks.print_network(self.netG_B)
        # if self.isTrain:
        #     networks.print_network(self.netD_A)
        #     networks.print_network(self.netD_B)
        # print('-----------------------------------------------')

    def name(self):
        return 'CycleGAN'

    def set_input(self, inp):
        AtoB = self.p.which_direction == 'AtoB'
        input_A = inp['A' if AtoB else 'B']
        input_B = inp['B' if AtoB else 'A']
        if len(self.gpu_ids) > 0:
            input_A = input_A.cuda(self.gpu_ids[0])
            input_B = input_B.cuda(self.gpu_ids[0])
        self.input_A = input_A
        self.input_B = input_B
        self.image_paths = inp['A_path' if AtoB else 'B_path']

    def forward(self):
        self.real_A = Variable(self.input_A)
        self.real_B = Variable(self.input_B)

    def test(self):
        with torch.no_grad():
            real_A = Variable(self.input_A)
            fake_B = self.netG_A(real_A)
            self.rec_A = self.netG_B(fake_B).data
            self.fake_B = fake_B.data

            real_B = Variable(self.input_B)
            fake_A = self.netG_B(real_B)
            self.rec_B = self.netG_A(fake_A).data
            self.fake_A = fake_A.data

    # get image paths
    def get_image_paths(self):
        return self.image_paths

    def backward_D_basic(self, netD, real, fake):
        # Real
        pred_real = netD(real)
        loss_D_real = self.criterionGAN(pred_real, True)
        # Fake
        pred_fake = netD(fake.detach())
        loss_D_fake = self.criterionGAN(pred_fake, False)
        # Combined loss
        loss_D = (loss_D_real + loss_D_fake) * 0.5
        # backward
        loss_D.backward()
        return loss_D

    def backward_D_A(self):
        fake_B = self.fake_B_pool.query(self.fake_B)
        loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B)
        self.loss_D_A = loss_D_A.item()

    def backward_D_B(self):
        fake_A = self.fake_A_pool.query(self.fake_A)
        loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A)
        self.loss_D_B = loss_D_B.item()

    def backward_G(self):
        lambda_idt = self.p.identity
        lambda_A = self.p.lambda_A
        lambda_B = self.p.lambda_B
        # Identity loss
        if lambda_idt > 0:
            # G_A should be identity if real_B is fed.
            idt_A = self.netG_A(self.real_B)
            loss_idt_A = self.criterionIdt(idt_A,
                                           self.real_B) * lambda_B * lambda_idt
            # G_B should be identity if real_A is fed.
            idt_B = self.netG_B(self.real_A)
            loss_idt_B = self.criterionIdt(idt_B,
                                           self.real_A) * lambda_A * lambda_idt
            self.idt_A = idt_A.data
            self.idt_B = idt_B.data
            self.loss_idt_A = loss_idt_A.item()
            self.loss_idt_B = loss_idt_B.item()

        else:
            loss_idt_A = 0
            loss_idt_B = 0
            self.loss_idt_A = 0
            self.loss_idt_B = 0

        # GAN loss D_A(G_A(A))
        fake_B = self.netG_A(self.real_A)
        pred_fake = self.netD_A(fake_B)
        loss_G_A = self.criterionGAN(pred_fake, True)

        # GAN loss D_B(G_B(B))
        fake_A = self.netG_B(self.real_B)
        pred_fake = self.netD_B(fake_A)
        loss_G_B = self.criterionGAN(pred_fake, True)

        # Forward cycle loss
        rec_A = self.netG_B(fake_B)
        loss_cycle_A = self.criterionCycle(rec_A, self.real_A) * lambda_A
        # print("loss_cycle_A  ",loss_cycle_A.grad)

        # Backward cycle loss
        rec_B = self.netG_A(fake_A)
        loss_cycle_B = self.criterionCycle(rec_B, self.real_B) * lambda_B

        # combined loss
        loss_G = loss_G_A + loss_G_B + loss_cycle_A + loss_cycle_B + loss_idt_A + loss_idt_B
        loss_G.backward()

        self.fake_B = fake_B.data
        self.fake_A = fake_A.data
        self.rec_A = rec_A.data
        self.rec_B = rec_B.data

        self.loss_G_A = loss_G_A.item()
        self.loss_G_B = loss_G_B.item()

        self.loss_cycle_A = loss_cycle_A.item()
        self.loss_cycle_B = loss_cycle_B.item()

    # def get_gradient(self, img):
    #     # print("get_gradient ",img)
    #     # return np.gradient(np.array(img))
    #     return np.gradient(img)

    def optimize_parameters(self):
        # forward
        self.forward()
        # G_A and G_B
        self.optimizer_G.zero_grad()
        self.backward_G()
        self.optimizer_G.step()
        # D_A
        self.optimizer_D_A.zero_grad()
        self.backward_D_A()
        self.optimizer_D_A.step()
        # D_B
        self.optimizer_D_B.zero_grad()
        self.backward_D_B()
        self.optimizer_D_B.step()

    def get_current_errors(self):
        ret_errors = OrderedDict([('D_A', self.loss_D_A),
                                  ('G_A', self.loss_G_A),
                                  ('Cyc_A', self.loss_cycle_A),
                                  ('D_B', self.loss_D_B),
                                  ('G_B', self.loss_G_B),
                                  ('Cyc_B', self.loss_cycle_B)])

        if self.p.identity > 0.0:
            ret_errors['idt_A'] = self.loss_idt_A
            ret_errors['idt_B'] = self.loss_idt_B

        return ret_errors

    def get_current_visuals(self):
        real_A = tensor2im(self.input_A)
        fake_B = tensor2im(self.fake_B)
        rec_A = tensor2im(self.rec_A)
        real_B = tensor2im(self.input_B)
        fake_A = tensor2im(self.fake_A)
        rec_B = tensor2im(self.rec_B)
        ret_visuals = OrderedDict([('real_A', real_A), ('fake_B', fake_B),
                                   ('rec_A', rec_A), ('real_B', real_B),
                                   ('fake_A', fake_A), ('rec_B', rec_B)])
        if self.isTrain and self.p.identity > 0.0:
            ret_visuals['idt_A'] = tensor2im(self.idt_A)
            ret_visuals['idt_B'] = tensor2im(self.idt_B)

        return ret_visuals

    def save(self, label):
        self.save_model(self.netG_A, 'G_A', label, self.gpu_ids)
        self.save_model(self.netD_A, 'D_A', label, self.gpu_ids)
        self.save_model(self.netG_B, 'G_B', label, self.gpu_ids)
        self.save_model(self.netD_B, 'D_B', label, self.gpu_ids)
コード例 #19
0
class SimpleGAN(Trainer):
    def __init__(self, parsed_args, parsed_groups):
        super().__init__(**parsed_groups['trainer arguments'])

        self.gen = UNetGenerator(**parsed_groups['generator arguments'])
        self.disc = NLayerDiscriminator(
            **parsed_groups['discriminator arguments'])

        init_weights(self.gen)
        init_weights(self.disc)

        self.real_label = torch.tensor(1.0)
        self.fake_label = torch.tensor(0.0)

        self.crit = torch.nn.MSELoss()  # LSGAN
        self.image_pool = ImagePool(parsed_args.pool_size,
                                    parsed_args.replay_prob)

        self.sel_ind = 0
        self.un_normalize = lambda x: 255. * (1 + x.clamp(min=-1, max=1)) / 2.

        self.parsed_args = parsed_args
        self.n_vis = parsed_args.n_vis
        self.vis = Visualizer()

    def configure_optimizers(self):
        opt1 = torch.optim.Adam(self.disc.parameters(), lr=self.parsed_args.lr)
        opt2 = torch.optim.Adam(self.gen.parameters(), lr=self.parsed_args.lr)

        N = self.parsed_args.n_epochs
        N_start = int(N * self.parsed_args.frac_decay_start)
        sched_lamb = lambda x: 1.0 - max(0, x - N_start) / (N - N_start)

        sched1 = torch.optim.lr_scheduler.LambdaLR(opt1, lr_lambda=sched_lamb)
        sched2 = torch.optim.lr_scheduler.LambdaLR(opt2, lr_lambda=sched_lamb)

        return opt1, opt2, sched1, sched2

    def on_fit_start(self):
        self.vis.start()

    def on_fit_end(self):
        self.vis.stop()

    def _shared_step(self, batch, save_img, is_train):
        res = Result()

        X, Y_real = batch
        Y_fake = self.gen(X)

        if is_train:
            Y_pool = self.image_pool.query(Y_fake.detach())
        else:
            Y_pool = Y_fake
            res.recon_error = self.crit(Y_real, Y_fake)

        real_predict = self.disc(Y_real)
        fake_predict = self.disc(Y_pool)

        real_label = self.real_label.expand_as(real_predict)
        fake_label = self.fake_label.expand_as(fake_predict)

        disc_loss = 0.5 * (self.crit(real_predict, real_label) + \
                           self.crit(fake_predict, fake_label))

        if is_train:
            res.step(disc_loss)
        res.disc_loss = disc_loss

        gen_predict = self.disc(Y_fake)
        gen_loss = self.crit(gen_predict, real_label)

        if is_train:
            res.step(gen_loss)
        res.gen_loss = gen_loss

        if save_img:
            res.img = [
                self.un_normalize(X[:self.n_vis]),
                self.un_normalize(Y_fake[:self.n_vis]),
                self.un_normalize(Y_real[:self.n_vis])
            ]
        return res

    def training_step(self, batch, batch_idx):
        res = self._shared_step(batch,
                                save_img=(batch_idx == 0),
                                is_train=True)
        return res

    def validation_step(self, batch, batch_idx):
        res = self._shared_step(batch,
                                save_img=(batch_idx == self.sel_ind),
                                is_train=False)
        return res

    def _shared_end(self, result_outputs, is_train):
        phase = 'Train' if is_train else 'Valid'

        self.vis.plot('Gen. Loss', phase + ' Loss', self.current_epoch,
                      torch.mean(torch.stack(result_outputs.gen_loss)))
        self.vis.plot('Disc. Loss', phase + ' Loss', self.current_epoch,
                      torch.mean(torch.stack(result_outputs.disc_loss)))

        collated_imgs = torch.cat([*torch.cat(result_outputs.img[0], dim=3)],
                                  dim=1)
        self.vis.show_image(phase + ' Images', collated_imgs)

    def training_epoch_end(self, training_outputs):
        self._shared_end(training_outputs, is_train=True)

    def validation_epoch_end(self, validation_outputs):
        self._shared_end(validation_outputs, is_train=False)
        self.sel_ind = random.randint(0, len(self.validation_loader) - 1)

        return torch.mean(torch.stack(validation_outputs.recon_error))
コード例 #20
0
class Model:
    def initialize(self, cfg):
        self.cfg = cfg

        ## set devices
        if cfg['GPU_IDS']:
            assert (torch.cuda.is_available())
            self.device = torch.device('cuda:{}'.format(cfg['GPU_IDS'][0]))
            torch.backends.cudnn.benchmark = True
            print('Using %d GPUs' % len(cfg['GPU_IDS']))
        else:
            self.device = torch.device('cpu')

        # define network
        if cfg['ARCHI'] == 'alexnet':
            self.netB = networks.netB_alexnet()
            self.netH = networks.netH_alexnet()
            if self.cfg['USE_DA'] and self.cfg['TRAIN']:
                self.netD = networks.netD_alexnet(self.cfg['DA_LAYER'])
        elif cfg['ARCHI'] == 'vgg16':
            raise NotImplementedError
            self.netB = networks.netB_vgg16()
            self.netH = networks.netH_vgg16()
            if self.cfg['USE_DA'] and self.cfg['TRAIN']:
                self.netD = netD_vgg16(self.cfg['DA_LAYER'])
        elif 'resnet' in cfg['ARCHI']:
            self.netB = networks.netB_resnet34()
            self.netH = networks.netH_resnet34()
            if self.cfg['USE_DA'] and self.cfg['TRAIN']:
                self.netD = networks.netD_resnet(self.cfg['DA_LAYER'])
        else:
            raise ValueError('Un-supported network')

        ## initialize network param.
        self.netB = networks.init_net(self.netB, cfg['GPU_IDS'], 'xavier')
        self.netH = networks.init_net(self.netH, cfg['GPU_IDS'], 'xavier')

        if self.cfg['USE_DA'] and self.cfg['TRAIN']:
            self.netD = networks.init_net(self.netD, cfg['GPU_IDS'], 'xavier')

        # loss, optimizer, and scherduler
        if cfg['TRAIN']:
            self.total_steps = 0
            ## Output path
            self.save_dir = os.path.join(
                cfg['OUTPUT_PATH'], cfg['ARCHI'],
                datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
            if not os.path.isdir(self.save_dir):
                os.makedirs(self.save_dir)
            # self.logger = Logger(self.save_dir)

            ## model names
            self.model_names = ['netB', 'netH']
            ## loss
            self.criterionGAN = networks.GANLoss().to(self.device)
            self.criterionDepth1 = torch.nn.MSELoss().to(self.device)
            self.criterionNorm = torch.nn.CosineEmbeddingLoss().to(self.device)
            self.criterionEdge = torch.nn.BCELoss().to(self.device)

            ## optimizers
            self.lr = cfg['LR']
            self.optimizers = []
            self.optimizer_B = torch.optim.Adam(self.netB.parameters(),
                                                lr=cfg['LR'],
                                                betas=(cfg['BETA1'],
                                                       cfg['BETA2']))
            self.optimizer_H = torch.optim.Adam(self.netH.parameters(),
                                                lr=cfg['LR'],
                                                betas=(cfg['BETA1'],
                                                       cfg['BETA2']))
            self.optimizers.append(self.optimizer_B)
            self.optimizers.append(self.optimizer_H)
            if cfg['USE_DA']:
                self.real_pool = ImagePool(cfg['POOL_SIZE'])
                self.syn_pool = ImagePool(cfg['POOL_SIZE'])
                self.model_names.append('netD')
                ## use SGD for discriminator
                self.optimizer_D = torch.optim.SGD(
                    self.netD.parameters(),
                    lr=cfg['LR'],
                    momentum=cfg['MOMENTUM'],
                    weight_decay=cfg['WEIGHT_DECAY'])
                self.optimizers.append(self.optimizer_D)
            ## LR scheduler
            self.schedulers = [
                networks.get_scheduler(optimizer, cfg)
                for optimizer in self.optimizers
            ]
        else:
            ## testing
            self.model_names = ['netB', 'netH']
            self.criterionDepth1 = torch.nn.MSELoss().to(self.device)
            self.criterionNorm = torch.nn.CosineEmbeddingLoss(
                reduction='none').to(self.device)

        self.load_dir = os.path.join(cfg['CKPT_PATH'])
        self.criterionNorm_eval = torch.nn.CosineEmbeddingLoss(
            reduction='none').to(self.device)

        if cfg['TEST'] or cfg['RESUME']:
            self.load_networks(cfg['EPOCH_LOAD'])

    def set_input(self, inputs):
        if self.cfg['GRAY']:
            _ch = np.random.randint(3)
            _syn = inputs['syn']['color'][:, _ch, :, :]
            self.input_syn_color = torch.stack((_syn, _syn, _syn),
                                               dim=1).to(self.device)
        else:
            self.input_syn_color = inputs['syn']['color'].to(self.device)
        self.input_syn_dep = inputs['syn']['depth'].to(self.device)
        self.input_syn_edge = inputs['syn']['edge'].to(self.device)
        self.input_syn_edge_count = inputs['syn']['edge_pix'].to(self.device)
        self.input_syn_norm = inputs['syn']['normal'].to(self.device)
        if self.cfg['USE_DA']:
            if self.cfg['GRAY']:
                _ch = np.random.randint(3)
                _real = inputs['real'][0][:, _ch, :, :]
                self.input_real_color = torch.stack((_real, _real, _real),
                                                    dim=1).to(self.device)
            else:
                self.input_real_color = inputs['real'][0].to(self.device)

    def forward(self):
        self.feat_syn = self.netB(self.input_syn_color)
        # TODO: make it compatible with other networks in addition to ResNet
        self.head_pred = self.netH(self.feat_syn['layer1'],
                                   self.feat_syn['layer2'],
                                   self.feat_syn['layer3'],
                                   self.feat_syn['layer4'])
        if self.cfg['USE_DA'] and self.cfg['TRAIN']:
            self.feat_real = self.netB(self.input_real_color)
            self.pred_D_real = self.netD(self.feat_real[self.cfg['DA_LAYER']])
            self.pred_D_syn = self.netD(self.feat_syn[self.cfg['DA_LAYER']])
        return self.head_pred

    def backward_BH(self):
        ## forward to compute prediction
        # TODO: replace this with self.head_pred to avoid computation twice
        self.task_pred = self.netH(self.feat_syn['layer1'],
                                   self.feat_syn['layer2'],
                                   self.feat_syn['layer3'],
                                   self.feat_syn['layer4'])

        # depth
        depth_diff = self.task_pred['depth'] - self.input_syn_dep
        _n = self.task_pred['depth'].size(0) * self.task_pred['depth'].size(
            2) * self.task_pred['depth'].size(3)
        loss_depth2 = depth_diff.sum().pow_(2).div_(_n).div_(_n)
        loss_depth1 = self.criterionDepth1(self.task_pred['depth'],
                                           self.input_syn_dep)
        self.loss_dep = self.cfg['DEP_WEIGHT'] * (loss_depth1 -
                                                  loss_depth2) * 0.5

        # surface normal
        ch = self.task_pred['norm'].size(1)
        _pred = self.task_pred['norm'].permute(0, 2, 3,
                                               1).contiguous().view(-1, ch)
        _gt = self.input_syn_norm.permute(0, 2, 3, 1).contiguous().view(-1, ch)
        _gt = (_gt / 127.5) - 1
        _pred = torch.nn.functional.normalize(_pred, dim=1)
        self.task_pred['norm'] = _pred.view(self.task_pred['norm'].size(0),
                                            self.task_pred['norm'].size(2),
                                            self.task_pred['norm'].size(3),
                                            3).permute(0, 3, 1, 2)
        self.task_pred['norm'] = (self.task_pred['norm'] + 1) * 127.5
        cos_label = torch.ones(_gt.size(0)).to(self.device)
        self.loss_norm = self.cfg['NORM_WEIGHT'] * self.criterionNorm(
            _pred, _gt, cos_label)

        # # edge
        self.loss_edge = self.cfg['EDGE_WEIGHT'] * self.criterionEdge(
            self.task_pred['edge'], self.input_syn_edge)

        ## combined loss
        loss = self.loss_dep + self.loss_norm + self.loss_edge

        if self.cfg['USE_DA']:
            pred_syn = self.netD(self.feat_syn[self.cfg['DA_LAYER']])
            self.loss_DA = self.criterionGAN(pred_syn, True)
            loss += self.loss_DA * self.cfg['DA_WEIGHT']

        loss.backward()

    def backward_D(self):
        ## Synthetic
        # stop backprop to netB by detaching
        _feat_s = self.syn_pool.query(
            self.feat_syn[self.cfg['DA_LAYER']].detach().cpu())
        pred_syn = self.netD(_feat_s.to(self.device))
        self.loss_D_syn = self.criterionGAN(pred_syn, False)

        ## Real
        _feat_r = self.real_pool.query(
            self.feat_real[self.cfg['DA_LAYER']].detach().cpu())
        pred_real = self.netD(_feat_r.to(self.device))
        self.loss_D_real = self.criterionGAN(pred_real, True)

        ## Combined
        self.loss_D = (self.loss_D_syn + self.loss_D_real) * 0.5
        self.loss_D.backward()

    def optimize(self):
        self.total_steps += 1
        self.train_mode()
        self.forward()
        # if DA, update on real data
        if self.cfg['USE_DA']:
            self.set_requires_grad(self.netD, True)
            self.set_requires_grad([self.netB, self.netH], False)
            self.optimizer_D.zero_grad()
            self.backward_D()
            self.optimizer_D.step()

        # update on synthetic data
        if self.cfg['USE_DA']:
            self.set_requires_grad([self.netB, self.netH], True)
            self.set_requires_grad(self.netD, False)
        self.optimizer_B.zero_grad()
        self.optimizer_H.zero_grad()
        self.backward_BH()
        self.optimizer_B.step()
        self.optimizer_H.step()

    def train_mode(self):
        self.netB.train()
        self.netH.train()
        if self.cfg['USE_DA']:
            self.netD.train()

    # make models eval mode during test time
    def eval_mode(self):
        self.netB.eval()
        self.netH.eval()
        if self.cfg['USE_DA']:
            self.netD.eval()

    def angle_error_ratio(self, angle_degree, base_angle_degree):
        logic_map = torch.gt(
            base_angle_degree *
            torch.ones_like(angle_degree, device=self.device), angle_degree)
        if len(logic_map.size()) == 1:
            num_pixels = torch.sum(logic_map).float()
        else:
            num_pixels = torch.sum(logic_map, dim=1).float()
        ratio = torch.div(
            num_pixels,
            torch.tensor(angle_degree.nelement(),
                         device=self.device,
                         dtype=torch.float64))
        return ratio, logic_map

    def normal_angle(self):
        # surface normal
        ch = self.head_pred['norm'].size(1)
        _pred = self.head_pred['norm'].permute(0, 2, 3,
                                               1).contiguous().view(-1, ch)
        _gt = self.input_syn_norm.permute(0, 2, 3, 1).contiguous().view(-1, ch)
        _gt = (_gt / 127.5) - 1
        _pred = torch.nn.functional.normalize(_pred, dim=1)
        _gt = torch.nn.functional.normalize(_gt, dim=1)
        cos_label = torch.ones(_gt.size(0)).to(self.device)
        norm_diff = self.criterionNorm_eval(_pred, _gt, cos_label)

        cos_val = 1 - norm_diff
        cos_val = torch.max(cos_val,
                            -torch.ones_like(cos_val, device=self.device))
        cos_val = torch.min(cos_val,
                            torch.ones_like(cos_val, device=self.device))

        angle_rad = torch.acos(cos_val)
        angle_degree = angle_rad / 3.14 * 180
        return angle_degree

    def test(self):
        self.eval_mode()
        with torch.no_grad():
            self.forward()

            angle_degree = self.normal_angle()
            # ratio metrics
            ratio_11, _ = self.angle_error_ratio(angle_degree, 11.25)
            ratio_22, _ = self.angle_error_ratio(angle_degree, 22.5)
            ratio_30, _ = self.angle_error_ratio(angle_degree, 30.0)
            ratio_45, _ = self.angle_error_ratio(angle_degree, 45.0)
            # image-wise metrics
            batch_size = self.head_pred['norm'].size(0)
            # TODO double check if it's image-wise
            batch_angles = angle_degree.view(batch_size, -1)
            image_mean = torch.mean(batch_angles, dim=1)
            image_score, _ = self.angle_error_ratio(batch_angles, 45.0)

            return {
                'batch_size': batch_size,
                'pixel_error': angle_degree.cpu().detach().numpy(),
                'image_mean': image_mean.cpu().detach().numpy(),
                'image_score': image_score.cpu().detach().numpy(),
                'ratio_11': ratio_11.cpu().detach().numpy(),
                'ratio_22': ratio_22.cpu().detach().numpy(),
                'ratio_30': ratio_30.cpu().detach().numpy(),
                'ratio_45': ratio_45.cpu().detach().numpy()
            }

    def out_logic_map(self, epoch_num, img_num):
        self.eval_mode()
        with torch.no_grad():
            self.forward()

        # surface normal
        batch_size = self.head_pred['norm'].size(0)
        ch = self.head_pred['norm'].size(1)
        _pred = self.head_pred['norm'].permute(0, 2, 3,
                                               1).contiguous().view(-1, ch)
        _gt = self.input_syn_norm.permute(0, 2, 3, 1).contiguous().view(-1, ch)
        _gt = (_gt / 127.5) - 1
        _pred = torch.nn.functional.normalize(_pred, dim=1)
        _gt = torch.nn.functional.normalize(_gt, dim=1)
        cos_label = torch.ones(_gt.size(0)).to(self.device)
        norm_diff = self.criterionNorm_eval(_pred, _gt, cos_label)

        cos_val = 1 - norm_diff
        cos_val = torch.max(cos_val,
                            -torch.ones_like(cos_val, device=self.device))
        cos_val = torch.min(cos_val,
                            torch.ones_like(cos_val, device=self.device))

        angle_rad = torch.acos(cos_val)
        angle_degree = angle_rad / 3.14 * 180
        # ratio metrics
        ratio_11, c = self.angle_error_ratio(angle_degree, 11.25)

        good_pixel_img = torch.cat(
            (c.view(-1, 1), c.view(-1, 1), c.view(-1, 1)),
            1).view(self.head_pred['norm'].size(0),
                    self.head_pred['norm'].size(2),
                    self.head_pred['norm'].size(3), 3).permute(0, 3, 1, 2)

        self.head_pred['norm'] = _pred.view(self.head_pred['norm'].size(0),
                                            self.head_pred['norm'].size(2),
                                            self.head_pred['norm'].size(3),
                                            3).permute(0, 3, 1, 2)
        self.head_pred['norm'] = (self.head_pred['norm'] + 1) * 127.5
        vis_norm = torch.cat((self.input_syn_norm, self.head_pred['norm'],
                              good_pixel_img.float() * 255),
                             dim=0)
        map_path = '%s/ep%d' % (self.cfg['VIS_PATH'], epoch_num)
        if not os.path.isdir(map_path):
            os.makedirs(map_path)
        torchvision.utils.save_image(vis_norm.detach(),
                                     '%s/%d_norm.jpg' % (map_path, img_num),
                                     nrow=1,
                                     normalize=True)

    # update learning rate (called once every epoch)
    def update_learning_rate(self):
        for scheduler in self.schedulers:
            scheduler.step()
        self.lr = self.cfgimizers[0].param_groups[0]['lr']
        print('learning rate = %.7f' % self.lr)

    #  return visualization images. train.py will save the images.
    def visualize_pred(self, ep=0):
        vis_dir = os.path.join(self.save_dir, 'vis')
        if not os.path.isdir(vis_dir):
            os.makedirs(vis_dir)
        if self.total_steps % self.cfg['VIS_FREQ'] == 0:
            num_pic = min(10, self.task_pred['norm'].size(0))
            torchvision.utils.save_image(self.input_syn_color[0:num_pic].cpu(),
                                         '%s/ep_%d_iter_%d_color.jpg' %
                                         (vis_dir, ep, self.total_steps),
                                         nrow=num_pic,
                                         normalize=True)
            vis_norm = torch.cat((self.input_syn_norm[0:num_pic],
                                  self.task_pred['norm'][0:num_pic]),
                                 dim=0)
            torchvision.utils.save_image(vis_norm.detach(),
                                         '%s/ep_%d_iter_%d_norm.jpg' %
                                         (vis_dir, ep, self.total_steps),
                                         nrow=num_pic,
                                         normalize=True)
            vis_depth = torch.cat((self.input_syn_dep[0:num_pic],
                                   self.task_pred['depth'][0:num_pic]),
                                  dim=0)
            torchvision.utils.save_image(vis_depth.detach(),
                                         '%s/ep_%d_iter_%d_depth.jpg' %
                                         (vis_dir, ep, self.total_steps),
                                         nrow=num_pic,
                                         normalize=True)
            # TODO Jason: visualization
            # edge_vis = torch.nn.functional.sigmoid(self.task_pred['edge'])
            # vis_edge = torch.cat((self.input_syn_edge[0:num_pic], edge_vis[0:num_pic]), dim=0)
            # torchvision.utils.save_image(vis_edge.detach(),
            #                             '%s/ep_%d_iter_%d_edge.jpg' % (vis_dir,ep,self.total_steps),
            #                             nrow=num_pic, normalize=False)
            if self.cfg['USE_DA']:
                torchvision.utils.save_image(
                    self.input_real_color[0:num_pic].cpu(),
                    '%s/ep_%d_iter_%d_real.jpg' %
                    (vis_dir, ep, self.total_steps),
                    nrow=num_pic,
                    normalize=True)
            print('==> Saved epoch %d total step %d visualization to %s' %
                  (ep, self.total_steps, vis_dir))

    # print on screen, log into tensorboard
    def print_n_log_losses(self, ep=0):
        if self.total_steps % self.cfg['PRINT_FREQ'] == 0:
            print('\nEpoch: %d  Total_step: %d  LR: %f' %
                  (ep, self.total_steps, self.lr))
            # print('Train on tasks: Loss_dep: %.4f   | Loss_edge: %.4f   | Loss_norm: %.4f'
            #       % (self.loss_dep, self.loss_edge, self.loss_norm))
            print('Train on tasks: Loss_dep: %.4f | Loss_norm: %.4f' %
                  (self.loss_dep, self.loss_norm))
            info = {
                'loss_dep': self.loss_dep,
                'loss_norm': self.loss_norm  #,
                # 'loss_edge': self.loss_edge
            }
            if self.cfg['USE_DA']:
                print(
                    'Train for DA:   Loss_D_syn: %.4f | Loss_D_real: %.4f | Loss_DA: %.4f'
                    % (self.loss_D_syn, self.loss_D_real, self.loss_DA))
                info['loss_D_syn'] = self.loss_D_syn
                info['loss_D_real'] = self.loss_D_real
                info['loss_DA'] = self.loss_DA

            # for tag, value in info.items():
            #     self.logger.scalar_summary(tag, value, self.total_steps)

    # save models to the disk
    def save_networks(self, which_epoch):
        for name in self.model_names:
            save_filename = '%s_ep%s.pth' % (name, which_epoch)
            save_path = os.path.join(self.save_dir, save_filename)
            net = getattr(self, name)
            if isinstance(net, torch.nn.DataParallel):
                torch.save(net.module.cpu().state_dict(), save_path)
            else:
                torch.save(net.cpu().state_dict(), save_path)
            print('==> Saved networks to %s' % save_path)
            if torch.cuda.is_available:
                net.cuda(self.device)

    # load models from the disk
    def load_networks(self, which_epoch):
        self.save_dir = self.load_dir
        print('loading networks...')
        if which_epoch == 'None':
            print('epoch is None')
            exit()
        for name in self.model_names:
            if isinstance(name, str):
                load_filename = '%s_ep%s.pth' % (name, which_epoch)
                load_path = os.path.join(self.load_dir, load_filename)
                net = getattr(self, name)
                if isinstance(net, torch.nn.DataParallel):
                    net = net.module
                    print('loading the model from %s' % load_path)
                    # if you are using PyTorch newer than 0.4 (e.g., built from
                    # GitHub source), you can remove str() on self.device
                    state_dict = torch.load(load_path,
                                            map_location=str(self.device))
                    net.load_state_dict(state_dict)

    # set requies_grad=Fasle to avoid computation
    def set_requires_grad(self, nets, requires_grad=False):
        if not isinstance(nets, list):
            nets = [nets]
        for net in nets:
            if net is not None:
                for param in net.parameters():
                    param.requires_grad = requires_grad
コード例 #21
0
def train():
    # train the model
    if FLAGS.load_model is not None:
        checkpoints_dir = "checkpoints/" + FLAGS.load_model.lstrip(
            "checkpoints/")
    else:
        current_time = datetime.now().strftime("%Y%m%d-%H%M")
        checkpoints_dir = "checkpoints/{}".format(current_time)
        try:
            os.makedirs(checkpoints_dir)
        except os.error:
            pass

    graph = tf.Graph()
    # add segmentation work-flow yw3025
    segementation = Segementation(None)
    with graph.as_default():
        cycle_gan = CycleGAN(
            X_train_file=FLAGS.X,
            Y_train_file=FLAGS.Y,
            batch_size=FLAGS.batch_size,
            image_size=FLAGS.image_size,
            use_lsgan=FLAGS.use_lsgan,
            norm=FLAGS.norm,
            lambda1=FLAGS.lambda1,
            lambda2=FLAGS.lambda2,
            learning_rate=FLAGS.learning_rate,
            beta1=FLAGS.beta1,
            ngf=FLAGS.ngf,
        )
        G_loss, D_Y_loss, F_loss, D_X_loss, fake_y, fake_x, y, x = cycle_gan.model(
        )
        optimizers = cycle_gan.optimize(G_loss, D_Y_loss, F_loss, D_X_loss)

        summary_op = tf.summary.merge_all()
        train_writer = tf.summary.FileWriter(checkpoints_dir, graph)
        saver = tf.train.Saver()

    with tf.Session(graph=graph) as sess:
        if FLAGS.load_model is not None:
            checkpoint = tf.train.get_checkpoint_state(checkpoints_dir)
            meta_graph_path = checkpoint.model_checkpoint_path + ".meta"
            restore = tf.train.import_meta_graph(meta_graph_path)
            restore.restore(sess, tf.train.latest_checkpoint(checkpoints_dir))
            step = int(meta_graph_path.split("-")[2].split(".")[0])
        else:
            sess.run(tf.global_variables_initializer())
            step = 0

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        try:
            fake_Y_pool = ImagePool(FLAGS.pool_size)
            fake_X_pool = ImagePool(FLAGS.pool_size)

            while not coord.should_stop():
                # get previously generated images
                fake_y_val, fake_x_val, y_val, x_val = sess.run(
                    [fake_y, fake_x, y, x])
                # add segmentation work-flow yw3025
                mask_y = segementation.get_result_railway(y_val)
                mask_x = segementation.get_result_railway(x_val)
                mask_fake_x = segementation.get_result_railway(fake_x_val)
                mask_fake_y = segementation.get_result_railway(fake_y_val)
                # train yw3025
                _, G_loss_val, D_Y_loss_val, F_loss_val, D_X_loss_val, summary = (
                    sess.run(
                        [
                            optimizers, G_loss, D_Y_loss, F_loss, D_X_loss,
                            summary_op
                        ],
                        feed_dict={
                            cycle_gan.fake_y: fake_Y_pool.query(fake_y_val),
                            cycle_gan.fake_x: fake_X_pool.query(fake_x_val),
                            cycle_gan.g_mask_x: mask_x,
                            cycle_gan.g_mask_y: mask_y,
                            cycle_gan.g_mask_fake_y: mask_fake_y,
                            cycle_gan.g_mask_fake_x: mask_fake_x,
                            cycle_gan.mask_x: mask_x,
                            cycle_gan.mask_y: mask_y
                        }))

                train_writer.add_summary(summary, step)
                train_writer.flush()

                if step % 100 == 0:
                    logging.info('-----------Step %d:-------------' % step)
                    logging.info('  G_loss   : {}'.format(G_loss_val))
                    logging.info('  D_Y_loss : {}'.format(D_Y_loss_val))
                    logging.info('  F_loss   : {}'.format(F_loss_val))
                    logging.info('  D_X_loss : {}'.format(D_X_loss_val))

                if step % 3000 == 0:
                    save_path = saver.save(sess,
                                           checkpoints_dir + "/model.ckpt",
                                           global_step=step)
                    logging.info("Model saved in file: %s" % save_path)

                step += 1

        except KeyboardInterrupt:
            logging.info('Interrupted')
            coord.request_stop()
        except Exception as e:
            coord.request_stop(e)
        finally:
            save_path = saver.save(sess,
                                   checkpoints_dir + "/model.ckpt",
                                   global_step=step)
            logging.info("Model saved in file: %s" % save_path)
            # When done, ask the threads to stop.
            coord.request_stop()
            coord.join(threads)
コード例 #22
0
            saver_ae_x.restore(sess, aex_latest_ckpt)
            saver_ae_y.restore(sess, aey_latest_ckpt)
        if FLAGS.whole_ckpt_dir != 'no':
            whole_latest_ckpt = tf.train.latest_checkpoint(
                FLAGS.whole_ckpt_dir)
            saver.restore(sess, whole_latest_ckpt)

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        #GAN part
        if not FLAGS.BEGAN:
            np.set_printoptions(precision=3)
            try:
                step = 0
                fake_Y_pool = ImagePool(FLAGS.pool_size)
                fake_X_pool = ImagePool(FLAGS.pool_size)

                concat_img_list = []

                fake_y_val, fake_x_val = sess.run([fake_y, fake_x])

                x_last_test_predict_list = []
                for last_i in range(len(testimage_x_list)):
                    x_last_test_predict_list.append(0)
                y_last_test_predict_list = []
                for last_i in range(len(testimage_x_list)):
                    y_last_test_predict_list.append(0)

                while not coord.should_stop():
コード例 #23
0
ファイル: cyclechain.py プロジェクト: sjtuzq/Cycle_Dynamics
class CycleGANModel():
    def __init__(self, opt):
        self.opt = opt
        self.dynamic = opt.dynamic
        self.isTrain = opt.istrain
        self.Tensor = torch.cuda.FloatTensor

        # load/define networks
        # The naming conversion is different from those used in the paper
        # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)

        self.netG_A = GModel(opt).cuda()
        self.netG_B = GModel(opt).cuda()
        self.netF_A = Fmodel().cuda()
        self.dataF = Fdata.get_loader()

        if self.isTrain:
            self.netD_A = DModel(opt).cuda()
            self.netD_B = DModel(opt).cuda()

        if self.isTrain:
            self.fake_A_pool = ImagePool(pool_size=128)
            self.fake_B_pool = ImagePool(pool_size=128)
            # define loss functions
            self.criterionGAN = GANLoss(tensor=self.Tensor).cuda()
            if opt.loss == 'l1':
                self.criterionCycle = torch.nn.L1Loss()
                self.criterionIdt = torch.nn.L1Loss()
            elif opt.loss == 'l2':
                self.criterionCycle = torch.nn.MSELoss()
                self.criterionIdt = torch.nn.MSELoss()
            # initialize optimizers
            # self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()))
            self.optimizer_G = torch.optim.Adam([{
                'params':
                self.netG_A.parameters(),
                'lr':
                1e-3
            }, {
                'params':
                self.netF_A.parameters(),
                'lr':
                0.0
            }, {
                'params':
                self.netG_B.parameters(),
                'lr':
                1e-3
            }])
            self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters())
            self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters())
            self.optimizers = []
            self.schedulers = []
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D_A)
            self.optimizers.append(self.optimizer_D_B)
            for optimizer in self.optimizers:
                # self.schedulers.append(networks.get_scheduler(optimizer, opt))
                self.schedulers.append(optimizer)

        print('---------- Networks initialized -------------')
        # networks.print_network(self.netG_A)
        # networks.print_network(self.netG_B)
        # if self.isTrain:
        #     networks.print_network(self.netD_A)
        #     networks.print_network(self.netD_B)
        print('-----------------------------------------------')

    def train_forward(self):
        optimizer = torch.optim.Adam(self.netF_A.parameters(), lr=1e-3)
        loss_fn = torch.nn.MSELoss()
        loss_fn = torch.nn.L1Loss()
        for epoch in range(100):
            epoch_loss = 0
            for i, item in enumerate(self.dataF):
                state0 = item[0].float().cuda()
                action = item[1].float().cuda()
                state1 = item[2].float().cuda()
                pred = self.netF_A(state0, action)
                loss = loss_fn(pred, state1)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                epoch_loss += loss.item()
            print('epoch:{} loss:{:.7f}'.format(epoch,
                                                epoch_loss / len(self.dataF)))
        print('forward model has been trained!')

    def set_input(self, input):
        # AtoB = self.opt.which_direction == 'AtoB'
        # input_A = input['A' if AtoB else 'B']
        # input_B = input['B' if AtoB else 'A']
        self.input_A = input[0]
        self.input_Bt0 = input[1][0]
        self.input_Bt1 = input[1][2]
        self.action = input[1][1]

    def forward(self):
        self.real_A = Variable(self.input_A).float().cuda()
        self.real_Bt0 = Variable(self.input_Bt0).float().cuda()
        self.real_Bt1 = Variable(self.input_Bt1).float().cuda()
        self.action = Variable(self.action).float().cuda()

    def test(self):
        self.forward()
        real_A = Variable(self.input_A, volatile=True).float().cuda()
        fake_B = self.netG_A(real_A)
        self.rec_A = self.netG_B(fake_B).data
        self.fake_B = fake_B.data

        real_B = Variable(self.input_Bt0, volatile=True).float().cuda()
        fake_A = self.netG_B(real_B)
        self.rec_B = self.netG_A(fake_A).data
        self.fake_A = fake_A.data

    def backward_D_basic(self, netD, real, fake):
        # Real
        pred_real = netD(real)
        loss_D_real = self.criterionGAN(pred_real, True)
        # Fake
        pred_fake = netD(fake.detach())
        loss_D_fake = self.criterionGAN(pred_fake, False)
        # Combined loss
        loss_D = (loss_D_real + loss_D_fake) * 0.5
        # backward
        loss_D.backward()
        return loss_D

    def backward_D_A(self):
        fake_B = self.fake_B_pool.query(self.fake_B)
        loss_D_A = self.backward_D_basic(self.netD_A, self.real_Bt0, fake_B)
        self.loss_D_A = loss_D_A.item()

    def backward_D_B(self):
        fake_A = self.fake_A_pool.query(self.fake_At0)
        loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A)
        self.loss_D_B = loss_D_B.item()

    def backward_G(self):
        lambda_idt = 0.5
        lambda_A = 100.0
        lambda_B = 100.0
        # Identity loss
        if lambda_idt > 0:
            # G_A should be identity if real_B is fed.
            idt_A = self.netG_A(self.real_Bt0)
            loss_idt_A = self.criterionIdt(
                idt_A, self.real_Bt0) * lambda_B * lambda_idt
            # G_B should be identity if real_A is fed.
            idt_B = self.netG_B(self.real_A)
            loss_idt_B = self.criterionIdt(idt_B,
                                           self.real_A) * lambda_A * lambda_idt

            self.idt_A = idt_A.data
            self.idt_B = idt_B.data
            self.loss_idt_A = loss_idt_A.item()
            self.loss_idt_B = loss_idt_B.item()
        else:
            loss_idt_A = 0
            loss_idt_B = 0
            self.loss_idt_A = 0
            self.loss_idt_B = 0

        lambda_G = 1.0

        # --------first cycle-----------#
        # GAN loss D_A(G_A(A))
        fake_B = self.netG_A(self.real_A)
        pred_fake = self.netD_A(fake_B)
        loss_G_A = self.criterionGAN(pred_fake, True) * lambda_G
        # Forward cycle loss
        rec_A = self.netG_B(fake_B)
        loss_cycle_A = self.criterionCycle(rec_A, self.real_A) * lambda_A

        # ---------second cycle---------#
        # GAN loss D_B(G_B(B))
        fake_At0 = self.netG_B(self.real_Bt0)
        pred_fake = self.netD_B(fake_At0)
        loss_G_Bt0 = self.criterionGAN(pred_fake, True) * lambda_G
        # Backward cycle loss
        rec_Bt0 = self.netG_A(fake_At0)
        loss_cycle_Bt0 = self.criterionCycle(rec_Bt0, self.real_Bt0) * lambda_B

        # ---------third cycle---------#
        # GAN loss D_B(G_B(B))
        fake_At1 = self.netF_A(fake_At0, self.action)
        pred_fake = self.netD_B(fake_At1)
        loss_G_Bt1 = self.criterionGAN(pred_fake, True) * lambda_G
        # Backward cycle loss
        rec_Bt1 = self.netG_A(fake_At1)
        loss_cycle_Bt1 = self.criterionCycle(rec_Bt1, self.real_Bt1) * lambda_B

        # combined loss
        loss_G = loss_idt_A + loss_idt_B
        loss_G = loss_G + loss_G_A + loss_cycle_A
        loss_G = loss_G + loss_G_Bt0 + loss_cycle_Bt0
        if self.dynamic:
            loss_G = loss_G + loss_G_Bt1 + loss_cycle_Bt1
        loss_G.backward()

        self.fake_B = fake_B.data
        self.fake_At0 = fake_At0.data
        self.fake_At1 = fake_At1.data
        self.rec_A = rec_A.data
        self.rec_Bt0 = rec_Bt0.data
        self.rec_Bt1 = rec_Bt1.data

        self.loss_G_A = loss_G_A.item()
        self.loss_G_Bt0 = loss_G_Bt0.item()
        self.loss_G_Bt1 = loss_G_Bt1.item()
        self.loss_cycle_A = loss_cycle_A.item()
        self.loss_cycle_Bt0 = loss_cycle_Bt0.item()
        self.loss_cycle_Bt1 = loss_cycle_Bt1.item()

    def optimize_parameters(self):
        # forward
        self.forward()
        # G_A and G_B
        self.optimizer_G.zero_grad()
        self.backward_G()
        self.optimizer_G.step()
        # D_A
        self.optimizer_D_A.zero_grad()
        self.backward_D_A()
        self.optimizer_D_A.step()
        # D_B
        self.optimizer_D_B.zero_grad()
        self.backward_D_B()
        self.optimizer_D_B.step()

    def get_current_errors(self):
        ret_errors = OrderedDict([('D_A', self.loss_D_A),
                                  ('G_A', self.loss_G_A),
                                  ('Cyc_A', self.loss_cycle_A),
                                  ('D_B', self.loss_D_B),
                                  ('G_B', self.loss_G_Bt0),
                                  ('Cyc_B', self.loss_cycle_Bt0)])
        # if self.opt.identity > 0.0:
        ret_errors['idt_A'] = self.loss_idt_A
        ret_errors['idt_B'] = self.loss_idt_B
        return ret_errors

    # helper saving function that can be used by subclasses
    def save_network(self, network, network_label, path):
        save_filename = 'model_{}.pth'.format(network_label)
        save_path = os.path.join(path, save_filename)
        torch.save(network.state_dict(), save_path)

    def save(self, path):
        self.save_network(self.netG_A, 'G_A', path)
        self.save_network(self.netD_A, 'D_A', path)
        self.save_network(self.netG_B, 'G_B', path)
        self.save_network(self.netD_B, 'D_B', path)

    def load_network(self, network, network_label, path):
        weight_filename = 'model_{}.pth'.format(network_label)
        weight_path = os.path.join(path, weight_filename)
        network.load_state_dict(torch.load(weight_path))

    def load(self, path):
        self.load_network(self.netG_A, 'G_A', path)
        self.load_network(self.netG_B, 'G_B', path)

    def plot_points(self, item, label):
        item = item.cpu().data.numpy()
        plt.scatter(item[:, 0], item[:, 1], label=label)

    def visual(self, path):
        plt.rcParams['figure.figsize'] = (8.0, 3.0)
        plt.xlim(-4, 4)
        plt.ylim(-1.5, 1.5)
        self.plot_points(self.real_A, 'realA')
        self.plot_points(self.fake_B, 'fake_B')
        self.plot_points(self.rec_A, 'rec_A')
        # self.plot_points(self.real_B,'real_B')
        for p1, p2 in zip(self.real_A, self.fake_B):
            p1, p2 = p1.cpu().data.numpy(), p2.cpu().data.numpy()
            plt.plot([p1[0], p2[0]], [p1[1], p2[1]])
        plt.legend()
        plt.savefig(path)
        plt.cla()
        plt.clf()
コード例 #24
0
    def __init__(self, p):

        super(CycleGAN, self).__init__(p)
        nb = p.batchSize
        size = p.cropSize

        # load/define models
        # The naming conversion is different from those used in the paper
        # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)

        self.netG_A = networks.define_G(p.input_nc, p.output_nc, p.ngf,
                                        p.which_model_netG, p.norm,
                                        not p.no_dropout, p.init_type,
                                        self.gpu_ids)
        self.netG_B = networks.define_G(p.output_nc, p.input_nc, p.ngf,
                                        p.which_model_netG, p.norm,
                                        not p.no_dropout, p.init_type,
                                        self.gpu_ids)

        if self.isTrain:
            use_sigmoid = p.no_lsgan
            self.netD_A = networks.define_D(p.output_nc, p.ndf,
                                            p.which_model_netD, p.n_layers_D,
                                            p.norm, use_sigmoid, p.init_type,
                                            self.gpu_ids)
            self.netD_B = networks.define_D(p.input_nc, p.ndf,
                                            p.which_model_netD, p.n_layers_D,
                                            p.norm, use_sigmoid, p.init_type,
                                            self.gpu_ids)

        if not self.isTrain or p.continue_train:
            which_epoch = p.which_epoch
            self.load_model(self.netG_A, 'G_A', which_epoch)
            self.load_model(self.netG_B, 'G_B', which_epoch)
            if self.isTrain:
                self.load_model(self.netD_A, 'D_A', which_epoch)
                self.load_model(self.netD_B, 'D_B', which_epoch)

        if self.isTrain:
            self.old_lr = p.lr
            self.fake_A_pool = ImagePool(p.pool_size)
            self.fake_B_pool = ImagePool(p.pool_size)
            # define loss functions
            self.criterionGAN = networks.GANLoss(use_lsgan=not p.no_lsgan,
                                                 tensor=self.Tensor)
            self.criterionCycle = torch.nn.L1Loss()
            self.criterionIdt = torch.nn.L1Loss()

            # initialize optimizers
            self.optimizer_G = torch.optim.Adam(itertools.chain(
                self.netG_A.parameters(), self.netG_B.parameters()),
                                                lr=p.lr,
                                                betas=(p.beta1, 0.999))
            self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(),
                                                  lr=p.lr,
                                                  betas=(p.beta1, 0.999))
            self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(),
                                                  lr=p.lr,
                                                  betas=(p.beta1, 0.999))

            self.optimizers = []
            self.schedulers = []
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D_A)
            self.optimizers.append(self.optimizer_D_B)
            for optimizer in self.optimizers:
                self.schedulers.append(networks.get_scheduler(optimizer, p))
コード例 #25
0
def train():
  if FLAGS.load_model is not None:
    checkpoints_dir = "checkpoints/" + FLAGS.load_model
  else:
    current_time = datetime.now().strftime("%Y%m%d-%H%M")
    checkpoints_dir = "checkpoints/{}".format(current_time)
    try:
      os.makedirs(checkpoints_dir)
    except os.error:
      pass

  graph = tf.Graph()
  with graph.as_default():
    cycle_gan = CycleGAN(
        X_train_file=FLAGS.X,
        Y_train_file=FLAGS.Y,
        batch_size=FLAGS.batch_size,
        image_size_w=FLAGS.image_size_w,
        image_size_h=FLAGS.image_size_h,
        use_lsgan=FLAGS.use_lsgan,
        norm=FLAGS.norm,
        lambda1=FLAGS.lambda1,
        lambda2=FLAGS.lambda1,
        beta1=FLAGS.beta1,
        ngf=FLAGS.ngf,
    )
    G_loss, C_loss, fake_y, fake_x = cycle_gan.model()
    G_optimizer, C_optimizer = cycle_gan.optimize(G_loss, C_loss)

    summary_op = tf.summary.merge_all()
    train_writer = tf.summary.FileWriter(checkpoints_dir, graph)
    saver = tf.train.Saver()

  with tf.Session(graph=graph) as sess:
    if FLAGS.load_model is not None:
      checkpoint = tf.train.get_checkpoint_state(checkpoints_dir)
      meta_graph_path = checkpoint.model_checkpoint_path + ".meta"
      restore = tf.train.import_meta_graph(meta_graph_path)
      restore.restore(sess, tf.train.latest_checkpoint(checkpoints_dir))
      step = int(meta_graph_path.split("-")[2].split(".")[0])
    else:
      sess.run(tf.global_variables_initializer())
      step = 0

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    try:
      fake_Y_pool = ImagePool(FLAGS.pool_size)
      fake_X_pool = ImagePool(FLAGS.pool_size)

      while not coord.should_stop():
        # get previously generated images
        fake_y_val, fake_x_val = sess.run([fake_y, fake_x])

        # train
        adjusted_lr = (FLAGS.learning_rate *
                           0.5 ** max(0, (step / FLAGS.decay_step) - 2))
        feed_ = {cycle_gan.fake_y: fake_Y_pool.query(fake_y_val),
                    cycle_gan.fake_x: fake_X_pool.query(fake_x_val),
                    cycle_gan.learning_rate: adjusted_lr}
        # update D 5 times before update G
        for i in range(5):
            _ = sess.run(C_optimizer, feed_dict=feed_)
        _ = sess.run(G_optimizer, feed_dict=feed_)

        G_loss_val, C_loss_val, summary = (
              sess.run(
                  [G_loss, C_loss, summary_op],
                  feed_dict=feed_
              )
        )

        train_writer.add_summary(summary, step)
        train_writer.flush()

        if step % 100 == 0:
          logging.info('-----------Step %d:-------------' % step)
          logging.info('  G_loss   : {}'.format(G_loss_val))
          logging.info('  C_loss   : {}'.format(C_loss_val))

        if step % 10000 == 0:
          save_path = saver.save(sess, checkpoints_dir + "/model.ckpt", global_step=step)
          logging.info("Model saved in file: %s" % save_path)

        step += 1

    except KeyboardInterrupt:
      logging.info('Interrupted')
      coord.request_stop()
    except Exception as e:
      coord.request_stop(e)
    finally:
      save_path = saver.save(sess, checkpoints_dir + "/model.ckpt", global_step=step)
      logging.info("Model saved in file: %s" % save_path)
      # When done, ask the threads to stop.
      coord.request_stop()
      coord.join(threads)
コード例 #26
0
def train():
  if FLAGS.load_model is not None:
    checkpoints_dir = "checkpoints/" + FLAGS.load_model
  else:
    current_time = datetime.now().strftime("%Y%m%d-%H%M")
    checkpoints_dir = "checkpoints/{}".format(current_time)
    try:
      os.makedirs(checkpoints_dir)
    except os.error:
      pass

  graph = tf.Graph()
  with graph.as_default():
    cycle_gan = CycleGAN(
        X_train_file=FLAGS.X,
        Y_train_file=FLAGS.Y,
        batch_size=FLAGS.batch_size,
        image_size=FLAGS.image_size,
        use_lsgan=FLAGS.use_lsgan,
        norm=FLAGS.norm,
        lambda1=FLAGS.lambda1,
        lambda2=FLAGS.lambda1,
        learning_rate=FLAGS.learning_rate,
        beta1=FLAGS.beta1,
        ngf=FLAGS.ngf
    )
    G_loss, D_Y_loss, F_loss, D_X_loss, fake_y, fake_x = cycle_gan.model()
    optimizers = cycle_gan.optimize(G_loss, D_Y_loss, F_loss, D_X_loss)

    summary_op = tf.summary.merge_all()
    train_writer = tf.summary.FileWriter(checkpoints_dir, graph)
    saver = tf.train.Saver()

  with tf.Session(graph=graph) as sess:
    if FLAGS.load_model is not None:
      checkpoint = tf.train.get_checkpoint_state(checkpoints_dir)
      meta_graph_path = checkpoint.model_checkpoint_path + ".meta"
      restore = tf.train.import_meta_graph(meta_graph_path)
      restore.restore(sess, tf.train.latest_checkpoint(checkpoints_dir))
      step = int(meta_graph_path.split("-")[2].split(".")[0])
    else:
      sess.run(tf.global_variables_initializer())
      step = 0

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    try:
      fake_Y_pool = ImagePool(FLAGS.pool_size)
      fake_X_pool = ImagePool(FLAGS.pool_size)

      while not coord.should_stop():
        # get previously generated images
        fake_y_val, fake_x_val = sess.run([fake_y, fake_x])

        # train
        _, G_loss_val, D_Y_loss_val, F_loss_val, D_X_loss_val, summary = (
              sess.run(
                  [optimizers, G_loss, D_Y_loss, F_loss, D_X_loss, summary_op],
                  feed_dict={cycle_gan.fake_y: fake_Y_pool.query(fake_y_val),
                             cycle_gan.fake_x: fake_X_pool.query(fake_x_val)}
              )
        )
        if step % 100 == 0:
          train_writer.add_summary(summary, step)
          train_writer.flush()

        if step % 100 == 0:
          logging.info('-----------Step %d:-------------' % step)
          logging.info('  G_loss   : {}'.format(G_loss_val))
          logging.info('  D_Y_loss : {}'.format(D_Y_loss_val))
          logging.info('  F_loss   : {}'.format(F_loss_val))
          logging.info('  D_X_loss : {}'.format(D_X_loss_val))

        if step % 10000 == 0:
          save_path = saver.save(sess, checkpoints_dir + "/model.ckpt", global_step=step)
          logging.info("Model saved in file: %s" % save_path)

        step += 1

    except KeyboardInterrupt:
      logging.info('Interrupted')
      coord.request_stop()
    except Exception as e:
      coord.request_stop(e)
    finally:
      save_path = saver.save(sess, checkpoints_dir + "/model.ckpt", global_step=step)
      logging.info("Model saved in file: %s" % save_path)
      # When done, ask the threads to stop.
      coord.request_stop()
      coord.join(threads)
コード例 #27
0
    def initialize(self, cfg):
        self.cfg = cfg

        ## set devices
        if cfg['GPU_IDS']:
            assert (torch.cuda.is_available())
            self.device = torch.device('cuda:{}'.format(cfg['GPU_IDS'][0]))
            torch.backends.cudnn.benchmark = True
            print('Using %d GPUs' % len(cfg['GPU_IDS']))
        else:
            self.device = torch.device('cpu')

        # define network
        if cfg['ARCHI'] == 'alexnet':
            self.netB = networks.netB_alexnet()
            self.netH = networks.netH_alexnet()
            if self.cfg['USE_DA'] and self.cfg['TRAIN']:
                self.netD = networks.netD_alexnet(self.cfg['DA_LAYER'])
        elif cfg['ARCHI'] == 'vgg16':
            raise NotImplementedError
            self.netB = networks.netB_vgg16()
            self.netH = networks.netH_vgg16()
            if self.cfg['USE_DA'] and self.cfg['TRAIN']:
                self.netD = netD_vgg16(self.cfg['DA_LAYER'])
        elif 'resnet' in cfg['ARCHI']:
            self.netB = networks.netB_resnet34()
            self.netH = networks.netH_resnet34()
            if self.cfg['USE_DA'] and self.cfg['TRAIN']:
                self.netD = networks.netD_resnet(self.cfg['DA_LAYER'])
        else:
            raise ValueError('Un-supported network')

        ## initialize network param.
        self.netB = networks.init_net(self.netB, cfg['GPU_IDS'], 'xavier')
        self.netH = networks.init_net(self.netH, cfg['GPU_IDS'], 'xavier')

        if self.cfg['USE_DA'] and self.cfg['TRAIN']:
            self.netD = networks.init_net(self.netD, cfg['GPU_IDS'], 'xavier')

        # loss, optimizer, and scherduler
        if cfg['TRAIN']:
            self.total_steps = 0
            ## Output path
            self.save_dir = os.path.join(
                cfg['OUTPUT_PATH'], cfg['ARCHI'],
                datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
            if not os.path.isdir(self.save_dir):
                os.makedirs(self.save_dir)
            # self.logger = Logger(self.save_dir)

            ## model names
            self.model_names = ['netB', 'netH']
            ## loss
            self.criterionGAN = networks.GANLoss().to(self.device)
            self.criterionDepth1 = torch.nn.MSELoss().to(self.device)
            self.criterionNorm = torch.nn.CosineEmbeddingLoss().to(self.device)
            self.criterionEdge = torch.nn.BCELoss().to(self.device)

            ## optimizers
            self.lr = cfg['LR']
            self.optimizers = []
            self.optimizer_B = torch.optim.Adam(self.netB.parameters(),
                                                lr=cfg['LR'],
                                                betas=(cfg['BETA1'],
                                                       cfg['BETA2']))
            self.optimizer_H = torch.optim.Adam(self.netH.parameters(),
                                                lr=cfg['LR'],
                                                betas=(cfg['BETA1'],
                                                       cfg['BETA2']))
            self.optimizers.append(self.optimizer_B)
            self.optimizers.append(self.optimizer_H)
            if cfg['USE_DA']:
                self.real_pool = ImagePool(cfg['POOL_SIZE'])
                self.syn_pool = ImagePool(cfg['POOL_SIZE'])
                self.model_names.append('netD')
                ## use SGD for discriminator
                self.optimizer_D = torch.optim.SGD(
                    self.netD.parameters(),
                    lr=cfg['LR'],
                    momentum=cfg['MOMENTUM'],
                    weight_decay=cfg['WEIGHT_DECAY'])
                self.optimizers.append(self.optimizer_D)
            ## LR scheduler
            self.schedulers = [
                networks.get_scheduler(optimizer, cfg)
                for optimizer in self.optimizers
            ]
        else:
            ## testing
            self.model_names = ['netB', 'netH']
            self.criterionDepth1 = torch.nn.MSELoss().to(self.device)
            self.criterionNorm = torch.nn.CosineEmbeddingLoss(
                reduction='none').to(self.device)

        self.load_dir = os.path.join(cfg['CKPT_PATH'])
        self.criterionNorm_eval = torch.nn.CosineEmbeddingLoss(
            reduction='none').to(self.device)

        if cfg['TEST'] or cfg['RESUME']:
            self.load_networks(cfg['EPOCH_LOAD'])
コード例 #28
0
ファイル: train.py プロジェクト: sasayabaku/pytorch-training
def main():

    args = args_initialize()

    save_freq = args.save_freq
    epochs = args.num_epoch
    cuda = args.cuda

    train_dataset = UnalignedDataset(is_train=True)
    train_loader = DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=0
    )

    net_G_A = ResNetGenerator(input_nc=3, output_nc=3)
    net_G_B = ResNetGenerator(input_nc=3, output_nc=3)
    net_D_A = Discriminator()
    net_D_B = Discriminator()

    if args.cuda:
        net_G_A = net_G_A.cuda()
        net_G_B = net_G_B.cuda()
        net_D_A = net_D_A.cuda()
        net_D_B = net_D_B.cuda()

    fake_A_pool = ImagePool(50)
    fake_B_pool = ImagePool(50)

    criterionGAN = GANLoss(cuda=cuda)
    criterionCycle = torch.nn.L1Loss()
    criterionIdt = torch.nn.L1Loss()

    optimizer_G = torch.optim.Adam(
        itertools.chain(net_G_A.parameters(), net_G_B.parameters()),
        lr=args.lr,
        betas=(args.beta1, 0.999)
    )
    optimizer_D_A = torch.optim.Adam(net_D_A.parameters(), lr=args.lr, betas=(args.beta1, 0.999))
    optimizer_D_B = torch.optim.Adam(net_D_B.parameters(), lr=args.lr, betas=(args.beta1, 0.999))

    log_dir = './logs'
    checkpoints_dir = './checkpoints'
    os.makedirs(log_dir, exist_ok=True)
    os.makedirs(checkpoints_dir, exist_ok=True)

    writer = SummaryWriter(log_dir)

    for epoch in range(epochs):

        running_loss = np.zeros((8))
        for batch_idx, data in enumerate(train_loader):

            input_A = data['A']
            input_B = data['B']

            if cuda:
                input_A = input_A.cuda()
                input_B = input_B.cuda()

            real_A = Variable(input_A)
            real_B = Variable(input_B)


            """
            Backward net_G
            """
            optimizer_G.zero_grad()
            lambda_idt = 0.5
            lambda_A = 10.0
            lambda_B = 10.0

            # 各 Generatorに変換後の画像を入力
            # 何もしないのが理想の出力
            idt_B = net_G_A(real_B)
            loss_idt_A = criterionIdt(idt_B, real_B) * lambda_B * lambda_idt

            idt_A = net_G_B(real_A)
            loss_idt_B = criterionIdt(idt_A, real_A) * lambda_A * lambda_idt

            # GAN loss = D_A(G_A(A))
            # G_Aとしては生成した偽物画像が本物(True)と判断して欲しい
            fake_B = net_G_A(real_A)
            pred_fake = net_D_A(fake_B)
            loss_G_A = criterionGAN(pred_fake, True)

            fake_A = net_G_B(real_B)
            pred_fake = net_D_B(fake_A)
            loss_G_B = criterionGAN(pred_fake, True)

            rec_A = net_G_B(fake_B)
            loss_cycle_A = criterionCycle(rec_A, real_A) * lambda_A

            rec_B = net_G_A(fake_A)
            loss_cycle_B = criterionCycle(rec_B, real_B) * lambda_B

            loss_G = loss_G_A + loss_G_B + loss_cycle_A + loss_cycle_B + loss_idt_A + loss_idt_B
            loss_G.backward()

            optimizer_G.step()

            """
            update D_A
            """
            optimizer_D_A.zero_grad()
            fake_B = fake_B_pool.query(fake_B.data)

            pred_real = net_D_A(real_B)
            loss_D_real = criterionGAN(pred_real, True)

            pred_fake = net_D_A(fake_B.detach())
            loss_D_fake = criterionGAN(pred_fake, False)

            loss_D_A = (loss_D_real + loss_D_fake) * 0.5
            loss_D_A.backward()

            optimizer_D_A.step()

            """
            update D_B
            """
            optimizer_D_B.zero_grad()
            fake_A = fake_A_pool.query(fake_A.data)

            pred_real = net_D_B(real_A)
            loss_D_real = criterionGAN(pred_real, True)

            pred_fake = net_D_B(fake_A.detach())
            loss_D_fake = criterionGAN(pred_fake, False)

            loss_D_B = (loss_D_real + loss_D_fake) * 0.5
            loss_D_B.backward()


            optimizer_D_B.step()

            ret_loss = np.array([
                loss_G_A.data.detach().cpu().numpy(), loss_D_A.data.detach().cpu().numpy(),
                loss_G_B.data.detach().cpu().numpy(), loss_D_B.data.detach().cpu().numpy(),
                loss_cycle_A.data.detach().cpu().numpy(), loss_cycle_B.data.detach().cpu().numpy(),
                loss_idt_A.data.detach().cpu().numpy(), loss_idt_B.data.detach().cpu().numpy()
            ])
            running_loss += ret_loss

            """
            Save checkpoints
            """
            if (epoch + 1) % save_freq == 0:
                save_network(net_G_A, 'G_A', str(epoch + 1))
                save_network(net_D_A, 'D_A', str(epoch + 1))
                save_network(net_G_B, 'G_B', str(epoch + 1))
                save_network(net_D_B, 'D_B', str(epoch + 1))

        running_loss /= len(train_loader)
        losses = running_loss
        print('epoch %d, losses: %s' % (epoch + 1, running_loss))

        writer.add_scalar('loss_G_A', losses[0], epoch)
        writer.add_scalar('loss_D_A', losses[1], epoch)
        writer.add_scalar('loss_G_B', losses[2], epoch)
        writer.add_scalar('loss_D_B', losses[3], epoch)
        writer.add_scalar('loss_cycle_A', losses[4], epoch)
        writer.add_scalar('loss_cycle_B', losses[5], epoch)
        writer.add_scalar('loss_idt_A', losses[6], epoch)
        writer.add_scalar('loss_idt_B', losses[7], epoch)
コード例 #29
0
ファイル: cyclechain.py プロジェクト: kiminh/Cycle_Dynamics
class CycleGANModel():
    def __init__(self,opt):
        self.opt = opt
        self.dynamic = opt.dynamic
        self.isTrain = opt.istrain
        self.Tensor = torch.cuda.FloatTensor

        # load/define networks
        # The naming conversion is different from those used in the paper
        # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)

        self.netG_A = state2img().cuda()
        self.netG_B = img2state().cuda()
        self.netF_A = Fmodel().cuda()
        self.dataF = CDFdata.get_loader(opt)
        self.train_forward()

        if self.isTrain:
            self.netD_A = imgDmodel().cuda()
            self.netD_B = stateDmodel().cuda()

        if self.isTrain:
            self.fake_A_pool = ImagePool(pool_size=128)
            self.fake_B_pool = ImagePool(pool_size=128)
            # define loss functions
            self.criterionGAN = GANLoss(tensor=self.Tensor).cuda()
            if opt.loss == 'l1':
                self.criterionCycle = torch.nn.L1Loss()
                self.criterionIdt = torch.nn.L1Loss()
            elif opt.loss == 'l2':
                self.criterionCycle = torch.nn.MSELoss()
                self.criterionIdt = torch.nn.MSELoss()
            # initialize optimizers
            # self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()))
            self.optimizer_G = torch.optim.Adam([{'params':self.netG_A.parameters(),'lr':1e-3},
                                                 {'params':self.netF_A.parameters(),'lr':0.0},
                                                 {'params':self.netG_B.parameters(),'lr':1e-3}])
            self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters())
            self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters())
            self.optimizers = []
            self.schedulers = []
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D_A)
            self.optimizers.append(self.optimizer_D_B)
            for optimizer in self.optimizers:
                # self.schedulers.append(networks.get_scheduler(optimizer, opt))
                self.schedulers.append(optimizer)

        print('---------- Networks initialized -------------')
        # networks.print_network(self.netG_A)
        # networks.print_network(self.netG_B)
        # if self.isTrain:
        #     networks.print_network(self.netD_A)
        #     networks.print_network(self.netD_B)
        print('-----------------------------------------------')


    def train_forward(self,pretrained=True):
        if pretrained:
            self.netF_A.load_state_dict(torch.load('./pred.pth'))
            return None
        optimizer = torch.optim.Adam(self.netF_A.parameters(),lr=1e-3)
        loss_fn = torch.nn.L1Loss()
        for epoch in range(100):
            epoch_loss = 0
            for i,item in enumerate(tqdm(self.dataF)):
                state, action, result = item[1]
                state = state.float().cuda()
                action = action.float().cuda()
                result = result.float().cuda()
                out = self.netF_A(state, action)
                loss = loss_fn(out, result)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                epoch_loss += loss.item()
            print('epoch:{} loss:{:.7f}'.format(epoch,epoch_loss/len(self.dataF)))
        print('forward model has been trained!')

    def set_input(self, input):
        # AtoB = self.opt.which_direction == 'AtoB'
        # input_A = input['A' if AtoB else 'B']
        # input_B = input['B' if AtoB else 'A']
        # A is state
        self.input_A = input[1][0]

        # B is img
        self.input_Bt0 = input[0][0]
        self.input_Bt1 = input[0][2]
        self.action = input[0][1]
        self.gt0 = input[2][0].float().cuda()
        self.gt1 = input[2][1].float().cuda()


    def forward(self):
        self.real_A = Variable(self.input_A).float().cuda()
        self.real_Bt0 = Variable(self.input_Bt0).float().cuda()
        self.real_Bt1 = Variable(self.input_Bt1).float().cuda()
        self.action = Variable(self.action).float().cuda()


    def test(self):
        self.forward()
        real_A = Variable(self.input_A, volatile=True).float().cuda()
        fake_B = self.netG_A(real_A)
        self.rec_A = self.netG_B(fake_B).data
        self.fake_B = fake_B.data

        real_B = Variable(self.input_Bt0, volatile=True).float().cuda()
        fake_A = self.netG_B(real_B)
        self.rec_B = self.netG_A(fake_A).data
        self.fake_A = fake_A.data


    def backward_D_basic(self, netD, real, fake):
        # Real
        pred_real = netD(real)
        loss_D_real = self.criterionGAN(pred_real, True)
        # Fake
        pred_fake = netD(fake.detach())
        loss_D_fake = self.criterionGAN(pred_fake, False)
        # Combined loss
        loss_D = (loss_D_real + loss_D_fake) * 0.5
        # backward
        loss_D.backward()
        return loss_D

    def backward_D_A(self):
        fake_B = self.fake_B_pool.query(self.fake_B)
        loss_D_A = self.backward_D_basic(self.netD_A, self.real_Bt0, fake_B)
        self.loss_D_A = loss_D_A.item()

    def backward_D_B(self):
        fake_A = self.fake_A_pool.query(self.fake_At0)
        loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A)
        self.loss_D_B = loss_D_B.item()

    def backward_G(self):
        lambda_idt = -0.5
        lambda_A = self.opt.lambda_A
        lambda_B = self.opt.lambda_B
        lambda_F = self.opt.lambda_F
        # Identity loss
        if lambda_idt > 0:
            # G_A should be identity if real_B is fed.
            idt_A = self.netG_A(self.real_Bt0)
            loss_idt_A = self.criterionIdt(idt_A, self.real_Bt0) * lambda_B * lambda_idt
            # G_B should be identity if real_A is fed.
            idt_B = self.netG_B(self.real_A)
            loss_idt_B = self.criterionIdt(idt_B, self.real_A) * lambda_A * lambda_idt

            self.idt_A = idt_A.data
            self.idt_B = idt_B.data
            self.loss_idt_A = loss_idt_A.item()
            self.loss_idt_B = loss_idt_B.item()
        else:
            loss_idt_A = 0
            loss_idt_B = 0
            self.loss_idt_A = 0
            self.loss_idt_B = 0

        lambda_G_A = 1.0
        lambda_G_B = 1.0

        # --------first cycle-----------#
        # GAN loss D_A(G_A(A))
        fake_B = self.netG_A(self.real_A)
        pred_fake = self.netD_A(fake_B)
        loss_G_A = self.criterionGAN(pred_fake, True) * lambda_G_A
        # Forward cycle loss
        rec_A = self.netG_B(fake_B)
        loss_cycle_A = self.criterionCycle(rec_A, self.real_A) * lambda_A

        # ---------second cycle---------#
        # GAN loss D_B(G_B(B))
        fake_At0 = self.netG_B(self.real_Bt0)
        pred_fake = self.netD_B(fake_At0)
        loss_G_Bt0 = self.criterionGAN(pred_fake, True) * lambda_G_B
        # Backward cycle loss
        rec_Bt0 = self.netG_A(fake_At0)
        loss_cycle_Bt0 = self.criterionCycle(rec_Bt0, self.real_Bt0) * lambda_B

        # ---------third cycle---------#
        # GAN loss D_B(G_B(B))
        fake_At1 = self.netF_A(fake_At0,self.action)
        pred_fake = self.netD_B(fake_At1)
        loss_G_Bt1 = self.criterionGAN(pred_fake, True) * lambda_G_B
        # Backward cycle loss
        rec_Bt1 = self.netG_A(fake_At1)
        loss_cycle_Bt1 = self.criterionCycle(rec_Bt1, self.real_Bt1) * lambda_F


        # combined loss
        loss_G = loss_idt_A + loss_idt_B
        loss_G = loss_G + loss_G_A + loss_cycle_A
        loss_G = loss_G + loss_G_Bt0 + loss_cycle_Bt0
        if self.dynamic:
            loss_G = loss_G + loss_G_Bt1 + loss_cycle_Bt1
        loss_G.backward()

        self.fake_B = fake_B.data
        self.fake_At0 = fake_At0.data
        self.fake_At1 = fake_At1.data
        self.rec_A = rec_A.data
        self.rec_Bt0 = rec_Bt0.data
        self.rec_Bt1 = rec_Bt1.data

        self.loss_G_A = loss_G_A.item()
        self.loss_G_Bt0 = loss_G_Bt0.item()
        self.loss_G_Bt1 = loss_G_Bt1.item()
        self.loss_cycle_A = loss_cycle_A.item()
        self.loss_cycle_Bt0 = loss_cycle_Bt0.item()
        self.loss_cycle_Bt1 = loss_cycle_Bt1.item()

        self.loss_state_l1 = self.criterionCycle(self.fake_At0, self.gt).item()


    def optimize_parameters(self):
        # forward
        self.forward()
        # G_A and G_B
        self.optimizer_G.zero_grad()
        self.backward_G()
        self.optimizer_G.step()
        # D_A
        self.optimizer_D_A.zero_grad()
        self.backward_D_A()
        self.optimizer_D_A.step()
        # D_B
        self.optimizer_D_B.zero_grad()
        self.backward_D_B()
        self.optimizer_D_B.step()

    def get_current_errors(self):
        ret_errors = OrderedDict([('L1',self.loss_state_l1), ('D_A', self.loss_D_A), ('G_A', self.loss_G_A),
                                  ('Cyc_A', self.loss_cycle_A), ('D_B', self.loss_D_B),
                                  ('G_B0', self.loss_G_Bt0), ('G_B1', self.loss_G_Bt1),
                                  ('Cyc_B0',  self.loss_cycle_Bt0), ('Cyc_B1',  self.loss_cycle_Bt1)])
        # if self.opt.identity > 0.0:
        #     ret_errors['idt_A'] = self.loss_idt_A
        #     ret_errors['idt_B'] = self.loss_idt_B
        return ret_errors

    # helper saving function that can be used by subclasses
    def save_network(self, network, network_label, path):
        save_filename = 'model_{}.pth'.format(network_label)
        save_path = os.path.join(path, save_filename)
        torch.save(network.state_dict(), save_path)

    def save(self, path):
        self.save_network(self.netG_A, 'G_A', path)
        self.save_network(self.netD_A, 'D_A', path)
        self.save_network(self.netG_B, 'G_B', path)
        self.save_network(self.netD_B, 'D_B', path)

    def load_network(self, network, network_label, path):
        weight_filename = 'model_{}.pth'.format(network_label)
        weight_path = os.path.join(path, weight_filename)
        network.load_state_dict(torch.load(weight_path))

    def load(self,path):
        self.load_network(self.netG_A, 'G_A', path)
        self.load_network(self.netG_B, 'G_B', path)

    def plot_points(self,item,label):
        item = item.cpu().data.numpy()
        plt.scatter(item[:,0],item[:,1],label=label)

    def visual(self,path):
        imgs = []
        for i in range(self.real_A.shape[0]):
            imgs_i = [self.real_Bt0[i], self.rec_Bt0[i]]
            imgs_i += [self.real_Bt1[i], self.rec_Bt1[i]]
            imgs_i += [self.fake_B[i]]
            imgs_i = torch.cat(imgs_i, 2).cpu()
            imgs.append(imgs_i)
        imgs = torch.cat(imgs, 1)
        imgs = (imgs + 1) / 2
        imgs = transforms.ToPILImage()(imgs)
        imgs.save(path)
コード例 #30
0
def train(epochs=100, batch_size=1):
    #生成器
    #    img_shape = (256, 256, 3)
    netG = CycleGAN()
    netG_XY, real_X, fake_Y = netG.generator()
    netG_YX, real_Y, fake_X = netG.generator()

    reconstruct_X = netG_YX(fake_Y)
    reconstruct_Y = netG_XY(fake_X)
    #鉴别器
    netD = CycleGAN()
    netD_X = netD.discriminator()
    netD_Y = netD.discriminator()

    netD_X_predict_fake = netD_X(fake_X)
    netD_Y_predict_fake = netD_Y(fake_Y)
    netD_X_predict_real = netD_X(real_X)
    netD_Y_predict_real = netD_Y(real_Y)
    #    netD_X.summary()
    #优化器
    optimizer = Adam(lr=0.001,
                     beta_1=0.5,
                     beta_2=0.999,
                     epsilon=None,
                     decay=0.01)
    #    netG_XY.summary()
    #    plot_model(netG_XY, to_file='./netG_XY_model_graph.png')
    #GAN
    netD_X.trainable = False  #冻结
    netD_Y.trainable = False
    netG_loss_inputs = [
        netD_X_predict_fake, reconstruct_X, real_X, netD_Y_predict_fake,
        reconstruct_Y, real_Y
    ]
    netG_train = Model([real_X, real_Y], Lambda(netG_loss)(netG_loss_inputs))
    netG_train.compile(loss='mae', optimizer=optimizer, metrics=['accuracy'])

    _fake_X_inputs = Input(shape=(256, 256, 3))
    _fake_Y_inputs = Input(shape=(256, 256, 3))
    _netD_X_predict_fake = netD_X(_fake_X_inputs)
    _netD_Y_predict_fake = netD_Y(_fake_Y_inputs)
    netD_X.trainable = True
    netD_X_train = Model(
        [real_X, _fake_X_inputs],
        Lambda(netD_loss)([netD_X_predict_real, _netD_X_predict_fake]))
    netD_X_train.compile(loss='mae', optimizer=optimizer,
                         metrics=['accuracy'])  #均方误差

    netD_X.trainable = False
    netD_Y.trainable = True
    netD_Y_train = Model(
        [real_Y, _fake_Y_inputs],
        Lambda(netD_loss)([netD_Y_predict_real, _netD_Y_predict_fake]))
    netD_Y_train.compile(loss='mae', optimizer=optimizer, metrics=['accuracy'])

    dataloader = Dataloader()
    fake_X_pool = ImagePool()
    fake_Y_pool = ImagePool()

    netG_X_function = get_G_function(netG_XY)
    netG_Y_function = get_G_function(netG_YX)
    if len(os.listdir('./weights')):
        netG_train.load_weights('./weights/netG.h5')
        netD_X_train.load_weights('./weights/netD_X.h5')
        netD_Y_train.load_weights

    print('Info: Strat Training\n')
    for epoch in range(epochs):

        target_label = np.zeros((batch_size, 1))

        for batch_i, (imgs_X,
                      imgs_Y) in enumerate(dataloader.load_batch(batch_size)):
            start_time = time.time()
            num_batch = 0
            tmp_fake_X = netG_X_function([imgs_X])[0]
            tmp_fake_Y = netG_Y_function([imgs_Y])[0]

            #从缓存区读取图片
            _fake_X = fake_X_pool.action(tmp_fake_X)
            _fake_Y = fake_Y_pool.action(tmp_fake_Y)
            if batch_i % 2 == 0:
                save_image('fake_X_' + str(epoch) + '_' + str(batch_i),
                           _fake_X[0])
                save_image('fake_Y_' + str(epoch) + '_' + str(batch_i),
                           _fake_Y[0])
            _netG_loss = netG_train.train_on_batch([imgs_X, imgs_Y],
                                                   target_label)
            netD_X_loss = netD_X_train.train_on_batch([imgs_X, _fake_X],
                                                      target_label)
            netD_Y_loss = netD_Y_train.train_on_batch([imgs_Y, _fake_Y],
                                                      target_label)
            num_batch += 1
            diff = time.time() - start_time
            print('Epoch:{}/{},netG_loss:{}, netD_loss:{},{}, time_cost_per_epoch:{}/epoch'\
              .format(epoch+1, epochs, _netG_loss, netD_X_loss, netD_Y_loss, diff, diff/num_batch))

        netG_train.save_weights('./weights/netG.h5')
        netD_X_train.save_weights('./weights/netD_X.h5')
        netD_Y_train.save_weights('./weights/netD_Y.hs')
        print('Model saved!\n')
    pass
コード例 #31
0
ファイル: cycleGAN.py プロジェクト: roshandhakal/Bokeh-Effect
class cycleGan():

    # def __init__(self,  g_conv_dim=64, d_conv_dim=64,res_blocks=4,lr=0.001, beta1=0.5, beta2=0.999):
    def __init__(self, opt):
        super(cycleGan, self).__init__()
        self.opt = opt
        self.G_XtoY = RensetGenerator(opt.input_nc, opt.output_nc, opt.ngf,
                                      opt.norm, not opt.no_dropout,
                                      opt.n_blocks,
                                      opt.padding_type).to(device)
        # self.G_YtoX = CycleGenerator(d_conv_dim, res_blocks)
        self.G_YtoX = RensetGenerator(opt.output_nc, opt.input_nc, opt.ngf,
                                      opt.norm, not opt.no_dropout,
                                      opt.n_blocks,
                                      opt.padding_type).to(device)
        # self.D_X = CycleDiscriminator(d_conv_dim)
        self.D_X = PatchDiscriminator(opt.output_nc, opt.ndf, opt.n_layers_D,
                                      opt.norm).to(device)
        # self.D_Y = CycleDiscriminator(d_conv_dim)
        self.D_Y = PatchDiscriminator(opt.input_nc, opt.ndf, opt.n_layers_D,
                                      opt.norm).to(device)
        # self.G_XtoY.apply(init_weights)
        # self.G_YtoX.apply(init_weights)
        # self.D_X.apply(init_weights)
        # self.D_Y.apply(init_weights)
        print(self.G_XtoY)
        print("Parameters: ", len(list(self.G_XtoY.parameters())))
        print(self.G_YtoX)
        print("Parameters: ", len(list(self.G_YtoX.parameters())))
        print(self.D_X)
        print("Parameters: ", len(list(self.D_X.parameters())))
        print(self.D_Y)
        print("Parameters: ", len(list(self.D_Y.parameters())))
        self.fake_X_pool = ImagePool(opt.pool_size)
        self.fake_Y_pool = ImagePool(opt.pool_size)
        self.criterionCycle = torch.nn.L1Loss()
        self.criterionIdt = torch.nn.L1Loss()  #For implementing Identity Loss
        self.optimizer_G = torch.optim.Adam(itertools.chain(
            self.G_XtoY.parameters(), self.G_YtoX.parameters()),
                                            lr=opt.lr,
                                            betas=[opt.beta1, 0.999])
        self.optimizer_DX = torch.optim.Adam(self.D_X.parameters(),
                                             lr=opt.lr,
                                             betas=[opt.beta1, 0.999])
        self.optimizer_DY = torch.optim.Adam(self.D_Y.parameters(),
                                             lr=opt.lr,
                                             betas=[opt.beta1, 0.999])

    def get_input(self, inputX, inputY):
        self.inputX = inputX
        self.inputY = inputY

    def forward(self):
        self.fake_X = self.G_YtoX(self.inputY).to(device)
        self.rec_Y = self.G_XtoY(self.fake_X).to(device)
        self.fake_Y = self.G_XtoY(self.inputX).to(device)
        self.rec_X = self.G_YtoX(self.fake_Y).to(device)

    def backward_D_basic(self, netD, real, fake):
        pred_real = netD(real)
        loss_D_real = real_mse_loss(pred_real)  # Fake
        pred_fake = netD(fake.detach())
        loss_D_fake = fake_mse_loss(pred_fake)
        # Combined loss and calculate gradients
        loss_D = (loss_D_real + loss_D_fake) * 0.5
        loss_D.backward()
        return loss_D

    def backward_D_X(self):
        """Calculate GAN loss for discriminator D_A"""
        if self.opt.pool == True:
            fake_X = self.fake_X_pool.query(self.fake_X)
            self.loss_D_X = self.backward_D_basic(self.D_X, self.inputX,
                                                  fake_X)
        else:
            self.loss_D_X = self.backward_D_basic(self.D_X, self.inputX,
                                                  self.fake_X)

    def backward_D_Y(self):
        """Calculate GAN loss for discriminator D_A"""
        if self.opt.pool == True:
            fake_Y = self.fake_Y_pool.query(self.fake_Y)
            self.loss_D_Y = self.backward_D_basic(self.D_Y, self.inputY,
                                                  fake_Y)
        else:
            self.loss_D_Y = self.backward_D_basic(self.D_Y, self.inputY,
                                                  self.fake_Y)

    def backward_G(self):
        # Not implemented identity loss

        lambda_A = self.opt.lambda_A
        lambda_B = self.opt.lambda_B

        self.loss_G_X = real_mse_loss(self.D_X(self.fake_X))
        # GAN loss D_B(G_B(B))
        self.loss_G_Y = real_mse_loss(self.D_Y(self.fake_Y))
        # Forward cycle loss || G_B(G_A(A)) - A||
        self.loss_cycle_X = self.criterionCycle(self.rec_X,
                                                self.inputX) * lambda_A
        # Backward cycle loss || G_A(G_B(B)) - B||
        self.loss_cycle_Y = self.criterionCycle(self.rec_Y,
                                                self.inputY) * lambda_B
        # combined loss and calculate gradients
        self.loss_G = self.loss_G_X + self.loss_G_Y + self.loss_cycle_X + self.loss_cycle_Y
        self.loss_G.backward()

    def change_lr(self, new_lr):
        self.opt.lr = new_lr
        self.optimizer_G = torch.optim.Adam(itertools.chain(
            self.G_XtoY.parameters(), self.G_YtoX.parameters()),
                                            lr=self.opt.lr,
                                            betas=[self.opt.beta1, 0.999])
        self.optimizer_DX = torch.optim.Adam(self.D_X.parameters(),
                                             lr=self.opt.lr,
                                             betas=[self.opt.beta1, 0.999])
        self.optimizer_DY = torch.optim.Adam(self.D_Y.parameters(),
                                             lr=self.opt.lr,
                                             betas=[self.opt.beta1, 0.999])

    def set_requires_grad(self, nets, requires_grad=False):

        if not isinstance(nets, list):
            nets = [nets]
        for net in nets:
            if net is not None:
                for param in net.parameters():
                    param.requires_grad = requires_grad

    def optimize(self):

        self.set_requires_grad([self.D_X, self.D_Y], False)
        self.forward()
        self.optimizer_G.zero_grad()  # set G_A and G_B's gradients to zero
        self.backward_G()  # calculate gradients for G_A and G_B
        self.optimizer_G.step()  # update G_A and G_B's weights

        self.set_requires_grad([self.D_X, self.D_Y], True)
        self.forward()
        self.optimizer_G.zero_grad()  # set G_A and G_B's gradients to zero
        self.backward_G()  # calculate gradients for G_A and G_B
        self.optimizer_G.step()

        # D_A and D_B
        self.optimizer_DX.zero_grad()
        self.optimizer_DY.zero_grad()  # set D_A and D_B's gradients to zero
        self.backward_D_X()  # calculate gradients for D_A
        self.backward_D_Y()  # calculate graidents for D_B
        self.optimizer_DX.step()  # update D_A and D_B's weights
        self.optimizer_DY.step()  # update D_A and D_B's weights