Beispiel #1
0
    def __init__(self, input_nc=3, output_nc=3, gpu_id=None):
        self.device = torch.device(
            f"cuda:{gpu_id}" if gpu_id is not None else 'cpu')
        print(f"Using device {self.device}")

        # Hyperparameters
        self.lambda_idt = 0.5
        self.lambda_A = 10.0
        self.lambda_B = 10.0

        # Define generator networks
        self.netG_A = networks.define_netG(input_nc,
                                           output_nc,
                                           ngf=64,
                                           n_blocks=9,
                                           device=self.device)
        self.netG_B = networks.define_netG(output_nc,
                                           input_nc,
                                           ngf=64,
                                           n_blocks=9,
                                           device=self.device)

        # Define discriminator networks
        self.netD_A = networks.define_netD(output_nc,
                                           ndf=64,
                                           n_layers=3,
                                           device=self.device)
        self.netD_B = networks.define_netD(input_nc,
                                           ndf=64,
                                           n_layers=3,
                                           device=self.device)

        # Define image pools
        self.fake_A_pool = utils.ImagePool(pool_size=50)
        self.fake_B_pool = utils.ImagePool(pool_size=50)

        # Define loss functions
        self.criterionGAN = networks.GANLoss().to(self.device)
        self.criterionCycle = torch.nn.L1Loss()
        self.criterionIdt = torch.nn.L1Loss()

        # Define optimizers
        netG_params = itertools.chain(self.netG_A.parameters(),
                                      self.netG_B.parameters())
        netD_params = itertools.chain(self.netD_A.parameters(),
                                      self.netD_B.parameters())
        self.optimizer_G = torch.optim.Adam(netG_params,
                                            lr=0.0002,
                                            betas=(0.5, 0.999))
        self.optimizer_D = torch.optim.Adam(netD_params,
                                            lr=0.0002,
                                            betas=(0.5, 0.999))

        # Learning rate schedulers
        self.scheduler_G = utils.get_lr_scheduler(self.optimizer_G)
        self.scheduler_D = utils.get_lr_scheduler(self.optimizer_D)
Beispiel #2
0
# Test data
# test_data = DatasetFromFolder(data_dir, subfolder='img_test', transform=transform, resize_scale=params.input_size,
#                               crop_size=params.crop_size, fliplr=params.fliplr, yuv=True)
test_data = DatasetFromFolder2(data_dir,
                               subfolder='lfw_test.txt',
                               transform=transform,
                               resize_scale=params.input_size)
test_data_loader = torch.utils.data.DataLoader(dataset=test_data,
                                               batch_size=1,
                                               shuffle=False)

# image pool
num_pool = 50
# cover_pool = utils.ImagePool(num_pool)
stego_pool = utils.ImagePool(num_pool)
secret_pool = utils.ImagePool(num_pool)

# optimizers
encoder_optimizer = torch.optim.Adam(encoder.parameters(),
                                     lr=params.lr,
                                     betas=(params.beta1, params.beta2),
                                     weight_decay=params.weight_decay)
decoder_optimizer = torch.optim.Adam(decoder.parameters(),
                                     lr=params.lr,
                                     betas=(params.beta1, params.beta2),
                                     weight_decay=params.weight_decay)
