Esempio n. 1
0
    def __init__(self,
                 lambda_ABA=settings.lambda_ABA,
                 lambda_BAB=settings.lambda_BAB,
                 lambda_local=settings.lambda_local,
                 pool_size=settings.pool_size,
                 max_crop_side=settings.max_crop_side,
                 decay_start=settings.decay_start,
                 epochs_to_zero_lr=settings.epochs_to_zero_lr,
                 warm_epochs=settings.warmup_epochs):
        super(GAN, self).__init__()

        self.r = 0
        self.lambda_ABA = lambda_ABA
        self.lambda_BAB = lambda_BAB
        self.lambda_local = lambda_local
        self.max_crop_side = max_crop_side

        self.netG_A = Generator(input_nc=4, output_nc=3)
        self.netG_B = Generator(input_nc=4, output_nc=3)
        self.netD_A = NLayerDiscriminator(input_nc=3)
        self.netD_B = NLayerDiscriminator(input_nc=3)
        self.localD = NLayerDiscriminator(input_nc=3)
        self.crop_drones = CropDrones()
        self.criterionGAN = GANLoss("lsgan")
        self.criterionCycle = nn.L1Loss()

        init_weights(self.netG_A)
        init_weights(self.netG_B)
        init_weights(self.netD_A)
        init_weights(self.netD_B)
        init_weights(self.localD)

        self.fake_B_pool = ImagePool(pool_size)
        self.fake_A_pool = ImagePool(pool_size)
        self.fake_drones_pool = ImagePool(pool_size)
Esempio n. 2
0
    def __init__(self,
                 num_iter=100,
                 num_iter_decay=100,
                 lambda_A=10,
                 lambda_B=10,
                 lambda_identity=0.5):
        super(CycleGANModel, self).__init__()
        self.name = None

        self.epoch_count = torch.tensor(1)  ###
        self.num_iter = torch.tensor(num_iter)
        self.num_iter_decay = torch.tensor(num_iter_decay)

        self.lambda_A = torch.tensor(lambda_A)
        self.lambda_B = torch.tensor(lambda_B)
        self.lambda_identity = torch.tensor(lambda_identity)

        self.netG_A = define_G(num_res_blocks=9)
        self.netG_B = define_G(num_res_blocks=9)

        self.netD_A = define_D()
        self.netD_B = define_D()

        self.fake_A_pool = ImagePool(pool_size=50)
        self.fake_B_pool = ImagePool(pool_size=50)

        self.criterionGAN = define_GAN_loss()
        self.criterionCycle = torch.nn.L1Loss()
        self.criterionIdt = torch.nn.L1Loss()

        self.optimizer_G_A = optim.Adam(self.netG_A.parameters(),
                                        lr=0.0002,
                                        betas=(0.5, 0.999))
        self.optimizer_G_B = optim.Adam(self.netG_B.parameters(),
                                        lr=0.0002,
                                        betas=(0.5, 0.999))
        self.optimizer_D_A = optim.Adam(self.netD_A.parameters(),
                                        lr=0.0002,
                                        betas=(0.5, 0.999))
        self.optimizer_D_B = optim.Adam(self.netD_B.parameters(),
                                        lr=0.0002,
                                        betas=(0.5, 0.999))

        lambda_rule = lambda epoch: 1.0 - max(
            0, epoch + self.epoch_count - self.num_iter) / float(
                self.num_iter_decay + 1)

        self.scheduler_G_A = scheduler.LambdaLR(self.optimizer_G_A,
                                                lr_lambda=lambda_rule)
        self.scheduler_G_B = scheduler.LambdaLR(self.optimizer_G_B,
                                                lr_lambda=lambda_rule)
        self.scheduler_D_A = scheduler.LambdaLR(self.optimizer_D_A,
                                                lr_lambda=lambda_rule)
        self.scheduler_D_B = scheduler.LambdaLR(self.optimizer_D_B,
                                                lr_lambda=lambda_rule)