# encoder_optimizer = torch.optim.SGD(encoder.parameters(), lr=params.lr, weight_decay=1e-8)
# decoder_optimizer = torch.optim.SGD(decoder.parameters(), lr=params.lr, weight_decay=1e-8)
discriminator_optimizer = torch.optim.SGD(discriminator.parameters(),
                                          lr=params.lr / 3,
    def __init__(self, args, dataloaders):

        self.dataloaders = dataloaders
        self.net_D1 = cycnet.define_D(input_nc=6,
                                      ndf=64,
                                      netD='n_layers',
                                      n_layers_D=2).to(device)
        self.net_D2 = cycnet.define_D(input_nc=6,
                                      ndf=64,
                                      netD='n_layers',
                                      n_layers_D=2).to(device)
        self.net_D3 = cycnet.define_D(input_nc=6,
                                      ndf=64,
                                      netD='n_layers',
                                      n_layers_D=3).to(device)
        self.net_G = cycnet.define_G(input_nc=3,
                                     output_nc=6,
                                     ngf=args.ngf,
                                     netG=args.net_G,
                                     use_dropout=False,
                                     norm='none').to(device)
        # M.Amintoosi norm='instance'
        # self.net_G = cycnet.define_G(
        #     input_nc=3, output_nc=6, ngf=args.ngf, netG=args.net_G, use_dropout=False, norm='instance').to(device)

        # Learning rate and Beta1 for Adam optimizers
        self.lr = args.lr

        # define optimizers
        self.optimizer_G = optim.Adam(self.net_G.parameters(),
                                      lr=self.lr,
                                      betas=(0.5, 0.999))
        self.optimizer_D1 = optim.Adam(self.net_D1.parameters(),
                                       lr=self.lr,
                                       betas=(0.5, 0.999))
        self.optimizer_D2 = optim.Adam(self.net_D2.parameters(),
                                       lr=self.lr,
                                       betas=(0.5, 0.999))
        self.optimizer_D3 = optim.Adam(self.net_D3.parameters(),
                                       lr=self.lr,
                                       betas=(0.5, 0.999))

        # define lr schedulers
        self.exp_lr_scheduler_G = lr_scheduler.StepLR(
            self.optimizer_G,
            step_size=args.exp_lr_scheduler_stepsize,
            gamma=0.1)
        self.exp_lr_scheduler_D1 = lr_scheduler.StepLR(
            self.optimizer_D1,
            step_size=args.exp_lr_scheduler_stepsize,
            gamma=0.1)
        self.exp_lr_scheduler_D2 = lr_scheduler.StepLR(
            self.optimizer_D2,
            step_size=args.exp_lr_scheduler_stepsize,
            gamma=0.1)
        self.exp_lr_scheduler_D3 = lr_scheduler.StepLR(
            self.optimizer_D3,
            step_size=args.exp_lr_scheduler_stepsize,
            gamma=0.1)

        # coefficient to balance loss functions
        self.lambda_L1 = args.lambda_L1
        self.lambda_adv = args.lambda_adv

        # based on which metric to update the "best" ckpt
        self.metric = args.metric

        # define some other vars to record the training states
        self.running_acc = []
        self.epoch_acc = 0
        if 'mse' in self.metric:
            self.best_val_acc = 1e9  # for mse, rmse, a lower score is better
        else:
            self.best_val_acc = 0.0  # for others (ssim, psnr), a higher score is better
        self.best_epoch_id = 0
        self.epoch_to_start = 0
        self.max_num_epochs = args.max_num_epochs
        self.G_pred1 = None
        self.G_pred2 = None
        self.batch = None
        self.G_loss = None
        self.D_loss = None
        self.is_training = False
        self.batch_id = 0
        self.epoch_id = 0
        self.checkpoint_dir = args.checkpoint_dir
        self.vis_dir = args.vis_dir
        self.D1_fake_pool = utils.ImagePool(pool_size=50)
        self.D2_fake_pool = utils.ImagePool(pool_size=50)
        self.D3_fake_pool = utils.ImagePool(pool_size=50)

        # define the loss functions
        if args.pixel_loss == 'minimum_pixel_loss':
            self._pxl_loss = loss.MinimumPixelLoss(
                opt=1)  # 1 for L1 and 2 for L2
        elif args.pixel_loss == 'pixel_loss':
            self._pxl_loss = loss.PixelLoss(opt=1)  # 1 for L1 and 2 for L2
        else:
            raise NotImplementedError(
                'pixel loss function [%s] is not implemented', args.pixel_loss)
        self._gan_loss = loss.GANLoss(gan_mode='vanilla').to(device)
        self._exclusion_loss = loss.ExclusionLoss()
        self._kurtosis_loss = loss.KurtosisLoss()
        # enable some losses?
        self.with_d1d2 = args.enable_d1d2
        self.with_d3 = args.enable_d3
        self.with_exclusion_loss = args.enable_exclusion_loss
        self.with_kurtosis_loss = args.enable_kurtosis_loss

        # m-th epoch to activate adversarial training
        self.m_epoch_activate_adv = int(self.max_num_epochs / 20) + 1

        # output auto-enhancement?
        self.output_auto_enhance = args.output_auto_enhance

        # use synfake to train D?
        self.synfake = args.enable_synfake

        # check and create model dir
        if os.path.exists(self.checkpoint_dir) is False:
            os.mkdir(self.checkpoint_dir)
        if os.path.exists(self.vis_dir) is False:
            os.mkdir(self.vis_dir)

        # visualize model
        if args.print_models:
            self._visualize_models()
def train(network_gen: nn.Module,
          network_dis: nn.Module,
          dataloader,
          checkpoint_path,
          weight_gan=0.01,
          weight_l1=1.0,
          weight_fm=10.0,
          weight_vgg=10.0,
          device=torch.device('cuda:0'),
          n_critic=3,
          n_gen=1,
          fake_pool_size=256,
          lr_gen=1e-4,
          lr_dis=5e-4,
          updates_per_epoch=10000,
          record_freq=1000,
          total_updates=1000000,
          gradient_accumulate=4,
          enable_fp16=False,
          resume=False):
    if enable_fp16:
        print(' -- FP16 AMP enabled')
    print(' -- Initializing losses')
    loss_gan = ganloss.GANLossSoftLS(device)
    loss_vgg = vgg_loss.VGG19LossWithStyle().to(device)

    opt_gen = optim.Adam(network_gen.parameters(),
                         lr=lr_gen,
                         betas=(0.5, 0.99),
                         weight_decay=1e-6)
    opt_dis = optim.Adam(network_dis.parameters(),
                         lr=lr_dis,
                         betas=(0.5, 0.99),
                         weight_decay=1e-6)

    sch_gen = optim.lr_scheduler.ReduceLROnPlateau(opt_gen,
                                                   'min',
                                                   factor=0.5,
                                                   patience=4,
                                                   verbose=True,
                                                   min_lr=1e-6)
    sch_dis = optim.lr_scheduler.ReduceLROnPlateau(opt_dis,
                                                   'min',
                                                   factor=0.5,
                                                   patience=4,
                                                   verbose=True,
                                                   min_lr=1e-6)
    sch_meter = utils.AvgMeter()

    scaler_gen = amp.GradScaler(enabled=enable_fp16)
    scaler_dis = amp.GradScaler(enabled=enable_fp16)

    loss_dis_real_meter = utils.AvgMeter()
    loss_dis_fake_meter = utils.AvgMeter()
    loss_dis_meter = utils.AvgMeter()

    loss_gen_l1_meter = utils.AvgMeter()
    loss_gen_l1_coarse_meter = utils.AvgMeter()
    loss_gen_vgg_meter = utils.AvgMeter()
    loss_gen_vgg_coarse_meter = utils.AvgMeter()
    loss_gen_gan_meter = utils.AvgMeter()
    loss_gen_fm_meter = utils.AvgMeter()
    loss_gen_meter = utils.AvgMeter()

    writer = SummaryWriter(os.path.join(checkpoint_path, 'tb_summary'))
    os.makedirs(os.path.join(checkpoint_path, 'checkpoints'), exist_ok=True)

    fakepool = utils.ImagePool(fake_pool_size, device)

    counter_start = 0
    if resume:
        chekcpoints = os.listdir(os.path.join(checkpoint_path, 'checkpoints'))
        last_chekcpoints = sorted(
            chekcpoints, key=lambda item: (len(item), item)
        )[-1] if 'latest.ckpt' not in chekcpoints else 'latest.ckpt'
        print(f' -- Loading checkpoint {last_chekcpoints}')
        ckpt = torch.load(
            os.path.join(checkpoint_path, 'checkpoints', last_chekcpoints))
        network_gen.load_state_dict(ckpt['gen'])
        network_dis.load_state_dict(ckpt['dis'])
        opt_gen.load_state_dict(ckpt['gen_opt'])
        opt_dis.load_state_dict(ckpt['dis_opt'])
        counter_start = ckpt['counter'] + 1
        print(f' -- Resume training from update {counter_start}')
    else:
        print(f' -- Start training from scratch')

    dataloader = iter(dataloader)

    print(' -- Training start')
    try:
        for counter in tqdm(range(counter_start, total_updates)):
            dataloader_meter = utils.AvgMeter()
            # train discrimiantor
            for critic in range(n_critic):
                opt_dis.zero_grad()
                for _ in range(gradient_accumulate):
                    start_time = time.time()
                    real_img, mask = next(dataloader)
                    end_time = time.time()
                    dataloader_meter(end_time - start_time)
                    real_img, mask = real_img.to(device), mask.to(device)
                    real_img_masked = mask_image(real_img, mask)
                    if np.random.randint(0,
                                         2) == 0 or not fakepool.available():
                        with torch.no_grad(), amp.autocast(
                                enabled=enable_fp16):
                            fake_img = network_gen(real_img_masked, mask)
                        fakepool.put(fake_img)
                    else:
                        fake_img = fakepool.sample()
                    with amp.autocast(enabled=enable_fp16):
                        real_logits, _ = network_dis(real_img)
                        fake_logits, _ = network_dis(fake_img)
                        loss_dis_real = loss_gan(real_logits, 'real', None)
                        mask_inv = 1 - F.interpolate(
                            mask,
                            size=(real_logits.shape[2], real_logits.shape[3]),
                            mode='bicubic',
                            align_corners=False)
                        loss_dis_fake = loss_gan(fake_logits, 'fake', mask_inv)
                        loss_dis = 0.5 * (loss_dis_real + loss_dis_fake)
                    if torch.isnan(loss_dis) or torch.isinf(loss_dis):
                        raise Exception

                    scaler_dis.scale(loss_dis /
                                     float(gradient_accumulate)).backward()

                    loss_dis_real_meter(loss_dis_real.item())
                    loss_dis_fake_meter(loss_dis_fake.item())
                    loss_dis_meter(loss_dis.item())
                scaler_dis.unscale_(opt_dis)
                #torch.nn.utils.clip_grad_norm_(network_dis.parameters(), 0.1)
                scaler_dis.step(opt_dis)
                scaler_dis.update()
            # train generator
            for gen in range(n_gen):
                opt_gen.zero_grad()
                for _ in range(gradient_accumulate):
                    start_time = time.time()
                    real_img, mask = next(dataloader)
                    end_time = time.time()
                    dataloader_meter(end_time - start_time)
                    real_img, mask = real_img.to(device), mask.to(device)
                    real_img_masked = mask_image(real_img, mask)
                    with amp.autocast(enabled=enable_fp16):
                        inpainted_result = network_gen(real_img_masked, mask)
                        #inpainted_result_coarse, inpainted_result = network_gen(real_img_masked, mask)
                        loss_gen_l1 = F.l1_loss(inpainted_result, real_img)
                        loss_vgg_combined = loss_vgg(inpainted_result,
                                                     real_img)
                        generator_logits, dis_features_fake = network_dis(
                            inpainted_result)
                        with torch.no_grad():
                            _, dis_features_real = network_dis(real_img)
                        loss_fm = 0
                        for (fm_fake, fm_real) in zip(dis_features_fake,
                                                      dis_features_real):
                            loss_fm += F.mse_loss(fm_fake, fm_real)
                        loss_fm = loss_fm / len(dis_features_real)
                        loss_gen_gan = loss_gan(generator_logits, 'generator',
                                                None)
                        loss_gen = weight_l1 * (
                            loss_gen_l1) + weight_fm * loss_fm + weight_vgg * (
                                loss_vgg_combined) + weight_gan * loss_gen_gan
                    if torch.isnan(loss_gen) or torch.isinf(loss_dis):
                        raise Exception

                    scaler_gen.scale(loss_gen /
                                     float(gradient_accumulate)).backward()

                    loss_gen_meter(loss_gen.item())
                    loss_gen_l1_meter(loss_gen_l1.item())
                    sch_meter(loss_gen_l1.item()
                              )  # use L1 loss as lr scheduler metric
                    loss_gen_vgg_meter(loss_vgg_combined.item())
                    loss_gen_gan_meter(loss_gen_gan.item())
                    loss_gen_fm_meter(loss_fm.item())
                scaler_gen.unscale_(opt_gen)
                #torch.nn.utils.clip_grad_norm_(network_gen.parameters(), 0.1)
                scaler_gen.step(opt_gen)
                scaler_gen.update()
            if counter % record_freq == 0:
                tqdm.write(f' -- Record at update {counter}')
                writer.add_scalar('discriminator/all',
                                  loss_dis_meter(reset=True), counter)
                writer.add_scalar('discriminator/real',
                                  loss_dis_real_meter(reset=True), counter)
                writer.add_scalar('discriminator/fake',
                                  loss_dis_fake_meter(reset=True), counter)
                writer.add_scalar('generator/all', loss_gen_meter(reset=True),
                                  counter)
                writer.add_scalar('generator/l1',
                                  loss_gen_l1_meter(reset=True), counter)
                writer.add_scalar('generator/vgg',
                                  loss_gen_vgg_meter(reset=True), counter)
                writer.add_scalar('generator/gan',
                                  loss_gen_gan_meter(reset=True), counter)
                writer.add_scalar('generator/fm',
                                  loss_gen_fm_meter(reset=True), counter)
                writer.add_image('original/image',
                                 img_unscale(real_img),
                                 counter,
                                 dataformats='NCHW')
                writer.add_image('original/mask',
                                 mask,
                                 counter,
                                 dataformats='NCHW')
                writer.add_image('original/masked',
                                 img_unscale(real_img_masked),
                                 counter,
                                 dataformats='NCHW')
                writer.add_image('inpainted/refined',
                                 img_unscale(inpainted_result),
                                 counter,
                                 dataformats='NCHW')
                torch.save(
                    {
                        'dis': network_dis.state_dict(),
                        'gen': network_gen.state_dict(),
                        'dis_opt': opt_dis.state_dict(),
                        'gen_opt': opt_gen.state_dict(),
                        'counter': counter
                    },
                    os.path.join(checkpoint_path, 'checkpoints',
                                 f'update_{counter}.ckpt'))
            if counter > 0 and counter % updates_per_epoch == 0:
                tqdm.write(f' -- Epoch finished at update {counter}')
                # epoch finished
                loss_epoch = sch_meter(reset=True)
                sch_gen.step(loss_epoch)
                sch_dis.step(loss_epoch)
            #tqdm.write(f'Dataloader overhead avg {int(dataloader_meter(reset = True) * 1000)}ms')
    except KeyboardInterrupt:
        print(' -- Training interrupted, saving latest model ..')
        torch.save(
            {
                'dis': network_dis.state_dict(),
                'gen': network_gen.state_dict(),
                'dis_opt': opt_dis.state_dict(),
                'gen_opt': opt_gen.state_dict(),
                'counter': counter
            }, os.path.join(checkpoint_path, 'checkpoints', f'latest.ckpt'))
Beispiel #5
0
    def _build_net(self):
        # tfph: TensorFlow PlaceHolder
        self.x_test_tfph = tf.placeholder(tf.float32,
                                          shape=[None, *self.img_size],
                                          name='x_test_tfph')
        self.y_test_tfph = tf.placeholder(tf.float32,
                                          shape=[None, *self.img_size],
                                          name='y_test_tfph')
        self.fake_x_tfph = tf.placeholder(tf.float32,
                                          shape=[None, *self.img_size],
                                          name='fake_x_tfph')
        self.fake_y_tfph = tf.placeholder(tf.float32,
                                          shape=[None, *self.img_size],
                                          name='fake_y_tfph')

        self.G_gen = Generator(name='G',
                               ngf=self.ngf,
                               norm=self.norm,
                               image_size=self.img_size,
                               _ops=self._G_gen_train_ops)
        self.Dy_dis = Discriminator(name='Dy',
                                    ndf=self.ndf,
                                    norm=self.norm,
                                    _ops=self._Dy_dis_train_ops,
                                    use_sigmoid=self.use_sigmoid)
        self.F_gen = Generator(name='F',
                               ngf=self.ngf,
                               norm=self.norm,
                               image_size=self.img_size,
                               _ops=self._F_gen_train_ops)
        self.Dx_dis = Discriminator(name='Dx',
                                    ndf=self.ndf,
                                    norm=self.norm,
                                    _ops=self._Dx_dis_train_ops,
                                    use_sigmoid=self.use_sigmoid)

        data_reader = Reader(self.data_path,
                             name='data',
                             image_size=self.img_size,
                             batch_size=self.flags.batch_size,
                             is_train=self.flags.is_train)
        # self.x_imgs_ori and self.y_imgs_ori are the images before data augmentation
        self.x_imgs, self.y_imgs, self.x_imgs_ori, self.y_imgs_ori, self.img_name = data_reader.feed(
        )

        self.fake_x_pool_obj = utils.ImagePool(pool_size=50)
        self.fake_y_pool_obj = utils.ImagePool(pool_size=50)

        # cycle consistency loss
        cycle_loss = self.cycle_consistency_loss(self.x_imgs, self.y_imgs)

        # X -> Y
        self.fake_y_imgs = self.G_gen(self.x_imgs)
        self.G_gen_loss = self.generator_loss(self.Dy_dis,
                                              self.fake_y_imgs,
                                              use_lsgan=self.use_lsgan)
        self.G_loss = self.G_gen_loss + cycle_loss
        self.Dy_dis_loss = self.discriminator_loss(self.Dy_dis,
                                                   self.y_imgs,
                                                   self.fake_y_tfph,
                                                   use_lsgan=self.use_lsgan)

        # Y -> X
        self.fake_x_imgs = self.F_gen(self.y_imgs)
        self.F_gen_loss = self.generator_loss(self.Dx_dis,
                                              self.fake_x_imgs,
                                              use_lsgan=self.use_lsgan)
        self.F_loss = self.F_gen_loss + cycle_loss
        self.Dx_dis_loss = self.discriminator_loss(self.Dx_dis,
                                                   self.x_imgs,
                                                   self.fake_x_tfph,
                                                   use_lsgan=self.use_lsgan)

        G_optim = self.optimizer(loss=self.G_loss,
                                 variables=self.G_gen.variables,
                                 name='Adam_G')
        Dy_optim = self.optimizer(loss=self.Dy_dis_loss,
                                  variables=self.Dy_dis.variables,
                                  name='Adam_Dy')
        F_optim = self.optimizer(loss=self.F_loss,
                                 variables=self.F_gen.variables,
                                 name='Adam_F')
        Dx_optim = self.optimizer(loss=self.Dx_dis_loss,
                                  variables=self.Dx_dis.variables,
                                  name='Adam_Dx')
        self.optims = tf.group([G_optim, Dy_optim, F_optim, Dx_optim])
        # with tf.control_dependencies([G_optim, Dy_optim, F_optim, Dx_optim]):
        #     self.optims = tf.no_op(name='optimizers')

        # for sampling function
        self.fake_y_sample = self.G_gen(self.x_test_tfph)
        self.fake_x_sample = self.F_gen(self.y_test_tfph)
Beispiel #6
0
                                 betas=(params.beta1, params.beta2))
D_B_optimizer = torch.optim.Adam(D_B.parameters(),
                                 lr=params.lrD,
                                 betas=(params.beta1, params.beta2))

# Training GAN
D_A_avg_losses = []
D_B_avg_losses = []
G_A_avg_losses = []
G_B_avg_losses = []
cycle_A_avg_losses = []
cycle_B_avg_losses = []

# Generated image pool
num_pool = 50
fake_A_pool = utils.ImagePool(num_pool)
fake_B_pool = utils.ImagePool(num_pool)

step = 0
for epoch in range(params.num_epochs):
    D_A_losses = []
    D_B_losses = []
    G_A_losses = []
    G_B_losses = []
    cycle_A_losses = []
    cycle_B_losses = []

    # learning rate decay
    if (epoch + 1) > params.decay_epoch:
        D_A_optimizer.param_groups[0]['lr'] -= params.lrD / (
            params.num_epochs - params.decay_epoch)
Beispiel #7
0
                    action='store_true',
                    help='use pre-trained model')

source_prediction_max_result = []
target_prediction_max_result = []
best_prec_result = torch.tensor(0, dtype=torch.float32)

args = parser.parse_args()

os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu

cuda = True if torch.cuda.is_available() else False
FloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if cuda else torch.LongTensor

fake_A_buffer = utils.ImagePool(max_size=args.max_buffer)
fake_B_buffer = utils.ImagePool(max_size=args.max_buffer)


def main():
    global args, best_prec_result

    utils.default_model_dir = args.dir
    start_time = time.time()

    Source_train_loader, Source_test_loader = dataset_selector(args.sd)
    Target_train_loader, Target_test_loader = dataset_selector(args.td)

    state_info = utils.model_optim_state_info()
    state_info.model_init()
    state_info.model_cuda_init()
Beispiel #8
0
	def train(self,args):

		self.dataset_A = loadPickleFile("cache_check/coded_sps_A_norm.pickle")
		self.dataset_B = loadPickleFile("cache_check/coded_sps_B_norm.pickle")
		n_samples = len(self.dataset_A)

		dataset = trainingDataset(datasetA=self.dataset_A,
									  datasetB=self.dataset_B,
									  n_frames=128)
		train_loader = torch.utils.data.DataLoader(dataset=dataset, batch_size=args.batch_size, shuffle=True,						   drop_last=False, num_workers=4)
		
		a_fake_sample = utils.ImagePool(50)
		b_fake_sample = utils.ImagePool(50)
		
	
		for epoch in range(self.start_epoch, args.epochs):

			lr = self.g_optimizer.param_groups[0]['lr']
			print('learning rate = %.7f' % lr)

			for i, (a_real, b_real) in enumerate(train_loader):
				# step
				a_real = a_real.float()
				b_real = b_real.float()

		

				step = epoch *len(train_loader) + i + 1
				print(step)
			
				# Generator Computations
				##################################################

				set_grad([self.Da, self.Db], False)
				self.g_optimizer.zero_grad()


				# Forward pass through generators
				##################################################
				a_fake = self.g_AB(a_real.cuda())
				b_fake = self.g_BA(b_real.cuda())

				a_recon = self.g_AB(b_fake)
				b_recon = self.g_BA(a_fake)


				a_idt = self.g_AB(a_real.cuda())
				b_idt = self.g_BA(b_real.cuda())


				# Identity losses
				###################################################
				a_idt_loss = self.L1(a_idt, a_real.cuda()) * args.lamda * args.idt_coef
				b_idt_loss = self.L1(b_idt, b_real.cuda()) * args.lamda * args.idt_coef

				if self.loss_type=='lsgan':

					# Adversarial losses
					###################################################
					a_fake_dis = self.Da(a_fake)
					b_fake_dis = self.Db(b_fake)

					real_label = utils.cuda(Variable(torch.ones(a_fake_dis.size())))


					a_gen_loss = self.MSE(a_fake_dis, real_label)
					b_gen_loss = self.MSE(b_fake_dis, real_label)
				elif self.loss_type=='wgan':

					# Wasserstein-GAN loss
					# G_A(A)
					a_gen_loss = self.criterionGAN(b_fake, generator_loss=True)

					# G_B(B)
					b_gen_loss = self.criterionGAN(a_fake, generator_loss=True)

				# Cycle consistency losses
				###################################################
				a_cycle_loss = self.L1(a_recon, a_real.cuda()) * args.lamda
				b_cycle_loss = self.L1(b_recon, b_real.cuda()) * args.lamda

				# Total generators losses
				###################################################
				gen_loss = a_gen_loss + b_gen_loss + a_cycle_loss + b_cycle_loss + a_idt_loss + b_idt_loss

				# Update generators
				###################################################
				gen_loss.backward(retain_graph=True)
				self.g_optimizer.step()


				# Discriminator Computations
				#################################################


				set_grad([self.Da, self.Db], True)
				self.d_optimizer.zero_grad()

				# Sample from history of generated images
				#################################################
				a_fake = a_fake_sample.query(a_fake)
				b_fake = b_fake_sample.query(b_fake)
				a_fake, b_fake = utils.cuda([a_fake, b_fake])

				print ("A_R_Size",a_fake.size())
				print ("B_R_Size",b_fake.size())  


				if self.loss_type=='lsgan':

					# Forward pass through discriminators
					################################################# 
					a_real_dis = self.Da(a_real.cuda())
					a_fake_dis = self.Da(a_fake)
					b_real_dis = self.Db(b_real.cuda())
					b_fake_dis = self.Db(b_fake)

					real_label = utils.cuda(Variable(torch.ones(a_real_dis.size())))
					fake_label = utils.cuda(Variable(torch.zeros(a_fake_dis.size())))

					# Discriminator losses
					##################################################
					a_dis_real_loss = self.MSE(a_real_dis, real_label)
					a_dis_fake_loss = self.MSE(a_fake_dis, fake_label)
					b_dis_real_loss = self.MSE(b_real_dis, real_label)
					b_dis_fake_loss = self.MSE(b_fake_dis, fake_label)

					# Total discriminators losses
					a_dis_loss = (a_dis_real_loss + a_dis_fake_loss)*0.5
					b_dis_loss = (b_dis_real_loss + b_dis_fake_loss)*0.5

				elif self.loss_type=='wgan':
					for i_critic in range(self.wgan_n_critic):
					# Clip the parameters for k-Lipschitz continuity
						for p in self.Da.parameters():
							p.data.clamp_(self.wgan_clamp_lower, self.wgan_clamp_upper)
						for p in self.Db.parameters():
							p.data.clamp_(self.wgan_clamp_lower, self.wgan_clamp_upper)
					#D_A
						a_dis_loss = self.backward_D_wasserstein(self.Da, a_real.cuda(), a_fake)
						# D_B
						b_dis_loss = self.backward_D_wasserstein(self.Db, b_real.cuda(), b_fake)

				# Update discriminators
				##################################################
				a_dis_loss.backward(retain_graph=True)
				b_dis_loss.backward(retain_graph=True)
				self.d_optimizer.step()
				
				writer.add_scalar('DisA loss',  a_dis_loss,
						epoch * len(train_loader) + i)
				writer.add_scalar('DisB loss',  b_dis_loss,
						epoch * len(train_loader) + i)
				
				writer.add_scalar('Generator loss',  gen_loss / 1000,
						epoch * len(train_loader) + i)



				print("Epoch: (%3d) (%5d/%5d) | Gen Loss:%.2e | Dis Loss:%.2e" %(epoch, i + 1, len(train_loader), gen_loss,a_dis_loss+b_dis_loss))
				
			# Override the latest checkpoint
			#######################################################
			utils.save_checkpoint({'epoch': epoch + 1,
								   'Da': self.Da.state_dict(),
								   'Db': self.Db.state_dict(),
								   'Gab': self.g_AB.state_dict(),
								   'Gba': self.g_BA.state_dict(),
								   'd_optimizer': self.d_optimizer.state_dict(),
								   'g_optimizer': self.g_optimizer.state_dict()},
								  '%s/w_gan_2.ckpt' % (args.checkpoint_dir))

			# Update learning rates
			########################
			self.g_lr_scheduler.step()
			self.d_lr_scheduler.step()