Esempio n. 3
0
    def __init__(self, opt):
        super(FeatureLoss, self).__init__()
        self.opt = opt
        self.isTrain = opt.isTrain
        self.vgg = VGG.vgg16(pretrained = True)
        self.Tensor = torch.cuda.FloatTensor if use_gpu else torch.Tensor

        self.input_A = self.Tensor(opt.batchSize, opt.input_nc, opt.fineSize, opt.fineSize)
        self.input_B = self.Tensor(opt.batchSize, opt.output_nc, opt.fineSize, opt.fineSize)

        # Assuming norm_type = batch
        norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
        # model  of Generator Net is unet_256
        self.GeneratorNet = Generator(opt.input_nc, opt.output_nc, 8, opt.ngf, norm_layer=norm_layer,use_dropout = not opt.no_dropout)
        if use_gpu:
            self.GeneratorNet.cuda(0)
        self.GeneratorNet.apply(init_weights)

        if self.isTrain:
            use_sigmoid = opt.no_lsgan
            # model  of Discriminator Net is basic
            self.DiscriminatorNet = Discriminator(opt.input_nc+ opt.output_nc, opt.ndf, n_layers = 3, norm_layer = norm_layer, use_sigmoid = use_sigmoid)
            if use_gpu:
                self.DiscriminatorNet.cuda(0)
            self.DiscriminatorNet.apply(init_weights)

        if not self.isTrain or opt.continue_train:
            self.load_network(self.GeneratorNet, 'Generator', opt.which_epoch)
            if self.isTrain:
                self.load_network(self.DiscriminatorNet, 'Discriminator', opt.which_epoch)

        if self.isTrain:
            self.fake_AB_pool = ImagePool(opt.pool_size)
            self.learning_rate = opt.lr
            # defining loss functions
            self.criterionGAN = GANLoss(use_lsgan = not opt.no_lsgan, tensor=self.Tensor)
            self.criterionL1 = torch.nn.L1Loss()
            self.criterionFV = loss_FV.FeatureVectorLoss()

            self.MySchedulers = []  # initialising schedulers
            self.MyOptimizers = []  # initialising optimizers
            self.generator_optimizer = torch.optim.Adam(self.GeneratorNet.parameters(), lr=self.learning_rate, betas = (opt.beta1, 0.999))
            self.discriminator_optimizer = torch.optim.Adam(self.DiscriminatorNet.parameters(), lr=self.learning_rate, betas = (opt.beta1, 0.999))
            self.MyOptimizers.append(self.generator_optimizer)
            self.MyOptimizers.append(self.discriminator_optimizer)
            def lambda_rule(epoch):
                lr_l = 1.0 - max(0, epoch - opt.niter)/float(opt.niter_decay+1)
                return lr_l
            for optimizer in self.MyOptimizers:
                self.MySchedulers.append(lr_scheduler.LambdaLR(optimizer, lr_lambda = lambda_rule))
                # assuming opt.lr_policy == 'lambda'


        print('<============ NETWORKS INITIATED ============>')
        print_net(self.GeneratorNet)
        if self.isTrain:
            print_net(self.DiscriminatorNet)
        print('<=============================================>')
Esempio n. 4
0
    def __init__(self,  sess , args):

        self.start_time = time.time ()
        self.sess = sess
        self.pool = ImagePool(max_size= args.max_size)
        self.img_size = (args.image_size , args.image_size)
        self.load_size = (args.load_size , args.load_size)
        self.img_channels = args.image_channel
        self.il = ImageLoader (load_size= self.load_size , img_size=self.img_size ,data_dir = args.data_dir ,target_dir = args.target_dir)
        self.data_dir = args.data_dir
        self.target_dir = args.target_dir
        self.video_dir = args.video_dir
        self.sample_dir = args.sample_dir
        self.checkpoint_dir = args.checkpoint_dir
        self.log_dir = args.log_dir
        self.output_data_dir = os.path.join ('results' ,args.output_data_dir)
        self.output_target_dir = os.path.join ('results' ,args.output_target_dir)
        self.gf_dim = args.gf_dim
        self.df_dim = args.df_dim
        self.l1_lambda = args.l1_lambda
        self.learning_rate = args.learning_rate
        self.bata1 = args.bata1
        self.epoch_num = args.epoch_num
        self.batch_size = args.batch_size
        self.data_batch_num = self.il.get_image_num() // self.batch_size
        self.target_batch_num = self.il.get_image_num(is_data= False) // self.batch_size
        self.batch_num = min (self.data_batch_num ,self.target_batch_num)
        self.global_step = 0

        if args.clear_all_memory:
            print('start clear all memory...')

            def clear_files(clear_dir):
                shutil.rmtree (clear_dir)
                os.mkdir (clear_dir)

            clear_files(self.log_dir)
            clear_files(self.checkpoint_dir)
            clear_files(self.sample_dir)

            print ('successfully clear all memory...')


        if not os.path.exists('results'):
            os.makedirs('results')

        if not os.path.exists(self.log_dir):
            os.makedirs(self.log_dir)

        self._build(args)


        self.saver = tf.train.Saver()
Esempio n. 5
0
    def initialize(self, n_input_channels, n_output_channels, n_blocks,
                   initial_filters, dropout_value, lr, batch_size, image_width,
                   image_height, gpu_ids, gan, pool_size, n_blocks_discr):

        self.input_img = self.tensor(batch_size, n_input_channels,
                                     image_height, image_width)
        self.input_gt = self.tensor(batch_size, n_output_channels,
                                    image_height, image_width)

        self.generator = UNetV2(n_input_channels,
                                n_output_channels,
                                n_blocks,
                                initial_filters,
                                gpu_ids=gpu_ids)

        if gan:
            self.discriminator = ImageDiscriminatorConv(
                n_output_channels,
                initial_filters,
                dropout_value,
                gpu_ids=gpu_ids,
                n_blocks=n_blocks_discr)
            self.criterion_gan = GANLoss(tensor=self.tensor)
            self.optimizer_dis = torch.optim.Adam(
                self.discriminator.parameters(), lr=lr, betas=(0.5, 0.999))
            self.fake_mask_pool = ImagePool(pool_size)

        if self.load_network:
            self._load_network(self.generator, 'Model', self.load_epoch)
            if gan:
                self._load_network(self.discriminator, 'Discriminator',
                                   self.load_epoch)

        self.criterion_seg = BCELoss2d()
        self.optimizer_seg = torch.optim.Adam(self.generator.parameters(),
                                              lr=lr,
                                              betas=(0.5, 0.999))

        print('---------- Network initialized -------------')
        self.print_network(self.generator)
        if gan:
            self.print_network(self.discriminator)
        print('-----------------------------------------------')