Beispiel #9
0
def main():
    # Get training options
    opt = get_opt()

    # Define the networks
    # netG_A: used to transfer image from domain A to domain B
    # netG_B: used to transfer image from domain B to domain A
    netG_A = networks.Generator(opt.input_nc, opt.output_nc, opt.ngf,
                                opt.n_res, opt.dropout)
    netG_B = networks.Generator(opt.output_nc, opt.input_nc, opt.ngf,
                                opt.n_res, opt.dropout)
    if opt.u_net:
        netG_A = networks.U_net(opt.input_nc, opt.output_nc, opt.ngf)
        netG_B = networks.U_net(opt.output_nc, opt.input_nc, opt.ngf)

    # netD_A: used to test whether an image is from domain B
    # netD_B: used to test whether an image is from domain A
    netD_A = networks.Discriminator(opt.input_nc, opt.ndf)
    netD_B = networks.Discriminator(opt.output_nc, opt.ndf)

    # Initialize the networks
    if opt.cuda:
        netG_A.cuda()
        netG_B.cuda()
        netD_A.cuda()
        netD_B.cuda()
    utils.init_weight(netG_A)
    utils.init_weight(netG_B)
    utils.init_weight(netD_A)
    utils.init_weight(netD_B)

    if opt.pretrained:
        netG_A.load_state_dict(torch.load('pretrained/netG_A.pth'))
        netG_B.load_state_dict(torch.load('pretrained/netG_B.pth'))
        netD_A.load_state_dict(torch.load('pretrained/netD_A.pth'))
        netD_B.load_state_dict(torch.load('pretrained/netD_B.pth'))

    # Define the loss functions
    criterion_GAN = utils.GANLoss()
    if opt.cuda:
        criterion_GAN.cuda()

    criterion_cycle = torch.nn.L1Loss()
    # Alternatively, can try MSE cycle consistency loss
    #criterion_cycle = torch.nn.MSELoss()
    criterion_identity = torch.nn.L1Loss()

    # Define the optimizers
    optimizer_G = torch.optim.Adam(itertools.chain(netG_A.parameters(),
                                                   netG_B.parameters()),
                                   lr=opt.lr,
                                   betas=(opt.beta1, 0.999))
    optimizer_D_A = torch.optim.Adam(netD_A.parameters(),
                                     lr=opt.lr,
                                     betas=(opt.beta1, 0.999))
    optimizer_D_B = torch.optim.Adam(netD_B.parameters(),
                                     lr=opt.lr,
                                     betas=(opt.beta1, 0.999))

    # Create learning rate schedulers
    lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(
        optimizer_G,
        lr_lambda=utils.Lambda_rule(opt.epoch, opt.n_epochs,
                                    opt.n_epochs_decay).step)
    lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(
        optimizer_D_A,
        lr_lambda=utils.Lambda_rule(opt.epoch, opt.n_epochs,
                                    opt.n_epochs_decay).step)
    lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(
        optimizer_D_B,
        lr_lambda=utils.Lambda_rule(opt.epoch, opt.n_epochs,
                                    opt.n_epochs_decay).step)

    Tensor = torch.cuda.FloatTensor if opt.cuda else torch.Tensor
    input_A = Tensor(opt.batch_size, opt.input_nc, opt.sizeh, opt.sizew)
    input_B = Tensor(opt.batch_size, opt.output_nc, opt.sizeh, opt.sizew)

    # Define two image pools to store generated images
    fake_A_pool = utils.ImagePool()
    fake_B_pool = utils.ImagePool()

    # Define the transform, and load the data
    transform = transforms.Compose([
        transforms.Resize((opt.sizeh, opt.sizew)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5, ), (0.5, ))
    ])
    dataloader = DataLoader(ImageDataset(opt.rootdir,
                                         transform=transform,
                                         mode='train'),
                            batch_size=opt.batch_size,
                            shuffle=True,
                            num_workers=opt.n_cpu)

    # numpy arrays to store the loss of epoch
    loss_G_array = np.zeros(opt.n_epochs + opt.n_epochs_decay)
    loss_D_A_array = np.zeros(opt.n_epochs + opt.n_epochs_decay)
    loss_D_B_array = np.zeros(opt.n_epochs + opt.n_epochs_decay)

    # Training
    for epoch in range(opt.epoch, opt.n_epochs + opt.n_epochs_decay):
        start = time.strftime("%H:%M:%S")
        print("current epoch :", epoch, " start time :", start)
        # Empty list to store the loss of each mini-batch
        loss_G_list = []
        loss_D_A_list = []
        loss_D_B_list = []

        for i, batch in enumerate(dataloader):
            if i % 50 == 1:
                print("current step: ", i)
                current = time.strftime("%H:%M:%S")
                print("current time :", current)
                print("last loss G:", loss_G_list[-1], "last loss D_A",
                      loss_D_A_list[-1], "last loss D_B", loss_D_B_list[-1])
            real_A = input_A.copy_(batch['A'])
            real_B = input_B.copy_(batch['B'])

            # Train the generator
            optimizer_G.zero_grad()

            # Compute fake images and reconstructed images
            fake_B = netG_A(real_A)
            fake_A = netG_B(real_B)

            if opt.identity_loss != 0:
                same_B = netG_A(real_B)
                same_A = netG_B(real_A)

            # discriminators require no gradients when optimizing generators
            utils.set_requires_grad([netD_A, netD_B], False)

            # Identity loss
            if opt.identity_loss != 0:
                loss_identity_A = criterion_identity(
                    same_A, real_A) * opt.identity_loss
                loss_identity_B = criterion_identity(
                    same_B, real_B) * opt.identity_loss

            # GAN loss
            prediction_fake_B = netD_B(fake_B)
            loss_gan_B = criterion_GAN(prediction_fake_B, True)
            prediction_fake_A = netD_A(fake_A)
            loss_gan_A = criterion_GAN(prediction_fake_A, True)

            # Cycle consistent loss
            recA = netG_B(fake_B)
            recB = netG_A(fake_A)
            loss_cycle_A = criterion_cycle(recA, real_A) * opt.cycle_loss
            loss_cycle_B = criterion_cycle(recB, real_B) * opt.cycle_loss

            # total loss without the identity loss
            loss_G = loss_gan_B + loss_gan_A + loss_cycle_A + loss_cycle_B

            if opt.identity_loss != 0:
                loss_G += loss_identity_A + loss_identity_B

            loss_G_list.append(loss_G.item())
            loss_G.backward()
            optimizer_G.step()

            # Train the discriminator
            utils.set_requires_grad([netD_A, netD_B], True)

            # Train the discriminator D_A
            optimizer_D_A.zero_grad()
            # real images
            pred_real = netD_A(real_A)
            loss_D_real = criterion_GAN(pred_real, True)

            # fake images
            fake_A = fake_A_pool.query(fake_A)
            pred_fake = netD_A(fake_A.detach())
            loss_D_fake = criterion_GAN(pred_fake, False)

            #total loss
            loss_D_A = (loss_D_real + loss_D_fake) * 0.5
            loss_D_A_list.append(loss_D_A.item())
            loss_D_A.backward()
            optimizer_D_A.step()

            # Train the discriminator D_B
            optimizer_D_B.zero_grad()
            # real images
            pred_real = netD_B(real_B)
            loss_D_real = criterion_GAN(pred_real, True)

            # fake images
            fake_B = fake_B_pool.query(fake_B)
            pred_fake = netD_B(fake_B.detach())
            loss_D_fake = criterion_GAN(pred_fake, False)

            # total loss
            loss_D_B = (loss_D_real + loss_D_fake) * 0.5
            loss_D_B_list.append(loss_D_B.item())
            loss_D_B.backward()
            optimizer_D_B.step()

        # Update the learning rate
        lr_scheduler_G.step()
        lr_scheduler_D_A.step()
        lr_scheduler_D_B.step()

        # Save models checkpoints
        torch.save(netG_A.state_dict(), 'model/netG_A.pth')
        torch.save(netG_B.state_dict(), 'model/netG_B.pth')
        torch.save(netD_A.state_dict(), 'model/netD_A.pth')
        torch.save(netD_B.state_dict(), 'model/netD_B.pth')

        # Save other checkpoint information
        checkpoint = {
            'epoch': epoch,
            'optimizer_G': optimizer_G.state_dict(),
            'optimizer_D_A': optimizer_D_A.state_dict(),
            'optimizer_D_B': optimizer_D_B.state_dict(),
            'lr_scheduler_G': lr_scheduler_G.state_dict(),
            'lr_scheduler_D_A': lr_scheduler_D_A.state_dict(),
            'lr_scheduler_D_B': lr_scheduler_D_B.state_dict()
        }
        torch.save(checkpoint, 'model/checkpoint.pth')

        # Update the numpy arrays that record the loss
        loss_G_array[epoch] = sum(loss_G_list) / len(loss_G_list)
        loss_D_A_array[epoch] = sum(loss_D_A_list) / len(loss_D_A_list)
        loss_D_B_array[epoch] = sum(loss_D_B_list) / len(loss_D_B_list)
        np.savetxt('model/loss_G.txt', loss_G_array)
        np.savetxt('model/loss_D_A.txt', loss_D_A_array)
        np.savetxt('model/loss_D_b.txt', loss_D_B_array)

        if epoch % 10 == 9:
            torch.save(netG_A.state_dict(),
                       'model/netG_A' + str(epoch) + '.pth')
            torch.save(netG_B.state_dict(),
                       'model/netG_B' + str(epoch) + '.pth')
            torch.save(netD_A.state_dict(),
                       'model/netD_A' + str(epoch) + '.pth')
            torch.save(netD_B.state_dict(),
                       'model/netD_B' + str(epoch) + '.pth')

        end = time.strftime("%H:%M:%S")
        print("current epoch :", epoch, " end time :", end)
        print("G loss :", loss_G_array[epoch], "D_A loss :",
              loss_D_A_array[epoch], "D_B loss :", loss_D_B_array[epoch])
Beispiel #10
0
MSE_loss = nn.MSELoss().cuda()
L1_loss = nn.L1Loss().cuda()

# Adam optimizer
G_optimizer = optim.Adam(itertools.chain(G_A.parameters(), G_B.parameters()),
                         lr=opt.lrG,
                         betas=(opt.beta1, opt.beta2))
D_A_optimizer = optim.Adam(D_A.parameters(),
                           lr=opt.lrD,
                           betas=(opt.beta1, opt.beta2))
D_B_optimizer = optim.Adam(D_B.parameters(),
                           lr=opt.lrD,
                           betas=(opt.beta1, opt.beta2))

# image store
fakeA_store = utils.ImagePool(50)
fakeB_store = utils.ImagePool(50)

train_hist = utils.train_histogram_initialize()

print('**************************start training!**************************')
start_time = time.time()
for epoch in range(opt.train_epoch):
    D_A_losses = []
    D_B_losses = []
    G_A_losses = []
    G_B_losses = []
    A_cycle_losses = []
    B_cycle_losses = []
    epoch_start_time = time.time()
    num_iter = 0
Beispiel #11
0
    def _build_net(self):
        # tfph: TensorFlow PlaceHolder
        self.x_test_tfph = tf.placeholder(tf.float32, shape=[None, *self.img_size], name='x_test_tfph')
        self.y_test_tfph = tf.placeholder(tf.float32, shape=[None, *self.img_size], name='y_test_tfph')
        self.xy_fake_pairs_tfph = tf.placeholder(tf.float32, shape=[None, self.img_size[0], self.img_size[1], 2],
                                                 name='xy_fake_pairs_tfph')
        self.yx_fake_pairs_tfph = tf.placeholder(tf.float32, shape=[None, self.img_size[0], self.img_size[1], 2],
                                                 name='yx_fake_pairs_tfph')

        self.G_gen = Generator(name='G', ngf=self.ngf, norm=self.norm, image_size=self.img_size,
                               _ops=self._G_gen_train_ops)
        self.Dy_dis = Discriminator(name='Dy', ndf=self.ndf, norm=self.norm, _ops=self._Dy_dis_train_ops,
                                    is_lsgan=self.is_lsgan)
        self.F_gen = Generator(name='F', ngf=self.ngf, norm=self.norm, image_size=self.img_size,
                               _ops=self._F_gen_train_ops)
        self.Dx_dis = Discriminator(name='Dx', ndf=self.ndf, norm=self.norm, _ops=self._Dx_dis_train_ops,
                                    is_lsgan=self.is_lsgan)
        self.vggModel = VGG16(name='VGG16_Pretrained')

        data_reader = Reader(self.data_path, name='data', image_size=self.img_size, batch_size=self.flags.batch_size,
                             is_train=self.flags.is_train)
        # self.x_imgs_ori and self.y_imgs_ori are the images before data augmentation
        self.x_imgs, self.y_imgs, self.x_imgs_ori, self.y_imgs_ori, self.img_name = data_reader.feed()

        self.fake_xy_pool_obj = utils.ImagePool(pool_size=50)
        self.fake_yx_pool_obj = utils.ImagePool(pool_size=50)

        # cycle consistency loss
        self.cycle_loss = self.cycle_consistency_loss(self.x_imgs, self.y_imgs)

        # concatenation
        self.fake_y_imgs = self.G_gen(self.x_imgs)
        self.xy_real_pairs = tf.concat([self.x_imgs, self.y_imgs], axis=3)
        self.xy_fake_pairs = tf.concat([self.x_imgs, self.fake_y_imgs], axis=3)

        self.fake_x_imgs = self.F_gen(self.y_imgs)
        self.yx_real_pairs = tf.concat([self.y_imgs, self.x_imgs], axis=3)
        self.yx_fake_pairs = tf.concat([self.y_imgs, self.fake_x_imgs], axis=3)

        # X -> Y
        self.G_gen_loss = self.generator_loss(self.Dy_dis, self.xy_fake_pairs, is_lsgan=self.is_lsgan)
        self.G_cond_loss = self.voxel_loss(preds=self.fake_y_imgs, gts=self.y_imgs)
        self.G_gdl_loss = self.gradient_difference_loss(preds=self.fake_y_imgs, gts=self.y_imgs)
        self.G_perceptual_loss = self.perceptual_loss_fn(preds=self.fake_y_imgs, gts=self.y_imgs)
        self.G_loss = self.G_gen_loss + self.G_cond_loss + self.cycle_loss + self.G_gdl_loss + self.G_perceptual_loss
        self.Dy_dis_loss = self.discriminator_loss(self.Dy_dis, self.xy_real_pairs, self.xy_fake_pairs_tfph,
                                                   is_lsgan=self.is_lsgan)

        # Y -> X
        self.F_gen_loss = self.generator_loss(self.Dx_dis, self.yx_fake_pairs, is_lsgan=self.is_lsgan)
        self.F_cond_loss = self.voxel_loss(preds=self.fake_x_imgs, gts=self.x_imgs)
        self.F_gdl_loss = self.gradient_difference_loss(preds=self.fake_x_imgs, gts=self.x_imgs)
        self.F_perceputal_loss = self.perceptual_loss_fn(preds=self.fake_x_imgs, gts=self.x_imgs)
        self.F_loss = self.F_gen_loss + self.F_cond_loss + self.cycle_loss + self.F_gdl_loss + self.F_perceputal_loss
        self.Dx_dis_loss = self.discriminator_loss(self.Dx_dis, self.yx_real_pairs, self.yx_fake_pairs_tfph,
                                                   is_lsgan=self.is_lsgan)

        G_optim = self.optimizer(loss=self.G_loss, variables=self.G_gen.variables, name='Adam_G')
        Dy_optim = self.optimizer(loss=self.Dy_dis_loss, variables=self.Dy_dis.variables, name='Adam_Dy')
        F_optim = self.optimizer(loss=self.F_loss, variables=self.F_gen.variables, name='Adam_F')
        Dx_optim = self.optimizer(loss=self.Dx_dis_loss, variables=self.Dx_dis.variables, name='Adam_Dx')
        self.optims = tf.group([G_optim, Dy_optim, F_optim, Dx_optim])

        # for sampling function
        self.fake_y_sample = self.G_gen(self.x_test_tfph)
        self.fake_x_sample = self.F_gen(self.y_test_tfph)
Beispiel #12
0
    def _build_net(self):
        # tfph: TensorFlow PlaceHolder
        self.x_test_tfph = placeholder(tf.float32, shape=[None, *self.img_size], name='x_test_tfph')
        self.y_test_tfph = placeholder(tf.float32, shape=[None, *self.img_size], name='y_test_tfph')

        # Supervised learning placeholders for Image Pool Tech.
        self.xy_fake_pairs_tfph = placeholder(tf.float32, shape=[None, self.img_size[0], self.img_size[1], 2],
                                              name='xy_fake_pairs_tfph')
        self.yx_fake_pairs_tfph = placeholder(tf.float32, shape=[None, self.img_size[0], self.img_size[1], 2],
                                              name='yx_fake_pairs_tfph')

        # Unsupervised learning placeholders for Image Pool Tech.
        self.xy_fake_unpairs_tfph = placeholder(tf.float32, shape=[None, self.img_size[0], self.img_size[1], 1],
                                                name='xy_fake_unpairs_tfph')
        self.yx_fake_unpairs_tfph = placeholder(tf.float32, shape=[None, self.img_size[0], self.img_size[1], 1],
                                                name='yx_fake_unpairs_tfph')

        self.G_gen = Generator(name='G', ngf=self.ngf, norm=self.norm, image_size=self.img_size,
                               _ops=self._G_gen_train_ops)

        self.Dy_dis_sup = Discriminator(
            name='Dy_sup', ndf=self.ndf, norm=self.norm, model=self.flags.dis_model, shared_reuse=False,
            _ops=self._Dy_dis_train_ops)

        self.Dy_dis_unsup = Discriminator(
            name='Dy_unsup', ndf=self.ndf, norm=self.norm, model=self.flags.dis_model, shared_reuse=True,
            _ops=self._Dy_dis_train_ops)

        self.F_gen = Generator(
            name='F', ngf=self.ngf, norm=self.norm, image_size=self.img_size, _ops=self._F_gen_train_ops)

        self.Dx_dis_sup = Discriminator(
            name='Dx_sup', ndf=self.ndf, norm=self.norm, model=self.flags.dis_model, shared_reuse=False,
            _ops=self._Dx_dis_train_ops)

        self.Dx_dis_unsup = Discriminator(
            name='Dx_unsup', ndf=self.ndf, norm=self.norm, model=self.flags.dis_model, shared_reuse=True,
            _ops=self._Dx_dis_train_ops)

        self.vggModel = VGG16(name='VGG16_Pretrained')

        data_reader = Reader(self.data_path, name='data', image_size=self.img_size, batch_size=self.flags.batch_size,
                             is_train=self.flags.is_train)
        # self.x_imgs_ori and self.y_imgs_ori are the images before data augmentation
        self.x_imgs, self.y_imgs, self.x_imgs_ori, self.y_imgs_ori, self.img_name = data_reader.feed()

        self.fake_xy_pool_obj_sup = utils.ImagePool(pool_size=50)
        self.fake_yx_pool_obj_sup = utils.ImagePool(pool_size=50)
        self.fake_xy_pool_obj_unsup = utils.ImagePool(pool_size=50)
        self.fake_yx_pool_obj_unsup = utils.ImagePool(pool_size=50)

        # cycle consistency loss
        self.cycle_loss = self.cycle_consistency_loss(self.x_imgs, self.y_imgs)

        # concatenation
        self.fake_y_imgs = self.G_gen(self.x_imgs)
        self.xy_real_pairs = tf.concat([self.x_imgs, self.y_imgs], axis=3)
        self.xy_fake_pairs = tf.concat([self.x_imgs, self.fake_y_imgs], axis=3)

        self.fake_x_imgs = self.F_gen(self.y_imgs)
        self.yx_real_pairs = tf.concat([self.y_imgs, self.x_imgs], axis=3)
        self.yx_fake_pairs = tf.concat([self.y_imgs, self.fake_x_imgs], axis=3)

        # X -> Y
        # Supervised learning
        self.G_gen_loss_sup = self.generator_loss(self.Dy_dis_sup, self.xy_fake_pairs)
        self.G_cond_loss = self.voxel_loss(preds=self.fake_y_imgs, gts=self.y_imgs)
        self.G_gdl_loss = self.gradient_difference_loss(preds=self.fake_y_imgs, gts=self.y_imgs)
        self.G_perceptual_loss = self.perceptual_loss_fn(preds=self.fake_y_imgs, gts=self.y_imgs)
        self.G_ssim_loss = self.ssim_loss_fn(preds=self.fake_y_imgs, gts=self.y_imgs)
        self.G_loss_sup = self.G_gen_loss_sup + self.G_cond_loss + self.cycle_loss + self.G_gdl_loss + \
                          self.G_perceptual_loss + self.G_ssim_loss
        self.Dy_dis_loss_sup = self.discriminator_loss(
            self.Dy_dis_sup, self.xy_real_pairs, self.xy_fake_pairs_tfph, is_lsgan=self.is_lsgan)

        # Unsupervised learning
        self.G_gen_loss_unsup = self.generator_loss(self.Dy_dis_unsup, self.fake_y_imgs)
        self.G_loss_unsup = self.G_gen_loss_unsup + self.cycle_loss
        self.Dy_dis_loss_unsup = self.discriminator_loss(
            self.Dy_dis_unsup, self.y_imgs, self.xy_fake_unpairs_tfph, is_lsgan=False)

        # Integrated optimization
        self.G_gen_loss_integrated = self.G_loss_sup + self.G_loss_unsup
        self.Dy_dis_loss_integrated = self.Dy_dis_loss_sup + self.Dy_dis_loss_unsup

        # Y -> X
        # Supervised learning
        self.F_gen_loss_sup = self.generator_loss(self.Dx_dis_sup, self.yx_fake_pairs)
        self.F_cond_loss = self.voxel_loss(preds=self.fake_x_imgs, gts=self.x_imgs)
        self.F_gdl_loss = self.gradient_difference_loss(preds=self.fake_x_imgs, gts=self.x_imgs)
        self.F_perceputal_loss = self.perceptual_loss_fn(preds=self.fake_x_imgs, gts=self.x_imgs)
        self.F_ssim_loss = self.ssim_loss_fn(preds=self.fake_x_imgs, gts=self.x_imgs)
        self.F_loss_sup = self.F_gen_loss_sup + self.F_cond_loss + self.cycle_loss + self.F_gdl_loss + \
                          self.F_perceputal_loss + self.F_ssim_loss
        self.Dx_dis_loss_sup = self.discriminator_loss(
            self.Dx_dis_sup, self.yx_real_pairs, self.yx_fake_pairs_tfph, is_lsgan=self.is_lsgan)

        # Unsupervised Learning
        self.F_gen_loss_unsup = self.generator_loss(self.Dx_dis_unsup, self.fake_x_imgs)
        self.F_loss_unsup = self.F_gen_loss_unsup + self.cycle_loss
        self.Dx_dis_loss_unsup = self.discriminator_loss(
            self.Dx_dis_unsup, self.x_imgs, self.yx_fake_unpairs_tfph, is_lsgan=False)

        # Integrated optimization
        self.F_gen_loss_integrated = self.F_loss_sup + self.F_loss_unsup
        self.Dx_dis_loss_integrated = self.Dx_dis_loss_sup + self.Dx_dis_loss_unsup

        # Supervised learning
        G_optim_sup = self.optimizer(
            loss=self.G_loss_sup, variables=self.G_gen.variables, name='Adam_G_sup')
        Dy_optim_sup = self.optimizer(
            loss=self.Dy_dis_loss_sup, variables=self.Dy_dis_sup.variables, name='Adam_Dy_sup')
        F_optim_sup = self.optimizer(
            loss=self.F_loss_sup, variables=self.F_gen.variables, name='Adam_F_sup')
        Dx_optim_sup = self.optimizer(
            loss=self.Dx_dis_loss_sup, variables=self.Dx_dis_sup.variables, name='Adam_Dx_sup')
        self.optims_sup = tf.group([G_optim_sup, Dy_optim_sup, F_optim_sup, Dx_optim_sup])

        # Unsupervised learning
        G_optim_unsup = self.optimizer(
            loss=self.G_loss_unsup, variables=self.G_gen.variables, name='Adam_G_unsup')
        Dy_optim_unsup = self.optimizer(
            loss=self.Dy_dis_loss_unsup, variables=self.Dy_dis_unsup.variables, name='Adam_Dy_unsup')
        F_optim_unsup = self.optimizer(
            loss=self.F_loss_unsup, variables=self.F_gen.variables, name='Adam_F_unsup')
        Dx_optim_unsup = self.optimizer(
            loss=self.Dx_dis_loss_unsup, variables=self.Dx_dis_unsup.variables, name='Adam_Dx_unsup')
        self.optims_unsup = tf.group([G_optim_unsup, Dy_optim_unsup, F_optim_unsup, Dx_optim_unsup])

        # Integrated optimization
        G_optim_integrated = self.optimizer(
            loss=self.G_gen_loss_integrated, variables=self.G_gen.variables, name='Adam_G_integrated')
        Dy_optim_integrated = self.optimizer(
            loss=self.Dy_dis_loss_integrated, variables=[self.Dy_dis_sup.variables, self.Dy_dis_unsup.variables],
            name='Adam_Dy_integrated')
        F_optim_integrated = self.optimizer(
            loss=self.F_gen_loss_integrated, variables=self.F_gen.variables, name='Adam_F_integrated')
        Dx_optim_integrated = self.optimizer(
            loss=self.Dx_dis_loss_integrated, variables=[self.Dx_dis_sup.variables, self.Dx_dis_unsup.variables],
            name='Adam_Dx_integrated')
        self.optims_integrated = tf.group(
            [G_optim_integrated, Dy_optim_integrated, F_optim_integrated, Dx_optim_integrated])

        # for sampling function
        self.fake_y_sample = self.G_gen(self.x_test_tfph)
        self.fake_x_sample = self.F_gen(self.y_test_tfph)

        self.print_network_vars(is_print=True)