Esempio n. 6
0
    input, label = input.cuda(), label.cuda()
    noise, fixed_noise = noise.cuda(), fixed_noise.cuda()

fixed_noise = Variable(fixed_noise)

# setup optimizer
optimizerD = optim.Adam(netD.parameters(),
                        lr=2 * opt.lr,
                        betas=(opt.beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
if not opt.withoutE:
    optimizerE = optim.Adam(netE.parameters(),
                            lr=opt.lr,
                            betas=(opt.beta1, 0.999))

fake_pool = ImagePool(50)

schedulers = []
schedulers.append(lr_scheduler.StepLR(optimizerD, step_size=40, gamma=0.5))
schedulers.append(lr_scheduler.StepLR(optimizerG, step_size=40, gamma=0.5))
if not opt.withoutE:
    schedulers.append(lr_scheduler.StepLR(optimizerE, step_size=40, gamma=0.5))

for epoch in range(opt.niter):
    for i, data in enumerate(dataloader, 0):
        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        # train with real
        netD.zero_grad()
        real_cpu, _ = data
	def train(self):
		self.build_network()	
		self.get_data_loaders()

		''' Initialize the image buffer '''
		self.image_pool = ImagePool(self.cfg.buffer_size)

		''' If no saved weights are found,
			pretrain the refiner / discriminator '''
		if not self.weights_loaded:
			self.pretrain_refiner()
			self.pretrain_discriminator()
		
		#''' Initialize the image buffer '''
		# self.image_pool = ImagePool(self.cfg.buffer_size)
		
		''' Check if step is valid '''
		assert self.current_step < self.cfg.train_steps, 'Target step is smaller than current step'


		for step in range((self.current_step + 1), self.cfg.train_steps):
			
			self.current_step = step
			
			''' Train Refiner ''' 
			self.D.eval()
			self.D.train_mode(False)
			
			self.R.train()
			self.R.train_mode(True)
			
			for idx in range(self.cfg.k_r):
				''' update refiner and return some important info for printing '''
				self.update_refiner(pretrain=False)

			

			''' Train Discriminator '''
			self.R.eval()
			self.R.train_mode(False)

			self.D.train()
			self.D.train_mode(True)

			for idx in range(self.cfg.k_d):
				''' update discriminator and return some important info for printing '''
				self.update_discriminator(pretrain=False)

				
			if step % self.cfg.print_interval == 0 and step > 0:
				self.print_refiner_info(step, pretrain=False)
				self.print_discriminator_info(step, pretrain=False)
			
			
			if self.cfg.log == True and (step % self.cfg.log_interval == 0 or step == 0):
				synthetic_images, _ = next(self.synthetic_data_iter)
				synthetic_images = synthetic_images.cuda(device=self.cfg.cuda_num)
				refined_images = self.R(synthetic_images)

				figure = np.stack([
					var_to_np(synthetic_images[:32]),
					var_to_np(refined_images[:32]),
					], axis=1)
				print('fig 0 shape {}'.format(np.shape(figure)))
				figure = figure.transpose((0, 1, 3, 4, 2))
				print('fig 5 shape {}'.format(np.shape(figure)))
				figure = figure.reshape((4, 8) + figure.shape[1:])
				print('fig 10 shape {}'.format(np.shape(figure)))
				figure = stack_images(figure)
				print('fig 15 shape {}'.format(np.shape(figure)))

				#figure = np.squeeze(figure, axis=2)not for 3 channel imgs
				figure = np.clip(figure*255, 0, 255).astype('uint8')

				cv2.imwrite(self.cfg.checkpoint_path + 'images/' + 'eyes_' + str(step) + '_.jpg', figure)

			if step % self.cfg.save_interval == 0:
				print('Saving checkpoints, Step : {}'.format(step))	
				torch.save(self.R.state_dict(), os.path.join(self.cfg.checkpoint_path, self.cfg.R_path % step))
				torch.save(self.D.state_dict(), os.path.join(self.cfg.checkpoint_path, self.cfg.D_path % step))

				state = {
					'step': step,
					'optD' : self.discriminator_optimizer.state_dict(),
					'optR' : self.refiner_optimizer.state_dict()
				}
				
				torch.save(state, os.path.join(self.cfg.checkpoint_path, self.cfg.optimizer_path))
Esempio n. 8
0
def train(dataset, start_epoch, max_epochs, lr_d, lr_g, batch_size, lmda_cyc,
          lmda_idt, pool_size, context):
    mx.random.seed(int(time.time()))

    print("Loading dataset...", flush=True)
    training_set_a = load_dataset(dataset, "trainA")
    training_set_b = load_dataset(dataset, "trainB")

    gen_ab = ResnetGenerator()
    dis_b = PatchDiscriminator()
    gen_ba = ResnetGenerator()
    dis_a = PatchDiscriminator()
    bce_loss = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss()
    l1_loss = mx.gluon.loss.L1Loss()

    gen_ab_params_file = "model/{}.gen_ab.params".format(dataset)
    dis_b_params_file = "model/{}.dis_b.params".format(dataset)
    gen_ab_state_file = "model/{}.gen_ab.state".format(dataset)
    dis_b_state_file = "model/{}.dis_b.state".format(dataset)
    gen_ba_params_file = "model/{}.gen_ba.params".format(dataset)
    dis_a_params_file = "model/{}.dis_a.params".format(dataset)
    gen_ba_state_file = "model/{}.gen_ba.state".format(dataset)
    dis_a_state_file = "model/{}.dis_a.state".format(dataset)

    if os.path.isfile(gen_ab_params_file):
        gen_ab.load_parameters(gen_ab_params_file, ctx=context)
    else:
        gen_ab.initialize(GANInitializer(), ctx=context)

    if os.path.isfile(dis_b_params_file):
        dis_b.load_parameters(dis_b_params_file, ctx=context)
    else:
        dis_b.initialize(GANInitializer(), ctx=context)

    if os.path.isfile(gen_ba_params_file):
        gen_ba.load_parameters(gen_ba_params_file, ctx=context)
    else:
        gen_ba.initialize(GANInitializer(), ctx=context)

    if os.path.isfile(dis_a_params_file):
        dis_a.load_parameters(dis_a_params_file, ctx=context)
    else:
        dis_a.initialize(GANInitializer(), ctx=context)

    print("Learning rate of discriminator:", lr_d, flush=True)
    print("Learning rate of generator:", lr_g, flush=True)
    trainer_gen_ab = mx.gluon.Trainer(gen_ab.collect_params(), "Nadam", {
        "learning_rate": lr_g,
        "beta1": 0.5
    })
    trainer_dis_b = mx.gluon.Trainer(dis_b.collect_params(), "Nadam", {
        "learning_rate": lr_d,
        "beta1": 0.5
    })
    trainer_gen_ba = mx.gluon.Trainer(gen_ba.collect_params(), "Nadam", {
        "learning_rate": lr_g,
        "beta1": 0.5
    })
    trainer_dis_a = mx.gluon.Trainer(dis_a.collect_params(), "Nadam", {
        "learning_rate": lr_d,
        "beta1": 0.5
    })

    if os.path.isfile(gen_ab_state_file):
        trainer_gen_ab.load_states(gen_ab_state_file)

    if os.path.isfile(dis_b_state_file):
        trainer_dis_b.load_states(dis_b_state_file)

    if os.path.isfile(gen_ba_state_file):
        trainer_gen_ba.load_states(gen_ba_state_file)

    if os.path.isfile(dis_a_state_file):
        trainer_dis_a.load_states(dis_a_state_file)

    fake_a_pool = ImagePool(pool_size)
    fake_b_pool = ImagePool(pool_size)

    print("Training...", flush=True)
    for epoch in range(start_epoch, max_epochs):
        ts = time.time()

        random.shuffle(training_set_a)
        random.shuffle(training_set_b)

        training_dis_a_L = 0.0
        training_dis_b_L = 0.0
        training_gen_L = 0.0
        training_batch = 0

        for real_a, real_b in get_batches(training_set_a,
                                          training_set_b,
                                          batch_size,
                                          ctx=context):
            training_batch += 1

            fake_a, _ = gen_ba(real_b)
            fake_b, _ = gen_ab(real_a)

            with mx.autograd.record():
                real_a_y, real_a_cam_y = dis_a(real_a)
                real_a_L = bce_loss(real_a_y,
                                    mx.nd.ones_like(real_a_y, ctx=context))
                real_a_cam_L = bce_loss(
                    real_a_cam_y, mx.nd.ones_like(real_a_cam_y, ctx=context))
                fake_a_y, fake_a_cam_y = dis_a(fake_a_pool.query(fake_a))
                fake_a_L = bce_loss(fake_a_y,
                                    mx.nd.zeros_like(fake_a_y, ctx=context))
                fake_a_cam_L = bce_loss(
                    fake_a_cam_y, mx.nd.zeros_like(fake_a_cam_y, ctx=context))
                L = real_a_L + real_a_cam_L + fake_a_L + fake_a_cam_L
                L.backward()
            trainer_dis_a.step(batch_size)
            dis_a_L = mx.nd.mean(L).asscalar()
            if dis_a_L != dis_a_L:
                raise ValueError()

            with mx.autograd.record():
                real_b_y, real_b_cam_y = dis_b(real_b)
                real_b_L = bce_loss(real_b_y,
                                    mx.nd.ones_like(real_b_y, ctx=context))
                real_b_cam_L = bce_loss(
                    real_b_cam_y, mx.nd.ones_like(real_b_cam_y, ctx=context))
                fake_b_y, fake_b_cam_y = dis_b(fake_b_pool.query(fake_b))
                fake_b_L = bce_loss(fake_b_y,
                                    mx.nd.zeros_like(fake_b_y, ctx=context))
                fake_b_cam_L = bce_loss(
                    fake_b_cam_y, mx.nd.zeros_like(fake_b_cam_y, ctx=context))
                L = real_b_L + real_b_cam_L + fake_b_L + fake_b_cam_L
                L.backward()
            trainer_dis_b.step(batch_size)
            dis_b_L = mx.nd.mean(L).asscalar()
            if dis_b_L != dis_b_L:
                raise ValueError()

            with mx.autograd.record():
                fake_a, gen_a_cam_y = gen_ba(real_b)
                fake_a_y, fake_a_cam_y = dis_a(fake_a)
                gan_a_L = bce_loss(fake_a_y,
                                   mx.nd.ones_like(fake_a_y, ctx=context))
                gan_a_cam_L = bce_loss(
                    fake_a_cam_y, mx.nd.ones_like(fake_a_cam_y, ctx=context))
                rec_b, _ = gen_ab(fake_a)
                cyc_b_L = l1_loss(rec_b, real_b)
                idt_a, idt_a_cam_y = gen_ba(real_a)
                idt_a_L = l1_loss(idt_a, real_a)
                gen_a_cam_L = bce_loss(
                    gen_a_cam_y, mx.nd.ones_like(
                        gen_a_cam_y, ctx=context)) + bce_loss(
                            idt_a_cam_y,
                            mx.nd.zeros_like(idt_a_cam_y, ctx=context))
                gen_ba_L = gan_a_L + gan_a_cam_L + cyc_b_L * lmda_cyc + idt_a_L * lmda_cyc * lmda_idt + gen_a_cam_L
                fake_b, gen_b_cam_y = gen_ab(real_a)
                fake_b_y, fake_b_cam_y = dis_b(fake_b)
                gan_b_L = bce_loss(fake_b_y,
                                   mx.nd.ones_like(fake_b_y, ctx=context))
                gan_b_cam_L = bce_loss(
                    fake_b_cam_y, mx.nd.ones_like(fake_b_cam_y, ctx=context))
                rec_a, _ = gen_ba(fake_b)
                cyc_a_L = l1_loss(rec_a, real_a)
                idt_b, idt_b_cam_y = gen_ab(real_b)
                idt_b_L = l1_loss(idt_b, real_b)
                gen_b_cam_L = bce_loss(
                    gen_b_cam_y, mx.nd.ones_like(
                        gen_b_cam_y, ctx=context)) + bce_loss(
                            idt_b_cam_y,
                            mx.nd.zeros_like(idt_b_cam_y, ctx=context))
                gen_ab_L = gan_b_L + gan_b_cam_L + cyc_a_L * lmda_cyc + idt_b_L * lmda_cyc * lmda_idt + gen_b_cam_L
                L = gen_ba_L + gen_ab_L
                L.backward()
            trainer_gen_ba.step(batch_size)
            trainer_gen_ab.step(batch_size)
            gen_L = mx.nd.mean(L).asscalar()
            if gen_L != gen_L:
                raise ValueError()

            training_dis_a_L += dis_a_L
            training_dis_b_L += dis_b_L
            training_gen_L += gen_L
            print(
                "[Epoch %d  Batch %d]  dis_a_loss %.10f  dis_b_loss %.10f  gen_loss %.10f  elapsed %.2fs"
                % (epoch, training_batch, dis_a_L, dis_b_L, gen_L,
                   time.time() - ts),
                flush=True)

        print(
            "[Epoch %d]  training_dis_a_loss %.10f  training_dis_b_loss %.10f  training_gen_loss %.10f  duration %.2fs"
            % (epoch + 1, training_dis_a_L / training_batch,
               training_dis_b_L / training_batch,
               training_gen_L / training_batch, time.time() - ts),
            flush=True)

        gen_ab.save_parameters(gen_ab_params_file)
        gen_ba.save_parameters(gen_ba_params_file)
        dis_a.save_parameters(dis_a_params_file)
        dis_b.save_parameters(dis_b_params_file)
        trainer_gen_ab.save_states(gen_ab_state_file)
        trainer_gen_ba.save_states(gen_ba_state_file)
        trainer_dis_a.save_states(dis_a_state_file)
        trainer_dis_b.save_states(dis_b_state_file)
Esempio n. 9
0
def train():
    FLAGS = tf.flags.FLAGS
    graph = tf.Graph()
    output_model_dir = get_output_model_dir()

    # Sketch dataset handler
    data_handler_S = SketchDataHandler(get_data_dir(), FLAGS.S,
                                       FLAGS.batch_size, FLAGS.target_size)

    # Pen dataset handler
    if FLAGS.data_type == 'bezier':
        data_handler_P = BezierDataHandler(FLAGS.batch_size, FLAGS.target_size)
    elif FLAGS.data_type == 'line':
        data_handler_P = LineDataHandler(FLAGS.batch_size, FLAGS.target_size)
    else:
        print("no match dataset for %s" % FLAGS.data_type)
        exit(-1)

    # Model type
    if FLAGS.model_type == 'cycle_gan':
        model = models.cycle_gan
    elif FLAGS.model_type == 'our_cycle_gan':
        model = models.our_cycle_gan
    else:
        print("no match model for %s" % FLAGS.model_type)
        exit(-1)

    fake_pen_pool = ImagePool()
    fake_sketch_pool = ImagePool()

    try:
        with graph.as_default():
            input_S = tf.placeholder(tf.float32,
                                     shape=data_handler_S.get_batch_shape(),
                                     name='input_S')
            input_P = tf.placeholder(tf.float32,
                                     shape=data_handler_P.get_batch_shape(),
                                     name='input_P')

            input_FP_pool = tf.placeholder(tf.float32,
                                           shape=(FLAGS.batch_size,
                                                  FLAGS.target_size,
                                                  FLAGS.target_size, 1),
                                           name='input_FP_pool')
            input_FS_pool = tf.placeholder(tf.float32,
                                           shape=(FLAGS.batch_size,
                                                  FLAGS.target_size,
                                                  FLAGS.target_size, 1),
                                           name='input_FS_pool')

            # Model here
            [train_op, losses,
             predictions] = model.build_model(input_S, input_P, input_FS_pool,
                                              input_FP_pool)

            # choose summary
            summary_list = [
                tf.summary.image("S/input_S", input_S),
                tf.summary.image("S/P_from_S",
                                 predictions['P_from_S']),  # output
                tf.summary.image("S/S_cycled", predictions['S_cycled']),
                tf.summary.image("S/noisy_S", predictions['noisy_S']),
                tf.summary.image("P/input_P", input_P),
                tf.summary.image("P/noisy_P", predictions['noisy_P']),
                tf.summary.image("P/S_from_P", predictions['S_from_P']),
                tf.summary.image("P/P_cycled", predictions['P_cycled']),
                tf.summary.image("Debug/P_from_S", predictions['P_from_S']),
                tf.summary.image("Debug/score_fakeP", predictions['fake_DP']),
                tf.summary.image("Debug/score_realP", predictions['real_DP']),
                tf.summary.scalar("loss/loss_cycle_S", losses['loss_cycle_S']),
                tf.summary.scalar("loss/loss_cycle_P", losses['loss_cycle_P']),
                tf.summary.scalar("loss/loss_cycle", losses['loss_cycle']),
                tf.summary.scalar("loss/loss_DS", losses['loss_DS']),
                tf.summary.scalar("loss/loss_DP", losses['loss_DP']),
                tf.summary.scalar("loss/loss_F", losses['loss_F']),
                tf.summary.scalar("loss/loss_G", losses['loss_G']),
            ]
            if model == models.our_cycle_gan:
                summary_list.extend([
                    tf.summary.image("S/extra", predictions['extra']),
                ])

            summary_op = tf.summary.merge(summary_list)
            summary_writer = tf.summary.FileWriter(output_model_dir)
            model_saver = tf.train.Saver(max_to_keep=1000)

        with tf.Session(graph=graph) as sess:
            if FLAGS.restore_model_dir is not None:
                checkpoint = tf.train.get_checkpoint_state(
                    FLAGS.restore_model_dir)
                model_saver.restore(
                    sess, tf.train.latest_checkpoint(FLAGS.restore_model_dir))
                meta_graph_path = checkpoint.model_checkpoint_path + ".meta"
                step = int(meta_graph_path.split("-")[1].split(".")[0])
            else:
                sess.run(tf.global_variables_initializer())
                step = 0

            try:
                while True:  # We manually shut down
                    # First, generate fake image and update the fake pool. Then train.
                    tmp_S = data_handler_S.next()
                    tmp_P = data_handler_P.next()
                    FP, FS = sess.run(
                        [predictions['P_from_S'], predictions['S_from_P']],
                        feed_dict={
                            input_S: tmp_S,
                            input_P: tmp_P,
                        })

                    # Now train using data and the fake pools.
                    fetch_dict = {
                        "train_op": train_op,
                        "loss": losses['loss'],
                        "P_from_S": predictions['P_from_S'],
                        "S_from_P": predictions['S_from_P'],
                    }

                    if step % FLAGS.log_step == 0:
                        fetch_dict.update({
                            "summary": summary_op,
                        })

                    result = sess.run(fetch_dict,
                                      feed_dict={
                                          input_S: tmp_S,
                                          input_P: tmp_P,
                                          input_FS_pool: fake_sketch_pool(FS),
                                          input_FP_pool: fake_pen_pool(FP),
                                      })

                    if step % FLAGS.log_step == 0:
                        summary_writer.add_summary(result["summary"], step)
                        summary_writer.flush()

                    if step % FLAGS.save_step == 0:
                        save_path = model_saver.save(sess,
                                                     os.path.join(
                                                         output_model_dir,
                                                         "model.ckpt"),
                                                     global_step=step)

                    print("Iter %d, loss %f" % (step, result["loss"]))
                    step += 1

            finally:
                save_path = model_saver.save(sess,
                                             os.path.join(
                                                 output_model_dir,
                                                 "model.ckpt"),
                                             global_step=step)

    finally:
        data_handler_S.kill()
        data_handler_P.kill()
Esempio n. 10
0
    def train(self, epochs, batch_size=1, sample_interval=50, pool_size=50):

        start_time = datetime.datetime.now()

        # Adversarial loss ground truths
        valid = np.ones((batch_size, ) + self.disc_patch)
        fake = np.zeros((batch_size, ) + self.disc_patch)

        fake_a_pool = ImagePool(pool_size)
        fake_b_pool = ImagePool(pool_size)

        tensorboard = TensorBoard(batch_size=batch_size, write_grads=True)
        tensorboard.set_model(self.combined)

        def named_logs(model, logs):
            result = {}
            for l in zip(model.metrics_names, logs):
                result[l[0]] = l[1]
            return result

        for epoch in range(epochs):
            for batch_i, (imgs_A, imgs_B) in enumerate(
                    self.data_loader.load_batch(batch_size)):

                # ----------------------
                #  Train Discriminators
                # ----------------------

                # Translate images to opposite domain
                fake_B = fake_b_pool.query(self.g_AB.predict(imgs_A))
                fake_A = fake_a_pool.query(self.g_BA.predict(imgs_B))

                # Train the discriminators (original images = real / translated = Fake)
                dA_loss_real = self.d_A.train_on_batch(imgs_A, valid)
                dA_loss_fake = self.d_A.train_on_batch(fake_A, fake)
                dA_loss = 0.5 * np.add(dA_loss_real, dA_loss_fake)

                dB_loss_real = self.d_B.train_on_batch(imgs_B, valid)
                dB_loss_fake = self.d_B.train_on_batch(fake_B, fake)
                dB_loss = 0.5 * np.add(dB_loss_real, dB_loss_fake)

                # Total disciminator loss
                d_loss = 0.5 * np.add(dA_loss, dB_loss)

                # ------------------
                #  Train Generators
                # ------------------

                # Train the generators
                g_loss = self.combined.train_on_batch(
                    [imgs_A, imgs_B],
                    [valid, valid, imgs_A, imgs_B, imgs_A, imgs_B])

                elapsed_time = datetime.datetime.now() - start_time

                # K.clear_session()

                # Plot the progress
                print(
                    "[Epoch %d/%d] [Batch %d/%d] [D loss: %f, acc: %3d%%] [G loss: %05f, adv: %05f, recon: %05f, id: %05f] time: %s " \
                    % (epoch, epochs,
                       batch_i, self.data_loader.n_batches,
                       d_loss[0], 100 * d_loss[1],
                       g_loss[0],
                       np.mean(g_loss[1:3]),
                       np.mean(g_loss[3:5]),
                       np.mean(g_loss[5:6]),
                       elapsed_time))

                # If at save interval => save generated image samples
                if batch_i % sample_interval == 0:
                    self.sample_images(epoch, batch_i)

            if epoch % 1 == 0:
                self.combined.save_weights(
                    f"saved_model/{self.dataset_name}/{epoch}.h5")
Esempio n. 11
0
 def initialize(self, opt, tensor):
     self.criterionGAN = GANLoss(use_l1=False, tensor=tensor)
     self.fake_AB_pool = ImagePool(opt.pool_size)
Esempio n. 12
0
    def __init__(self, opt, gpu_ids=[0], continue_run=None):
        self.opt = opt
        self.kt = 0
        self.lamk = 0.001
        self.lambdaImg = 100
        self.lambdaGan = 1.0
        self.model_names = ['netD', 'netG']
        self.gpu_ids = gpu_ids

        if not continue_run:
            expname = '-'.join([
                'b_' + str(self.opt.batchSize), 'ngf_' + str(self.opt.ngf),
                'ndf_' + str(self.opt.ndf), 'gm_' + str(self.opt.gamma)
            ])
            self.rundir = self.opt.rundir + '/pix2pixBEGAN-' + datetime.now(
            ).strftime('%B%d-%H-%M-%S') + expname + self.opt.comment
            if not os.path.isdir(self.rundir):
                os.mkdir(self.rundir)
            with open(self.rundir + '/options.pkl', 'wb') as file:
                pickle.dump(opt, file)
        else:
            self.rundir = continue_run
            if os.path.isfile(self.rundir + '/options.pkl'):
                with open(self.rundir + '/options.pkl', 'rb') as file:
                    tmp = opt.rundir
                    tmp_lr = opt.lr
                    self.opt = pickle.load(file)
                    self.opt.rundir = tmp
                    self.opt.lr = tmp_lr

        self.netG = UnetGenerator(input_nc=3,
                                  output_nc=3,
                                  num_downs=7,
                                  ngf=self.opt.ngf,
                                  norm_layer=nn.BatchNorm2d,
                                  use_dropout=True)
        self.netD = UnetDescriminator(input_nc=3,
                                      output_nc=3,
                                      num_downs=7,
                                      ngf=self.opt.ndf,
                                      norm_layer=nn.BatchNorm2d,
                                      use_dropout=True)

        # Decide which device we want to run on
        self.device = torch.device("cuda:0" if (
            torch.cuda.is_available()) else "cpu")

        init_net(self.netG, 'normal', 0.002, [0])
        init_net(self.netD, 'normal', 0.002, [0])

        self.netG.to(self.device)
        self.netD.to(self.device)
        self.imagePool = ImagePool(pool_size)

        self.criterionL1 = torch.nn.L1Loss()

        if continue_run:
            self.load_networks('latest')

        self.writer = Logger(self.rundir)
        self.start_step, self.opt.lr = self.writer.get_latest(
            'misc/lr', self.opt.lr)

        # initialize optimizers
        self.optimG = torch.optim.Adam(self.netG.parameters(),
                                       lr=self.opt.lr,
                                       betas=(beta1, 0.999))
        self.optimD = torch.optim.Adam(self.netD.parameters(),
                                       lr=self.opt.lr,
                                       betas=(beta1, 0.999))
Esempio n. 13
0
    def __init__(
        self, name="experiment", phase="train", which_epoch="latest",
        batch_size=1, image_size=128, map_nc=1, input_nc=3, output_nc=3,
        num_downs=7, ngf=64, ndf=64, norm_layer="batch", pool_size=50,
        lr=0.0002, beta1=0.5, lambda_D=0.5, lambda_MSE=10,
        lambda_P=5.0, use_dropout=True, gpu_ids=[], n_layers=3,
        use_sigmoid=False, use_lsgan=True, upsampling="nearest",
        continue_train=False, checkpoints_dir="checkpoints/"
    ):
        # Define input data that will be consumed by networks
        self.input_A = torch.FloatTensor(
            batch_size, 3, image_size, image_size
        )
        self.input_map = torch.FloatTensor(
            batch_size, map_nc, image_size, image_size
        )
        norm_layer = nn.BatchNorm2d \
            if norm_layer == "batch" else nn.InstanceNorm2d

        # Define netD and netG
        self.netG = networks.UnetGenerator(
            input_nc=input_nc, output_nc=map_nc,
            num_downs=num_downs, ngf=ngf,
            use_dropout=use_dropout, gpu_ids=gpu_ids, norm_layer=norm_layer,
            upsampling_layer=upsampling
        )
        self.netD = networks.NLayerDiscriminator(
            input_nc=input_nc + map_nc, ndf=ndf,
            n_layers=n_layers, use_sigmoid=use_sigmoid, gpu_ids=gpu_ids
        )

        # Transfer data to GPU
        if len(gpu_ids) > 0:
            self.input_A = self.input_A.cuda()
            self.input_map = self.input_map.cuda()
            self.netD.cuda()
            self.netG.cuda()

        # Initialize parameters of netD and netG
        self.netG.apply(networks.weights_init)
        self.netD.apply(networks.weights_init)

        # Load trained netD and netG
        if phase == "test" or continue_train:
            netG_checkpoint_file = os.path.join(
                checkpoints_dir, name, "netG_{}.pth".format(which_epoch)
            )
            self.netG.load_state_dict(
                torch.load(netG_checkpoint_file)
            )
            print("Restoring netG from {}".format(netG_checkpoint_file))

        if continue_train:
            netD_checkpoint_file = os.path.join(
                checkpoints_dir, name, "netD_{}.pth".format(which_epoch)
            )
            self.netD.load_state_dict(
                torch.load(netD_checkpoint_file)
            )
            print("Restoring netD from {}".format(netD_checkpoint_file))

        self.name = name
        self.gpu_ids = gpu_ids
        self.checkpoints_dir = checkpoints_dir

        # Criterions
        if phase == "train":
            self.count = 0
            self.lr = lr
            self.lambda_D = lambda_D
            self.lambda_MSE = lambda_MSE

            self.image_pool = ImagePool(pool_size)
            self.criterionGAN = networks.GANLoss(use_lsgan=use_lsgan)
            self.criterionL1 = torch.nn.L1Loss()
            self.criterionMSE = torch.nn.MSELoss()  # Landmark loss

            self.optimizer_G = torch.optim.Adam(
                self.netG.parameters(), lr=self.lr, betas=(beta1, 0.999)
            )
            self.optimizer_D = torch.optim.Adam(
                self.netD.parameters(), lr=self.lr, betas=(beta1, 0.999)
            )

            print('---------- Networks initialized -------------')
            networks.print_network(self.netG)
            networks.print_network(self.netD)
            print('-----------------------------------------------')