Beispiel #13
0
def train(epochs):
    gan_loss = gluon.loss.SigmoidBinaryCrossEntropyLoss()
    l1_loss = gluon.loss.L1Loss()

    trainer_G = gluon.Trainer(netG.collect_params(),
                              'adam',
                              optimizer_params={
                                  'learning_rate': 0.0002,
                                  'beta1': 0.5,
                                  'beta2': 0.999
                              })
    trainer_D = gluon.Trainer(netD.collect_params(),
                              'adam',
                              optimizer_params={
                                  'learning_rate': 0.0002,
                                  'beta1': 0.5,
                                  'beta2': 0.999
                              })

    ## config the log file
    logging.basicConfig()
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    fh = logging.FileHandler(os.path.join(log_dir, 'train.log'))
    logger.addHandler(fh)

    sw = SummaryWriter(logdir=os.path.join(log_dir, 'train_sw'))
    batch_len = train_iter.num_data // train_iter.batch_size

    image_pool = utils.ImagePool(50)

    global_step = 0
    for epch in range(epochs):
        train_iter.reset()
        epch_time = time.time()
        batch_time = time.time()
        for iter_step, databatch in enumerate(train_iter):
            data = databatch.data[0].as_in_context(ctx)
            label = databatch.label[0].as_in_context(ctx)

            ## train netD
            pred = netG(data)
            # fake_data =nd.concat(data, pred, dim=1)
            # fake_data = image_pool.fetch_img(fake_data)
            fake_data = image_pool.fetch_img(nd.concat(data, pred, dim=1))
            with autograd.record():
                # fake
                pred_fake = netD(fake_data)
                fake_label = nd.zeros_like(pred_fake)
                loss_fake = gan_loss(pred_fake, fake_label).sum()
                # real
                real_data = nd.concat(data, label, dim=1)
                pred_real = netD(real_data)
                real_label = nd.ones_like(pred_real)
                loss_real = gan_loss(pred_real, real_label).sum()

                loss_D = (loss_real + loss_fake) * 0.5
                loss_D.backward()
            trainer_D.step(data.shape[0])
            sw.add_scalar('lossD', loss_D.asscalar(), global_step)

            ## train netG
            with autograd.record():
                pred = netG(data)
                in_data = nd.concat(data, pred, dim=1)
                pred_real = netD(in_data)
                pred_label = nd.ones_like(pred_real)

                ganloss_g = gan_loss(pred_real, pred_label)
                l1loss_g = l1_loss(pred, label)
                loss_G = ganloss_g + l1loss_g * l1_lambda
                loss_G = loss_G.sum()
                loss_G.backward()
            trainer_G.step(data.shape[0])
            sw.add_scalar('lossG', loss_G.asscalar(), global_step)

            ## do the checkpoints during intra epoch
            if (iter_step + 1) % log_iter_intervals == 0:
                logger.info(
                    '[Epoch {}][Iter {}] Done., Speed: {:.4f} sample / s'.
                    format(str(epch), str(iter_step),
                           data.shape[0] / (time.time() - batch_time)))

            batch_time = time.time()
            global_step += 1

        ## do the evaluation after every epoch
        fake_img = pred[0]
        img_arr = (fake_img - mx.nd.min(fake_img)) / (mx.nd.max(fake_img) -
                                                      mx.nd.min(fake_img))
        # img_arr = img_arr[::-1, :, :]
        sw.add_image('generated image', img_arr)
        eval(epch)

        ## do the checkpoints inter epochs
        netG.save_parameters(ckpt_fmt.format('netG', str(epch)))
        netD.save_parameters(ckpt_fmt.format('netD', str(epch)))

        logger.info('[Epoch {}] Done. Cost: {:.4f} s'.format(
            str(epch),
            time.time() - epch_time))
Beispiel #14
0
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
                    help='use pre-trained model')

source_prediction_max_result = []
target_prediction_max_result = []
best_prec_result = torch.tensor(0, dtype=torch.float32)

args = parser.parse_args()

os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu

cuda = True if torch.cuda.is_available() else False
FloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if cuda else torch.LongTensor

fake_S_buffer = utils.ImagePool(max_size=args.max_buffer)
fake_T_buffer = utils.ImagePool(max_size=args.max_buffer)

# adversarial_loss = torch.nn.BCELoss()
criterion_GAN = torch.nn.MSELoss()
criterion_Cycle = torch.nn.L1Loss()
criterion_Recov = torch.nn.MSELoss()
criterion = nn.CrossEntropyLoss().cuda()

def main():
    global args, best_prec_result
    
    utils.default_model_dir = args.dir
    start_time = time.time()

    Source_train_loader, Source_test_loader = dataset_selector(args.sd)
Beispiel #15
0
def main(unused_argv):
    total_step = 0
    checkpoints_dir = './models/real2cartoon'
    summary_dir = './summary'

    graph = tf.Graph()
    with graph.as_default():
        cycle_gan = CycleGAN(batch_size=FLAGS.batch_size,
                             image_size=256,
                             use_mse=FLAGS.use_mse,
                             lambda1=FLAGS.lambda1,
                             lambda2=FLAGS.lambda2,
                             learning_rate=FLAGS.learning_rate,
                             filters=FLAGS.filters,
                             beta1=FLAGS.beta1,
                             mse_label=FLAGS.mse_label,
                             file_x=FLAGS.file_x,
                             file_y=FLAGS.file_y)

        G_loss, F_loss, D_X_loss, D_Y_loss, fake_y, fake_x = cycle_gan.model()
        optimizers = cycle_gan.optimize(G_loss, F_loss, D_X_loss, D_Y_loss)

        summarys = tf.summary.merge_all()
        train_writer = tf.summary.FileWriter(summary_dir, graph)
        saver = tf.train.Saver()

    with tf.Session(graph=graph) as sess:
        ckpt = tf.train.get_checkpoint_state(checkpoints_dir)

        if ckpt and ckpt.model_checkpoint_path:
            ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
            saver.restore(sess, ckpt.model_checkpoint_path)
            total_step = int(
                next(re.finditer("(\d+)(?!.*\d)", ckpt_name)).group(0))
            logger.info('load model success' + ckpt.model_checkpoint_path)
        else:
            sess.run(tf.global_variables_initializer())
            logger.info('start new model')

        # img_x = utils.get_img(FLAGS.file_x, FLAGS.output_height, FLAGS.output_width, FLAGS.batch_size)
        # img_y = utils.get_img(FLAGS.file_y, FLAGS.output_height, FLAGS.output_width, FLAGS.batch_size)

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        try:
            fake_X_pool = utils.ImagePool(FLAGS.pool_size)
            fake_Y_pool = utils.ImagePool(FLAGS.pool_size)

            while not coord.should_stop():
                # img_x, img_y = read_file()
                fake_y_val, fake_x_val = sess.run([fake_y, fake_x])

                _, G_loss_val, D_Y_loss_val, F_loss_val, D_X_loss_val, summary = (
                    sess.run(
                        [
                            optimizers, G_loss, D_Y_loss, F_loss, D_X_loss,
                            summarys
                        ],
                        feed_dict={
                            cycle_gan.x: fake_X_pool.query(fake_x_val),
                            cycle_gan.y: fake_Y_pool.query(fake_y_val)
                        }))

                train_writer.add_summary(summary, total_step)
                train_writer.flush()

                logger.info('step: {}'.format(total_step))
                if total_step > 1e5:
                    sess.run(cycle_gan.learning_rate_decay_op())

                if total_step % 100 == 0:
                    logger.info('-----------Step %d:-------------' %
                                total_step)
                    logger.info('  G_loss   : {}'.format(G_loss_val))
                    logger.info('  D_Y_loss : {}'.format(D_Y_loss_val))
                    logger.info('  F_loss   : {}'.format(F_loss_val))
                    logger.info('  D_X_loss : {}'.format(D_X_loss_val))
                    logger.info('  learning_rate : {}'.format(
                        cycle_gan.learning_rate))

                if total_step % 10000 == 0:
                    save_path = saver.save(sess,
                                           checkpoints_dir + "/model.ckpt",
                                           global_step=total_step)
                    logger.info("Model saved in file: %s" % save_path)

                total_step += 1
        except KeyboardInterrupt:
            logger.info('Interrupted')
            coord.request_stop()
        except Exception as e:
            coord.request_stop(e)
        finally:
            save_path = saver.save(sess,
                                   checkpoints_dir + "/model.ckpt",
                                   global_step=total_step)
            logger.info("Model saved in file: %s" % save_path)
            # When done, ask the threads to stop.
            coord.request_stop()
            coord.join(threads)
    def _build_graph(self):
        self.x_test_tfph = tf.compat.v1.placeholder(
            tf.float32, shape=[None, *self.input_shape], name='x_test_tfph')
        self.fake_pair_tfph = tf.compat.v1.placeholder(
            tf.float32,
            shape=[None, self.input_shape[0], self.input_shape[1], 2],
            name='fake_pairs_tfph')
        self.rate_tfph = tf.compat.v1.placeholder(tf.float32,
                                                  name='keep_prob_ph')

        # Initialize TFRecord reader
        train_reader = Reader(tfrecordsFile=self.data_path[0],
                              decodeImgShape=self.decode_img_shape,
                              imgShape=self.input_shape,
                              batchSize=self.batch_size,
                              name='train')

        # Initialize generator & discriminator
        self.gen_obj = Generator(name='G',
                                 gen_c=self.gen_c,
                                 norm='instance',
                                 logger=self.logger,
                                 _ops=None)
        self.dis_obj = Discriminator(name='D',
                                     dis_c=self.dis_c,
                                     norm='instance',
                                     logger=self.logger,
                                     _ops=None)

        # Random batch for training
        self.img_train, self.seg_img_train = train_reader.shuffle_batch()
        self.img_pool_obj = utils.ImagePool(pool_size=150)

        # Transform img_train and seg_img_train
        trans_seg_img_train = self.transform_seg(self.seg_img_train)
        trans_img_train = self.transform_img(self.img_train)

        # Concatenation
        self.g_sample = self.gen_obj(trans_seg_img_train, self.rate_tfph)
        self.real_pair = tf.concat([trans_seg_img_train, trans_img_train],
                                   axis=3)
        self.fake_pair = tf.concat([trans_seg_img_train, self.g_sample],
                                   axis=3)

        # Define generator loss
        self.gen_adv_loss = self.generator_loss(self.dis_obj, self.fake_pair)
        self.cond_loss = self.conditional_loss(pred=self.g_sample,
                                               gt=trans_img_train)
        self.gen_loss = self.gen_adv_loss + self.cond_loss

        # Define discriminator loss
        self.dis_loss = self.discriminator_loss(self.dis_obj, self.real_pair,
                                                self.fake_pair_tfph)

        # Optimizers
        self.gen_optim = self.init_optimizer(loss=self.gen_loss,
                                             variables=self.gen_obj.variables,
                                             name='Adam_gen')
        self.dis_optim = self.init_optimizer(loss=self.dis_loss,
                                             variables=self.dis_obj.variables,
                                             name='Adam_dis')
Beispiel #17
0
    def _build_net(self):
        self.mae_record_placeholder = tf.placeholder(
            tf.float32, name='mae_record_placeholder')
        self.mae_record = tf.Variable(256.,
                                      trainable=False,
                                      dtype=tf.float32,
                                      name='mae_record')
        self.mae_record_assign_op = self.mae_record.assign(
            self.mae_record_placeholder)

        # tfph: TensorFlow PlaceHolder
        self.x_test_tfph = tf.placeholder(tf.float32,
                                          shape=[None, *self.image_size],
                                          name='x_test_tfph')
        self.y_test_tfph = tf.placeholder(tf.float32,
                                          shape=[None, *self.image_size],
                                          name='y_test_tfph')
        self.fake_x_tfph = tf.placeholder(tf.float32,
                                          shape=[None, *self.image_size],
                                          name='fake_x_tfph')
        self.fake_y_tfph = tf.placeholder(tf.float32,
                                          shape=[None, *self.image_size],
                                          name='fake_y_tfph')

        self.G_gen = Generator(name='G',
                               ngf=self.ngf,
                               norm=self.norm,
                               image_size=self.image_size,
                               _ops=self._G_gen_train_ops)
        self.Dy_dis = Discriminator(name='Dy',
                                    ndf=self.ndf,
                                    norm=self.norm,
                                    _ops=self._Dy_dis_train_ops,
                                    use_sigmoid=self.use_sigmoid)
        self.F_gen = Generator(name='F',
                               ngf=self.ngf,
                               norm=self.norm,
                               image_size=self.image_size,
                               _ops=self._F_gen_train_ops)
        self.Dx_dis = Discriminator(name='Dx',
                                    ndf=self.ndf,
                                    norm=self.norm,
                                    _ops=self._Dx_dis_train_ops,
                                    use_sigmoid=self.use_sigmoid)

        x_reader = Reader(self.x_path,
                          name='X',
                          image_size=self.image_size,
                          batch_size=self.flags.batch_size)
        y_reader = Reader(self.y_path,
                          name='Y',
                          image_size=self.image_size,
                          batch_size=self.flags.batch_size)
        self.x_imgs = x_reader.feed()
        self.y_imgs = y_reader.feed()

        self.fake_x_pool_obj = utils.ImagePool(pool_size=50)
        self.fake_y_pool_obj = utils.ImagePool(pool_size=50)

        # cycle consistency loss
        cycle_loss = self.cycle_consistency_loss(self.x_imgs, self.y_imgs)

        # X -> Y
        self.fake_y_imgs = self.G_gen(self.x_imgs)
        self.G_gen_loss = self.generator_loss(self.Dy_dis,
                                              self.fake_y_imgs,
                                              use_lsgan=self.use_lsgan)
        self.G_loss = self.G_gen_loss + cycle_loss
        self.Dy_dis_loss = self.discriminator_loss(self.Dy_dis,
                                                   self.y_imgs,
                                                   self.fake_y_tfph,
                                                   use_lsgan=self.use_lsgan)

        # Y -> X
        self.fake_x_imgs = self.F_gen(self.y_imgs)
        self.F_gen_loss = self.generator_loss(self.Dx_dis,
                                              self.fake_x_imgs,
                                              use_lsgan=self.use_lsgan)
        self.F_loss = self.F_gen_loss + cycle_loss
        self.Dx_dis_loss = self.discriminator_loss(self.Dx_dis,
                                                   self.x_imgs,
                                                   self.fake_x_tfph,
                                                   use_lsgan=self.use_lsgan)

        G_op = self.optimizer(loss=self.G_loss,
                              variables=self.G_gen.variables,
                              name='Adam_G')
        G_ops = [G_op] + self._G_gen_train_ops
        G_optim = tf.group(*G_ops)

        Dy_op = self.optimizer(loss=self.Dy_dis_loss,
                               variables=self.Dy_dis.variables,
                               name='Adam_Dy')
        Dy_ops = [Dy_op] + self._Dy_dis_train_ops
        Dy_optim = tf.group(*Dy_ops)

        F_op = self.optimizer(loss=self.F_loss,
                              variables=self.F_gen.variables,
                              name='Adam_F')
        F_ops = [F_op] + self._F_gen_train_ops
        F_optim = tf.group(*F_ops)

        Dx_op = self.optimizer(loss=self.Dx_dis_loss,
                               variables=self.Dx_dis.variables,
                               name='Adam_Dx')
        Dx_ops = [Dx_op] + self._Dx_dis_train_ops
        Dx_optim = tf.group(*Dx_ops)

        self.optims = tf.group([G_optim, Dy_optim, F_optim, Dx_optim])
        # with tf.control_dependencies([G_optim, Dy_optim, F_optim, Dx_optim]):
        #     self.optims = tf.no_op(name='optimizers')

        # for sampling function
        self.fake_y_sample = self.G_gen(self.x_test_tfph)
        self.fake_x_sample = self.F_gen(self.y_test_tfph)
        self.recon_x_sample = self.F_gen(self.G_gen(self.x_test_tfph))
    def _build_net(self):
        # tfph: TensorFlow PlaceHolder
        self.x_test_tfph = tf.placeholder(
            tf.float32,
            shape=[None, self.img_size[0], self.img_size[1], self.mm_dim],
            name='x_test_tfph')
        self.y_test_tfph = tf.placeholder(
            tf.float32,
            shape=[None, self.img_size[0], self.img_size[1], self.fp_dim],
            name='y_test_tfph')
        self.xy_fake_pairs_tfph = tf.placeholder(tf.float32,
                                                 shape=[
                                                     None, self.img_size[0],
                                                     self.img_size[1],
                                                     self.mm_dim + self.fp_dim
                                                 ],
                                                 name='xy_fake_pairs_tfph')
        self.yx_fake_pairs_tfph = tf.placeholder(tf.float32,
                                                 shape=[
                                                     None, self.img_size[0],
                                                     self.img_size[1],
                                                     self.mm_dim + self.fp_dim
                                                 ],
                                                 name='yx_fake_pairs_tfph')

        self.G_gen = Generator(name='G',
                               ngf=self.ngf,
                               norm=self.norm,
                               image_size=self.img_size,
                               output_dim=self.fp_dim,
                               _ops=self._G_gen_train_ops)
        self.Dy_dis = Discriminator(name='Dy',
                                    ndf=self.ndf,
                                    norm=self.norm,
                                    _ops=self._Dy_dis_train_ops)
        self.F_gen = Generator(name='F',
                               ngf=self.ngf,
                               norm=self.norm,
                               image_size=self.img_size,
                               output_dim=self.mm_dim,
                               _ops=self._F_gen_train_ops)
        self.Dx_dis = Discriminator(name='Dx',
                                    ndf=self.ndf,
                                    norm=self.norm,
                                    _ops=self._Dx_dis_train_ops)

        data_reader = Reader(self.data_path,
                             name='data',
                             image_size=self.img_size,
                             batch_size=self.flags.batch_size,
                             is_train=self.flags.is_train)
        # self.x_imgs_ori and self.y_imgs_ori are the images before data augmentation
        self.x_imgs, self.y_imgs, self.x_imgs_ori, self.y_imgs_ori, self.img_name = data_reader.feed(
        )

        # slicing minutiae and fingerprint images to 2d and 1d
        self.x_imgs_2d = tf.slice(self.x_imgs,
                                  begin=[0, 0, 0, 0],
                                  size=[-1, -1, -1, 2],
                                  name='x_slice')
        self.y_imgs_1d = tf.slice(self.y_imgs,
                                  begin=[0, 0, 0, 0],
                                  size=[-1, -1, -1, 1],
                                  name='y_slice')

        self.fake_xy_pool_obj = utils.ImagePool(pool_size=50)
        self.fake_yx_pool_obj = utils.ImagePool(pool_size=50)

        # cycle consistency loss
        self.cycle_loss = self.cycle_consistency_loss(self.x_imgs_2d,
                                                      self.y_imgs_1d)

        # concatenation
        self.fake_y_imgs = self.G_gen(self.x_imgs_2d)
        self.xy_real_pairs = tf.concat([self.x_imgs_2d, self.y_imgs_1d],
                                       axis=3)
        self.xy_fake_pairs = tf.concat([self.x_imgs_2d, self.fake_y_imgs],
                                       axis=3)

        self.fake_x_imgs = self.F_gen(self.y_imgs_1d)
        self.yx_real_pairs = tf.concat([self.y_imgs_1d, self.x_imgs_2d],
                                       axis=3)
        self.yx_fake_pairs = tf.concat([self.y_imgs_1d, self.fake_x_imgs],
                                       axis=3)

        # X -> Y
        self.G_gen_loss = self.generator_loss(self.Dy_dis, self.xy_fake_pairs)
        self.G_cond_loss = self.voxel_loss(preds=self.fake_y_imgs,
                                           gts=self.y_imgs_1d,
                                           weight=self.L1_lambda)
        self.G_loss = self.G_gen_loss + self.G_cond_loss + self.cycle_loss
        self.Dy_dis_loss = self.discriminator_loss(self.Dy_dis,
                                                   self.xy_real_pairs,
                                                   self.xy_fake_pairs_tfph)

        # Y -> X
        self.F_gen_loss = self.generator_loss(self.Dx_dis, self.yx_fake_pairs)
        self.F_cond_loss = self.voxel_loss(preds=self.fake_x_imgs,
                                           gts=self.x_imgs_2d,
                                           weight=0.)
        self.F_loss = self.F_gen_loss + self.F_cond_loss + self.cycle_loss
        self.Dx_dis_loss = self.discriminator_loss(self.Dx_dis,
                                                   self.yx_real_pairs,
                                                   self.yx_fake_pairs_tfph)

        G_optim = self.optimizer(loss=self.G_loss,
                                 variables=self.G_gen.variables,
                                 name='Adam_G')
        Dy_optim = self.optimizer(loss=self.Dy_dis_loss,
                                  variables=self.Dy_dis.variables,
                                  name='Adam_Dy')
        F_optim = self.optimizer(loss=self.F_loss,
                                 variables=self.F_gen.variables,
                                 name='Adam_F')
        Dx_optim = self.optimizer(loss=self.Dx_dis_loss,
                                  variables=self.Dx_dis.variables,
                                  name='Adam_Dx')
        self.optims = tf.group([G_optim, Dy_optim, F_optim, Dx_optim])

        # for sampling function
        self.fake_y_sample = self.G_gen(self.x_test_tfph)
        self.fake_x_sample = self.F_gen(self.y_test_tfph)
Beispiel #19
0
    def _build_net(self):
        self.mae_record_placeholder = tf.placeholder(
            tf.float32, name='mae_record_placeholder')
        self.mae_record = tf.Variable(256.,
                                      trainable=False,
                                      dtype=tf.float32,
                                      name='mae_record')
        self.mae_record_assign_op = self.mae_record.assign(
            self.mae_record_placeholder)

        # tfph: TensorFlow PlaceHolder
        self.x_test_tfph = tf.placeholder(tf.float32,
                                          shape=[None, *self.image_size],
                                          name='x_test_tfph')
        self.y_test_tfph = tf.placeholder(tf.float32,
                                          shape=[None, *self.image_size],
                                          name='y_test_tfph')
        self.fake_x_tfph = tf.placeholder(tf.float32,
                                          shape=[None, *self.image_size],
                                          name='fake_x_tfph')
        self.fake_y_tfph = tf.placeholder(tf.float32,
                                          shape=[None, *self.image_size],
                                          name='fake_y_tfph')

        self.G_gen = Generator(name='G',
                               ngf=self.ngf,
                               norm=self.norm,
                               image_size=self.image_size,
                               _ops=self._G_gen_train_ops)
        self.Dy_dis = Discriminator(name='Dy',
                                    ndf=self.ndf,
                                    norm=self.norm,
                                    _ops=self._Dy_dis_train_ops,
                                    use_sigmoid=self.use_sigmoid)
        self.F_gen = Generator(name='F',
                               ngf=self.ngf,
                               norm=self.norm,
                               image_size=self.image_size,
                               _ops=self._F_gen_train_ops)
        self.Dx_dis = Discriminator(name='Dx',
                                    ndf=self.ndf,
                                    norm=self.norm,
                                    _ops=self._Dx_dis_train_ops,
                                    use_sigmoid=self.use_sigmoid)

        x_reader = Reader(self.x_path,
                          name='X',
                          image_size=self.image_size,
                          batch_size=self.flags.batch_size)
        y_reader = Reader(self.y_path,
                          name='Y',
                          image_size=self.image_size,
                          batch_size=self.flags.batch_size)
        self.x_imgs = x_reader.feed()
        self.y_imgs = y_reader.feed()

        self.fake_x_pool_obj = utils.ImagePool(pool_size=50)
        self.fake_y_pool_obj = utils.ImagePool(pool_size=50)

        self._unpair_net()  # idea from cyclegan
        self._pair_net()  # idea from pix2pix

        # Optimizers
        # G generator for unpaired data
        G_op_unpair = self.optimizer(loss=self.G_loss_unpair,
                                     variables=self.G_gen.variables,
                                     name='Adam_G_unpair')
        G_ops_unpair = [G_op_unpair] + self._G_gen_train_ops
        G_optim_unpair = tf.group(*G_ops_unpair)

        # G generator for paired data
        G_op_pair = self.optimizer(loss=self.G_loss_pair,
                                   variables=self.G_gen.variables,
                                   name='Adam_G_pair')
        G_ops_pair = [G_op_pair] + self._G_gen_train_ops
        self.G_optim_pair = tf.group(*G_ops_pair)

        # Dy discriminator for unpaired data
        Dy_op_unpair = self.optimizer(loss=self.Dy_dis_loss_unpair,
                                      variables=[
                                          self.Dy_dis.share_variables,
                                          self.Dy_dis.unpair_variables
                                      ],
                                      name='Adam_Dy_unpair')
        Dy_ops_unpair = [Dy_op_unpair] + self._Dy_dis_train_ops
        Dy_optim_unpair = tf.group(*Dy_ops_unpair)

        # Dy discriminator for paired data
        Dy_op_pair = self.optimizer(loss=self.Dy_dis_loss_pair,
                                    variables=[
                                        self.Dy_dis.share_variables,
                                        self.Dy_dis.pair_variables
                                    ],
                                    name='Adam_Dy_pair')
        Dy_ops_pair = [Dy_op_pair] + self._Dy_dis_train_ops
        self.Dy_optim_pair = tf.group(*Dy_ops_pair)

        # F generator for unpaired data
        F_op_unpair = self.optimizer(loss=self.F_loss_unpair,
                                     variables=self.F_gen.variables,
                                     name='Adam_F_unpair')
        F_ops_unpair = [F_op_unpair] + self._F_gen_train_ops
        F_optim_unpair = tf.group(*F_ops_unpair)

        # F generator for paired data
        F_op_pair = self.optimizer(loss=self.F_loss_pair,
                                   variables=self.F_gen.variables,
                                   name='Adam_F_pair')
        F_ops_pair = [F_op_pair] + self._F_gen_train_ops
        self.F_optim_pair = tf.group(*F_ops_pair)

        # Dx discriminator for unpaired data
        Dx_op_unpair = self.optimizer(loss=self.Dx_dis_loss_unpair,
                                      variables=[
                                          self.Dx_dis.share_variables,
                                          self.Dx_dis.unpair_variables
                                      ],
                                      name='Adam_Dx_unpair')
        Dx_ops_unpair = [Dx_op_unpair] + self._Dx_dis_train_ops
        Dx_optim_unpair = tf.group(*Dx_ops_unpair)

        # Dx discriminator for paired data
        Dx_op_pair = self.optimizer(loss=self.Dx_dis_loss_pair,
                                    variables=[
                                        self.Dx_dis.share_variables,
                                        self.Dx_dis.pair_variables
                                    ],
                                    name='Adam_Dx_pair')
        Dx_ops_pair = [Dx_op_pair] + self._Dx_dis_train_ops
        self.Dx_optim_pair = tf.group(*Dx_ops_pair)

        self.optims_unpair = tf.group(
            [G_optim_unpair, Dy_optim_unpair, F_optim_unpair, Dx_optim_unpair])
        self.optims_pair = tf.group([
            self.G_optim_pair, self.Dy_optim_pair, self.F_optim_pair,
            self.Dx_optim_pair
        ])
        self.loss_collections = [
            self.G_loss_unpair, self.Dy_dis_loss_unpair, self.F_loss_unpair,
            self.Dx_dis_loss_unpair, self.G_loss_pair, self.Dy_dis_loss_pair,
            self.F_loss_pair, self.Dx_dis_loss_pair
        ]
Beispiel #20
0
    def train(self, args):
        # Obtain dataloaders
        loader = self.get_dataloader(args)

        # Generated image pools
        imagepool_a = utils.ImagePool()
        imagepool_b = utils.ImagePool()

        lambda_coef = args.lamda
        lambda_idt = args.idt_coef

        # Initialize Weights
        utils.init_weights(self.G_BA)
        utils.init_weights(self.G_AB)
        utils.init_weights(self.D_A)
        utils.init_weights(self.D_B)

        step = 0

        self.load_checkpoint(args)

        # Terrible hack
        self.gen_scheduler.last_epoch = self.curr_epoch - 1
        self.dis_scheduler.last_epoch = self.curr_epoch - 1

        self.G_BA.train()
        self.G_AB.train()

        for epoch in range(self.curr_epoch, args.epochs):

            for a_real, b_real in loader:
                # Send data to (ideally) GPU
                a_real = a_real.to(self.device)
                b_real = b_real.to(self.device)

                # batch size
                batch_size = a_real.shape[0]
                positive_labels = torch.ones(batch_size).to(self.device)
                negative_labels = torch.zeros(batch_size).to(self.device)

                # Generator forward passes
                a_fake = self.G_BA(b_real)
                b_fake = self.G_AB(a_real)

                a_reconstruct = self.G_BA(b_fake)
                b_reconstruct = self.G_AB(a_fake)

                a_identity = self.G_BA(a_real)
                b_identity = self.G_AB(b_real)

                # Identity Loss
                a_idt_loss = self.L1(a_identity,
                                     a_real) * lambda_coef * lambda_idt
                b_idt_loss = self.L1(b_identity,
                                     b_real) * lambda_coef * lambda_idt

                # GAN Loss
                a_fake_dis = self.D_A(a_fake)
                b_fake_dis = self.D_B(b_fake)

                positive_labels = torch.ones_like(a_fake_dis)

                a_gan_loss = self.MSE(a_fake_dis, positive_labels)
                b_gan_loss = self.MSE(b_fake_dis, positive_labels)

                # Cycle Loss
                a_cycle_loss = self.L1(a_reconstruct, a_real) * lambda_coef
                b_cycle_loss = self.L1(b_reconstruct, b_real) * lambda_coef

                # Total Loss
                total_gan_loss = a_idt_loss + b_idt_loss + a_gan_loss + b_gan_loss + a_cycle_loss + b_cycle_loss

                # Sample previously generated images for discriminator forward pass
                a_fake = torch.Tensor(
                    imagepool_a(a_fake.detach().cpu().clone().numpy())
                )  # a_fake first dim might be batch entry
                b_fake = torch.Tensor(
                    imagepool_b(b_fake.detach().cpu().clone().numpy()))

                a_fake = a_fake.to(self.device)
                b_fake = b_fake.to(self.device)

                # Discriminator forward pass
                a_real_dis = self.D_A(a_real)
                a_fake_dis = self.D_B(a_fake)
                b_real_dis = self.D_B(b_real)
                b_fake_dis = self.D_B(b_fake)

                # Discriminator Losses
                positive_labels = torch.ones_like(a_fake_dis)
                negative_labels = torch.zeros_like(a_fake_dis)

                a_dis_real_loss = self.MSE(a_real_dis, positive_labels)
                a_dis_fake_loss = self.MSE(a_fake_dis, negative_labels)
                b_dis_real_loss = self.MSE(b_real_dis, positive_labels)
                b_dis_fake_loss = self.MSE(b_fake_dis, negative_labels)

                a_dis_loss = (a_dis_real_loss + a_dis_fake_loss) * 0.5
                b_dis_loss = (b_dis_real_loss + b_dis_fake_loss) * 0.5

                # Step
                self.gen_optimizer.zero_grad()
                total_gan_loss.backward()
                self.gen_optimizer.step()

                self.dis_optimizer.zero_grad()
                a_dis_loss.backward()
                b_dis_loss.backward()
                self.dis_optimizer.step()

                for group in self.dis_optimizer.param_groups:
                    for p in group['params']:
                        state = self.dis_optimizer.state[p]
                        if state['step'] >= 962:
                            state['step'] = 962

                for group in self.gen_optimizer.param_groups:
                    for p in group['params']:
                        state = self.gen_optimizer.state[p]
                        if state['step'] >= 962:
                            state['step'] = 962

                if (step + 1) % 5 == 0:
                    print(
                        "Epoch: (%3d) (%5d/%5d) | Gen Loss:%.2e | Dis Loss:%.2e"
                        % (epoch, step + 1, len(loader), total_gan_loss,
                           a_dis_loss + b_dis_loss))

                step += 1
            self.save_checkpoint(epoch + 1, args)
            self.gen_scheduler.step()
            self.dis_scheduler.step()
            step = 0