Пример #1
0
class Trainer(object):
    def __init__(self, style_data_loader, content_data_loader, config):

        self.log_file           = os.path.join(config.log_path, config.version,config.version+"_log.log")
        self.report_file        = os.path.join(config.log_path, config.version,config.version+"_report.log")
        logging.basicConfig(filename=self.report_file,
            format='[%(asctime)s-%(levelname)s:%(message)s]', 
                level = logging.DEBUG,filemode='w',
                    datefmt='%Y-%m-%d%I:%M:%S %p')

        self.Experiment_description = config.experiment_description
        logging.info("Experiment description: \n%s"%self.Experiment_description)
        # Data loader
        self.style_data_loader = style_data_loader
        self.content_data_loader = content_data_loader

        # exact loss
        self.adv_loss = config.adv_loss
        logging.info("loss: %s"%self.adv_loss)

        # Model hyper-parameters
        self.imsize         = config.imsize
        logging.info("image size: %d"%self.imsize)
        self.batch_size     = config.batch_size
        logging.info("Batch size: %d"%self.batch_size)

        logging.info("Is shuffle: {}".format(config.is_shuffle))
        logging.info("Image center crop size: {}".format(config.center_crop))

        self.res_num        = config.res_num
        logging.info("resblock number: %d"%self.res_num)
        self.g_conv_dim     = config.g_conv_dim
        logging.info("generator convolution initial channel: %d"%self.g_conv_dim)
        self.d_conv_dim     = config.d_conv_dim
        logging.info("discriminator convolution initial channel: %d"%self.d_conv_dim)
        self.parallel       = config.parallel
        logging.info("Is multi-GPU parallel: %s"%str(self.parallel))
        self.gpus           = config.gpus
        logging.info("GPU number: %s"%self.gpus)
        self.total_step     = config.total_step
        logging.info("Total step: %d"%self.total_step)
        self.d_iters        = config.d_iters
        self.g_iters        = config.g_iters
        self.total_iters_ratio=config.total_iters_ratio
        
        self.num_workers    = config.num_workers

        self.g_lr           = config.g_lr
        logging.info("Generator learning rate: %f"%self.g_lr)
        self.d_lr           = config.d_lr
        logging.info("Discriminator learning rate: %f"%self.d_lr)
        self.lr_decay       = config.lr_decay
        logging.info("Learning rate decay: %f"%self.lr_decay)
        self.beta1          = config.beta1
        logging.info("Adam opitimizer beta1: %f"%self.beta1)
        self.beta2          = config.beta2
        logging.info("Adam opitimizer beta2: %f"%self.beta2)

        self.pretrained_model   = config.pretrained_model
        self.use_pretrained_model = config.use_pretrained_model
        logging.info("Use pretrained model: %s"%str(self.pretrained_model))

        self.use_tensorboard    = config.use_tensorboard
        logging.info("Use tensorboard: %s"%str(self.use_tensorboard))

        self.check_point_path   = config.check_point_path
        self.sample_path        = config.sample_path
        self.summary_path       = config.summary_path
        self.validation_path    = config.validation
        # val_dataloader          = Validation_Data_Loader(self.validation_path,self.imsize)
        # self.validation_data    = val_dataloader.load_validation_images()
        # valres_path = os.path.join(config.log_path, config.version, "valres")
        # if not os.path.exists(valres_path):
        #     os.makedirs(valres_path)
        # self.valres_path = valres_path

        self.log_step           = config.log_step
        self.sample_step        = config.sample_step
        self.model_save_step    = config.model_save_step
        self.prep_weights       = [1.0, 1.0, 1.0, 1.0, 1.0]
        self.transform_loss_w   = config.transform_loss_w
        logging.info("transform loss weight: %f"%self.transform_loss_w)
        self.feature_loss_w     = config.feature_loss_w
        logging.info("feature loss weight: %f"%self.feature_loss_w)
        self.style_class        = config.style_class
        self.real_prep_threshold= config.real_prep_threshold
        logging.info("real label threshold: %f"%self.real_prep_threshold)
        # self.TVLossWeight       = config.TV_loss_weight
        # logging.info("TV loss weight: %f"%self.TVLossWeight)


        self.discr_success_rate = config.discr_success_rate
        logging.info("discriminator success rate: %f"%self.discr_success_rate)

        logging.info("Is conditional generating: %s"%str(config.condition_model))

        self.device = torch.device('cuda:%s'%config.default_GPU if torch.cuda.is_available() else 'cpu')

        print('build_model...')
        self.build_model()
        if self.use_tensorboard:
            self.build_tensorboard()

        # Start with trained model
        if self.use_pretrained_model:
            print('load_pretrained_model...')

    def train(self):

        # Data iterator
        style_iter      = iter(self.style_data_loader)
        content_iter    = iter(self.content_data_loader)

        step_per_epoch  = len(self.style_data_loader)
        model_save_step = int(self.model_save_step)

        # Fixed input for debugging

        # Start with trained model
        if self.use_pretrained_model:
            start = self.pretrained_model + 1
        else:
            start = 0
        alternately_iter     = 0
        self.d_iters         = self.d_iters * self.total_iters_ratio
        max_alternately_iter = self.d_iters + self.total_iters_ratio * self.g_iters
        d_acc         = 0
        real_acc      = 0
        photo_acc     = 0
        fake_acc      = 0
        win_rate      = self.discr_success_rate
        discr_success = self.discr_success_rate
        alpha         = 0.05


        real_labels = []
        fake_labels = []
        # size = [[self.batch_size,122*122],[self.batch_size,58*58],[self.batch_size,10*10],[self.batch_size,2*2],[self.batch_size,2*2]]
        size = [[self.batch_size,1,760,760],[self.batch_size,1,371,371],[self.batch_size,1,83,83],[self.batch_size,1,11,11],[self.batch_size,1,6,6]]
        for i in range(5):
            real_label = torch.ones(size[i], device=self.device)
            fake_label = torch.zeros(size[i], device=self.device)
            # threshold = torch.zeros(size[i], device=self.device)
            real_labels.append(real_label)
            fake_labels.append(fake_label)

        # Start time
        print('Start   ======  training...')
        start_time = time.time()
        for step in range(start, self.total_step):
            self.Discriminator.train()
            self.Generator.train()
            # self.Decoder.train()
            try:
                content_images =next(content_iter)
                style_images = next(style_iter)
            except:
                style_iter      = iter(self.style_data_loader)
                content_iter    = iter(self.content_data_loader)
                style_images = next(style_iter)
                content_images = next(content_iter)
            style_images    = style_images.to(self.device)
            content_images  = content_images.to(self.device)
            # ================== Train D ================== #
            # Compute loss with real images
            if discr_success < win_rate:
                real_out = self.Discriminator(style_images)
                d_loss_real = 0
                real_acc = 0
                for i in range(len(real_out)):
                    temp = self.C_loss(real_out[i],real_labels[i]).mean()
                    real_acc +=  torch.gt(real_out[i],0).type(torch.float).mean()
                    temp *= self.prep_weights[i]
                    d_loss_real += temp
                real_acc /= len(real_out)

                d_loss_photo = 0
                photo_out = self.Discriminator(content_images)
                photo_acc = 0
                for i in range(len(photo_out)):
                    temp = self.C_loss(photo_out[i],fake_labels[i])
                    photo_acc +=  torch.lt(photo_out[i],0).type(torch.float).mean()
                    temp *= self.prep_weights[i]
                    d_loss_photo += temp
                photo_acc /= len(photo_out) 

                fake_image,_ = self.Generator(content_images)
                fake_out = self.Discriminator(fake_image.detach())
                d_loss_fake = 0
                fake_acc = 0
                for i in range(len(fake_out)):
                    temp = self.C_loss(fake_out[i],fake_labels[i]).mean()
                    fake_acc +=  torch.lt(fake_out[i],0).type(torch.float).mean()
                    temp *= self.prep_weights[i]
                    d_loss_fake += temp
                fake_acc /= len(fake_out) 
                d_acc = ((real_acc + photo_acc + fake_acc)/3).item()
                discr_success = discr_success * (1. - alpha) + alpha * d_acc
                # Backward + Optimize
                d_loss = d_loss_real + d_loss_photo + d_loss_fake
                self.reset_grad()
                d_loss.backward()
                self.d_optimizer.step()
            else:
                # ================== Train G ================== #   
                #      
                fake_image, real_feature= self.Generator(content_images)
                fake_feature            = self.Generator(fake_image, get_feature = True)
                fake_out                = self.Discriminator(fake_image)
                g_feature_loss          = self.L1_loss(fake_feature,real_feature)
                g_transform_loss        = self.MSE_loss(self.Transform(content_images),self.Transform(fake_image))
                g_loss_fake = 0
                g_acc = 0
                for i in range(len(fake_out)):
                    temp = self.C_loss(fake_out[i],real_labels[i]).mean()
                    g_acc +=  torch.gt(fake_out[i],0).type(torch.float).mean()
                    temp *= self.prep_weights[i]
                    g_loss_fake += temp
                g_acc /= len(fake_out)
                g_loss_fake = g_loss_fake + g_feature_loss*self.feature_loss_w + \
                                        g_transform_loss*self.transform_loss_w
                discr_success = discr_success * (1. - alpha) + alpha * (1.0 - g_acc)
                self.reset_grad()
                g_loss_fake.backward()
                self.g_optimizer.step()
                # self.decoder_optimizer.step()
            

            # Print out log info
            if (step + 1) % self.log_step == 0:
                elapsed = time.time() - start_time
                elapsed = str(datetime.timedelta(seconds=elapsed))
                print("Elapsed [{}], G_step [{}/{}], D_step[{}/{}], d_out_real: {:.4f}, d_out_fake: {:.4f}, g_loss_fake: {:.4f}".
                      format(elapsed, step + 1, self.total_step, (step + 1),
                             self.total_step , d_loss_real.item(), d_loss_fake.item(), g_loss_fake.item()))
                
                if self.use_tensorboard:
                    self.writer.add_scalar('data/d_loss_real', d_loss_real.item(),(step + 1))
                    self.writer.add_scalar('data/d_loss_fake', d_loss_fake.item(),(step + 1))
                    self.writer.add_scalar('data/d_loss', d_loss.item(), (step + 1))
                    self.writer.add_scalar('data/g_loss', g_loss_fake.item(), (step + 1))
                    self.writer.add_scalar('data/g_feature_loss', g_feature_loss, (step + 1))
                    self.writer.add_scalar('data/g_transform_loss', g_transform_loss, (step + 1))
                    # self.writer.add_scalar('data/g_tv_loss', g_tv_loss, (step + 1))
                    self.writer.add_scalar('acc/real_acc', real_acc.item(), (step + 1))
                    self.writer.add_scalar('acc/photo_acc', photo_acc.item(), (step + 1))
                    self.writer.add_scalar('acc/fake_acc', fake_acc.item(), (step + 1))
                    self.writer.add_scalar('acc/disc_acc', d_acc, (step + 1))
                    self.writer.add_scalar('acc/g_acc', g_acc, (step + 1))
                    self.writer.add_scalar("acc/discr_success",discr_success,(step+1))
                    

            # Sample images
            if (step + 1) % self.sample_step == 0:
                print('Sample images {}_fake.png'.format(step + 1))
                fake_images,_ = self.Generator(content_images)
                saved_image1 = torch.cat([denorm(content_images),denorm(fake_images.data)],3)
                saved_image2 = torch.cat([denorm(style_images),denorm(fake_images.data)],3)
                wocao        = torch.cat([saved_image1,saved_image2],2)
                save_image(wocao,
                           os.path.join(self.sample_path, '{}_fake.jpg'.format(step + 1)))
                # print("Transfer validation images")
                # num = 1
                # for val_img in self.validation_data:
                #     print("testing no.%d img"%num)
                #     val_img = val_img.to(self.device)
                #     fake_images,_ = self.Generator(val_img)
                #     saved_val_image = torch.cat([denorm(val_img),denorm(fake_images)],3)
                #     save_image(saved_val_image,
                #            os.path.join(self.valres_path, '%d_%d.jpg'%((step+1),num)))
                #     num +=1
                # save_image(denorm(displaymask.data),os.path.join(self.sample_path, '{}_mask.png'.format(step + 1)))

            if (step+1) % model_save_step==0:
                torch.save(self.Generator.state_dict(),
                           os.path.join(self.check_point_path , '{}_Generator.pth'.format(step + 1)))
                torch.save(self.Discriminator.state_dict(),
                           os.path.join(self.check_point_path , '{}_Discriminator.pth'.format(step + 1)))
            # alternately_iter += 1
            # alternately_iter %= max_alternately_iter
            
            

    def build_model(self):
        # code_dim=100, n_class=1000
        self.Generator = Generator(chn=self.g_conv_dim, k_size= 3, res_num= self.res_num).to(self.device)
        self.Discriminator = Discriminator(chn=self.d_conv_dim, k_size= 3).to(self.device)
        self.Transform = Transform_block().to(self.device)
        if self.parallel:

            print('use parallel...')
            print('gpuids ', self.gpus)
            gpus = [int(i) for i in self.gpus.split(',')]
    
            self.Generator      = nn.DataParallel(self.Generator, device_ids=gpus)
            self.Discriminator  = nn.DataParallel(self.Discriminator, device_ids=gpus)
            self.Transform      = nn.DataParallel(self.Transform, device_ids=gpus)

        # self.G.apply(weights_init)
        # self.D.apply(weights_init)

        # Loss and optimizer
        # self.g_optimizer = torch.optim.Adam(self.G.parameters(), self.g_lr, [self.beta1, self.beta2])

        self.g_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, 
                                    self.Generator.parameters()), self.g_lr, [self.beta1, self.beta2])
        # self.decoder_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, 
        #                             self.Decoder.parameters()), self.g_lr, [self.beta1, self.beta2])
        self.d_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, 
                                    self.Discriminator.parameters()), self.d_lr, [self.beta1, self.beta2])
        # self.L1_loss = torch.nn.L1Loss()
        self.MSE_loss = torch.nn.MSELoss()
        self.L1_loss = torch.nn.SmoothL1Loss()
        self.C_loss = torch.nn.BCEWithLogitsLoss()
        # self.TV_loss = TVLoss(self.TVLossWeight,self.imsize,self.batch_size)
        
        # print networks
        logging.info("Generator structure:")
        logging.info(self.Generator)
        # print(self.Decoder)
        logging.info("Discriminator structure:")
        logging.info(self.Discriminator)

    def build_tensorboard(self):
        from tensorboardX import SummaryWriter
        # from logger import Logger
        # self.logger = Logger(self.log_path)
        self.writer = SummaryWriter(log_dir=self.summary_path)


    def load_pretrained_model(self):
        self.Generator.load_state_dict(torch.load(os.path.join(
            self.check_point_path , '{}_Generator.pth'.format(self.pretrained_model))))
        self.Discriminator.load_state_dict(torch.load(os.path.join(
            self.check_point_path , '{}_Discriminator.pth'.format(self.pretrained_model))))
        print('loaded trained models (step: {})..!'.format(self.pretrained_model))

    def reset_grad(self):
        self.g_optimizer.zero_grad()
        # self.decoder_optimizer.zero_grad()
        self.d_optimizer.zero_grad()

    def save_sample(self, data_iter):
        real_images, _ = next(data_iter)
        save_image(denorm(real_images), os.path.join(self.sample_path, 'real.png'))
Пример #2
0
def main(args):
    # Step0 ====================================================================
    # Set GPU ids
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_ids

    # Set the file name format
    FILE_NAME_FORMAT = "{0}_{1}_{2}_{3:d}{4}".format(args.model, args.dataset,
                                                     args.loss, args.epochs,
                                                     args.flag)
    # Set the results file path
    RESULT_FILE_NAME = FILE_NAME_FORMAT + '_results.pkl'
    RESULT_FILE_PATH = os.path.join(RESULT_PATH, RESULT_FILE_NAME)
    # Set the checkpoint file path
    CHECKPOINT_FILE_NAME = FILE_NAME_FORMAT + '.ckpt'
    CHECKPOINT_FILE_PATH = os.path.join(CHECKPOINT_PATH, CHECKPOINT_FILE_NAME)
    BEST_CHECKPOINT_FILE_NAME = FILE_NAME_FORMAT + '_best.ckpt'
    BEST_CHECKPOINT_FILE_PATH = os.path.join(CHECKPOINT_PATH,
                                             BEST_CHECKPOINT_FILE_NAME)
    # Set the random seed same for reproducibility
    random.seed(190811)
    torch.manual_seed(190811)
    torch.cuda.manual_seed_all(190811)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # Step1 ====================================================================
    # Load dataset
    train_dataloader = CycleGAN_Dataloader(name=args.dataset,
                                           num_workers=args.num_workers)
    test_dataloader = CycleGAN_Dataloader(name=args.dataset,
                                          train=False,
                                          num_workers=args.num_workers)
    print('==> DataLoader ready.')

    # Step2 ====================================================================
    # Make the model
    if args.dataset == 'cityscapes':
        A_generator = Generator(num_resblock=6)
        B_generator = Generator(num_resblock=6)
        A_discriminator = Discriminator()
        B_discriminator = Discriminator()
    else:
        A_generator = Generator(num_resblock=9)
        B_generator = Generator(num_resblock=9)
        A_discriminator = Discriminator()
        B_discriminator = Discriminator()

    # Check DataParallel available
    if torch.cuda.device_count() > 1:
        A_generator = nn.DataParallel(A_generator)
        B_generator = nn.DataParallel(B_generator)
        A_discriminator = nn.DataParallel(A_discriminator)
        B_discriminator = nn.DataParallel(B_discriminator)

    # Check CUDA available
    if torch.cuda.is_available():
        A_generator.cuda()
        B_generator.cuda()
        A_discriminator.cuda()
        B_discriminator.cuda()
    print('==> Model ready.')

    # Step3 ====================================================================
    # Set each loss function
    criterion_GAN = nn.MSELoss()
    criterion_cycle = nn.L1Loss()
    criterion_identity = nn.L1Loss()
    criterion_feature = nn.L1Loss()

    # Set each optimizer
    optimizer_G = optim.Adam(itertools.chain(A_generator.parameters(),
                                             B_generator.parameters()),
                             lr=args.lr,
                             betas=(0.5, 0.999))
    optimizer_D = optim.Adam(itertools.chain(A_discriminator.parameters(),
                                             B_discriminator.parameters()),
                             lr=args.lr,
                             betas=(0.5, 0.999))

    # Set learning rate scheduler
    def lambda_rule(epoch):
        epoch_decay = args.epochs / 2
        lr_linear_scale = 1.0 - max(0, epoch + 1 - epoch_decay) \
                                / float(epoch_decay+ 1)
        return lr_linear_scale

    scheduler_G = lr_scheduler.LambdaLR(optimizer_G, lr_lambda=lambda_rule)
    scheduler_D = lr_scheduler.LambdaLR(optimizer_D, lr_lambda=lambda_rule)
    print('==> Criterion and optimizer ready.')

    # Step4 ====================================================================
    # Train and validate the model
    start_epoch = 0
    best_metric = float("inf")

    # Initialize the result lists
    train_loss_G = []
    train_loss_D_A = []
    train_loss_D_B = []

    # Set image buffer
    A_buffer = ImageBuffer(args.buffer_size)
    B_buffer = ImageBuffer(args.buffer_size)

    if args.resume:
        assert os.path.exists(CHECKPOINT_FILE_PATH), 'No checkpoint file!'
        checkpoint = torch.load(CHECKPOINT_FILE_PATH)
        A_generator.load_state_dict(checkpoint['A_generator_state_dict'])
        B_generator.load_state_dict(checkpoint['B_generator_state_dict'])
        A_discriminator.load_state_dict(
            checkpoint['A_discriminator_state_dict'])
        B_discriminator.load_state_dict(
            checkpoint['B_discriminator_state_dict'])
        optimizer_G.load_state_dict(checkpoint['optimizer_G_state_dict'])
        optimizer_D.load_state_dict(checkpoint['optimizer_D_state_dict'])
        scheduler_G.load_state_dict(checkpoint['scheduler_G_state_dict'])
        scheduler_D.load_state_dict(checkpoint['scheduler_D_state_dict'])
        start_epoch = checkpoint['epoch']
        train_loss_G = checkpoint['train_loss_G']
        train_loss_D_A = checkpoint['train_loss_D_A']
        train_loss_D_B = checkpoint['train_loss_D_B']
        best_metric = checkpoint['best_metric']

    # Save the training information
    result_data = {}
    result_data['model'] = args.model
    result_data['dataset'] = args.dataset
    result_data['loss'] = args.loss
    result_data['target_epoch'] = args.epochs
    result_data['batch_size'] = args.batch_size

    # Check the directory of the file path
    if not os.path.exists(os.path.dirname(RESULT_FILE_PATH)):
        os.makedirs(os.path.dirname(RESULT_FILE_PATH))
    if not os.path.exists(os.path.dirname(CHECKPOINT_FILE_PATH)):
        os.makedirs(os.path.dirname(CHECKPOINT_FILE_PATH))
    print('==> Train ready.')

    for epoch in range(args.epochs):
        # strat after the checkpoint epoch
        if epoch < start_epoch:
            continue

        print("\n[Epoch: {:3d}/{:3d}]".format(epoch + 1, args.epochs))
        epoch_time = time.time()
        #=======================================================================
        # train and validate the model
        tloss_G, tloss_D = train(
            train_dataloader, A_generator, B_generator, A_discriminator,
            B_discriminator, criterion_GAN, criterion_cycle,
            criterion_identity, optimizer_G, optimizer_D, A_buffer, B_buffer,
            args.loss, args.lambda_cycle, args.lambda_identity,
            criterion_feature, args.lambda_feature, args.attention)
        train_loss_G.append(tloss_G)
        train_loss_D_A.append(tloss_D['A'])
        train_loss_D_B.append(tloss_D['B'])

        if (epoch + 1) % 10 == 0:
            val(test_dataloader, A_generator, B_generator, A_discriminator,
                B_discriminator, epoch + 1, FILE_NAME_FORMAT, args.attention)

        # Update the optimizer's learning rate
        current_lr = optimizer_G.param_groups[0]['lr']
        scheduler_G.step()
        scheduler_D.step()
        #=======================================================================
        current = time.time()

        # Save the current result
        result_data['current_epoch'] = epoch
        result_data['train_loss_G'] = train_loss_G
        result_data['train_loss_D_A'] = train_loss_D_A
        result_data['train_loss_D_B'] = train_loss_D_B

        # Save result_data as pkl file
        with open(RESULT_FILE_PATH, 'wb') as pkl_file:
            pickle.dump(result_data,
                        pkl_file,
                        protocol=pickle.HIGHEST_PROTOCOL)

        # Save the best checkpoint
        # if train_loss_G < best_metric:
        #     best_metric = train_loss_G
        #     torch.save({
        #         'epoch': epoch+1,
        #         'A_generator_state_dict': A_generator.state_dict(),
        #         'B_generator_state_dict': B_generator.state_dict(),
        #         'A_discriminator_state_dict': A_discriminator.state_dict(),
        #         'B_discriminator_state_dict': B_discriminator.state_dict(),
        #         'optimizer_G_state_dict': optimizer_G.state_dict(),
        #         'optimizer_D_state_dict': optimizer_D.state_dict(),
        #         'scheduler_G_state_dict': scheduler_G.state_dict(),
        #         'scheduler_D_state_dict': scheduler_D.state_dict(),
        #         'train_loss_G': train_loss_G,
        #         'train_loss_D_A': train_loss_D_A,
        #         'train_loss_D_B': train_loss_D_B,
        #         'best_metric': best_metric,
        #         }, BEST_CHECKPOINT_FILE_PATH)

        # Save the current checkpoint
        torch.save(
            {
                'epoch': epoch + 1,
                'A_generator_state_dict': A_generator.state_dict(),
                'B_generator_state_dict': B_generator.state_dict(),
                'A_discriminator_state_dict': A_discriminator.state_dict(),
                'B_discriminator_state_dict': B_discriminator.state_dict(),
                'optimizer_G_state_dict': optimizer_G.state_dict(),
                'optimizer_D_state_dict': optimizer_D.state_dict(),
                'scheduler_G_state_dict': scheduler_G.state_dict(),
                'scheduler_D_state_dict': scheduler_D.state_dict(),
                'train_loss_G': train_loss_G,
                'train_loss_D_A': train_loss_D_A,
                'train_loss_D_B': train_loss_D_B,
                'best_metric': best_metric,
            }, CHECKPOINT_FILE_PATH)

        if (epoch + 1) % 10 == 0:
            CHECKPOINT_FILE_NAME_epoch = FILE_NAME_FORMAT + '_{0}.ckpt'
            CHECKPOINT_FILE_PATH_epoch = os.path.join(
                CHECKPOINT_PATH, FILE_NAME_FORMAT, CHECKPOINT_FILE_NAME_epoch)
            if not os.path.exists(os.path.dirname(CHECKPOINT_FILE_PATH_epoch)):
                os.makedirs(os.path.dirname(CHECKPOINT_FILE_PATH_epoch))
            torch.save(
                {
                    'epoch': epoch + 1,
                    'A_generator_state_dict': A_generator.state_dict(),
                    'B_generator_state_dict': B_generator.state_dict(),
                    'A_discriminator_state_dict': A_discriminator.state_dict(),
                    'B_discriminator_state_dict': B_discriminator.state_dict(),
                    'optimizer_G_state_dict': optimizer_G.state_dict(),
                    'optimizer_D_state_dict': optimizer_D.state_dict(),
                    'scheduler_G_state_dict': scheduler_G.state_dict(),
                    'scheduler_D_state_dict': scheduler_D.state_dict(),
                    'train_loss_G': train_loss_G,
                    'train_loss_D_A': train_loss_D_A,
                    'train_loss_D_B': train_loss_D_B,
                    'best_metric': best_metric,
                }, CHECKPOINT_FILE_PATH_epoch)

        # Print the information on the console
        print("model                : {}".format(args.model))
        print("dataset              : {}".format(args.dataset))
        print("loss                 : {}".format(args.loss))
        print("batch_size           : {}".format(args.batch_size))
        print("current lrate        : {:f}".format(current_lr))
        print("G loss               : {:f}".format(tloss_G))
        print("D A/B loss           : {:f}/{:f}".format(
            tloss_D['A'], tloss_D['B']))
        print("epoch time           : {0:.3f} sec".format(current -
                                                          epoch_time))
        print("Current elapsed time : {0:.3f} sec".format(current - start))
    print('==> Train done.')

    print(' '.join(['Results have been saved at', RESULT_FILE_PATH]))
    print(' '.join(['Checkpoints have been saved at', CHECKPOINT_FILE_PATH]))
class Model:
    def __init__(self, base_path='', epochs=10, learning_rate=0.0002, image_size=256, leaky_relu=0.2,
                 betas=(0.5, 0.999), lamda=100, image_format='png'):
        self.image_size = image_size
        self.leaky_relu_threshold = leaky_relu

        self.epochs = epochs
        self.lr = learning_rate
        self.betas = betas
        self.lamda = lamda
        self.base_path = base_path
        self.image_format = image_format
        self.count = 1

        self.gen = None
        self.dis = None
        self.gen_optim = None
        self.dis_optim = None
        self.model_type = None
        self.residual_blocks = 9
        self.layer_size = 64
        self.lr_policy = None
        self.lr_schedule_gen = None
        self.lr_schedule_dis = None

        self.device = self.get_device()
        self.create_folder_structure()

    def create_folder_structure(self):
        checkpoint_folder = self.base_path + '/checkpoints'
        loss_folder = self.base_path + '/Loss_Checkpoints'
        training_folder = self.base_path + '/Training Images'
        test_folder = self.base_path + '/Test Images'
        if not os.path.exists(checkpoint_folder):
            os.makedirs(checkpoint_folder)
        if not os.path.exists(loss_folder):
            os.makedirs(loss_folder)
        if not os.path.exists(training_folder):
            os.makedirs(training_folder)
        if not os.path.exists(test_folder):
            os.makedirs(test_folder)

    def get_device(self):
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        print('Using device:', device)
        print(torch.cuda.get_device_name(0))

        if device.type == 'cuda':
            print('Memory Usage -')
            print('Allocated:', round(torch.cuda.memory_allocated(0) / 1024 ** 3, 1), 'GB')
            print('Cached:   ', round(torch.cuda.memory_cached(0) / 1024 ** 3, 1), 'GB')
            return device
        else:
            return None

    def initialize_model(self, lr_schedular_options, model_type='unet', residual_blocks=9, layer_size=64):

        all_models = ['unet', 'resnet', 'inception', 'unet2', 'unet_large', 'unet_fusion']
        if model_type not in all_models:
            raise Exception('This model type is not available!');

        self.dis = Discriminator(image_size=self.image_size, leaky_relu=self.leaky_relu_threshold)
        if model_type == 'unet':
            self.gen = Generator_Unet(image_size=self.image_size, ngf=layer_size)
        elif model_type == 'resnet':
            self.gen = Generator_RESNET(residual_blocks=residual_blocks, ngf=layer_size)
        elif model_type == 'inception':
            self.gen = Generator_InceptionNet(ngf=layer_size)
        elif model_type == 'unet2':
            self.gen = Generator_Unet_2(image_size=self.image_size, ngf=layer_size)
        elif model_type == 'unet_large':
            self.gen = Generator_Unet_Large(image_size=self.image_size, ngf=layer_size)
        elif model_type == 'unet_fusion':
            self.gen = Generator_Unet_Fusion(image_size=self.image_size, ngf=layer_size)

        if self.device is not None:
            self.gen.cuda()
            self.dis.cuda()

        self.gen_optim = optim.Adam(self.gen.parameters(), lr=self.lr, betas=self.betas)
        self.dis_optim = optim.Adam(self.dis.parameters(), lr=self.lr, betas=self.betas)

        self.lr_schedule_dis = self.get_learning_schedule(self.gen_optim, lr_schedular_options)
        self.lr_schedule_gen = self.get_learning_schedule(self.dis_optim, lr_schedular_options)

        self.model_type = model_type
        self.layer_size = layer_size
        self.residual_blocks = residual_blocks
        self.lr_policy = lr_schedular_options
        print('Model Initialized !\nGenerator Model Type : {} and Layer Size : {}'.format(model_type, layer_size))
        print('Model Parameters are:\nEpochs : {}\nLearning rate : {}\nLeaky Relu Threshold : {}\nLamda : {}\nBeta : {}'
              .format(self.epochs, self.lr, self.leaky_relu_threshold, self.lamda, self.betas))

    def train_model(self, trainloader, average_loss, eval=(False, None, None), save_model=(False, 25),
                    display_test_image=(False, None, 25)):

        print('We will be using L1 loss with perpetual loss (L1)!')
        mean_loss = nn.BCELoss()
        l1_loss = nn.L1Loss()
        vgg16 = models.vgg16()
        vgg16_conv = nn.Sequential(*list(vgg16.children())[:-3])

        self.gen.train()
        self.dis.train()

        batches = len(trainloader)
        print('Total number of batches in an epoch are : {}'.format(batches))

        sample_img_test = None
        if display_test_image[0]:
            sample_img_test, rgb_test_images = next(iter(display_test_image[1]))
            save_image((rgb_test_images[0].detach().cpu() + 1) / 2,
                       '{}/Training Images/real_img.{}'.format(self.base_path, self.image_format))
            if self.device is not None:
                sample_img_test = sample_img_test.cuda()

        for i in range(self.epochs):

            if eval[0] and (i % eval[2] == 0):
                self.evaluate_L1_loss_dataset(eval[1], train=False)
                self.evaluate_L1_loss_dataset(trainloader, train=True)
                self.gen.train()

            running_gen_loss = 0
            running_dis_loss = 0

            for gray_img, real_img in trainloader:

                batch_size = len(gray_img)
                zero_label = torch.zeros(batch_size)
                one_label = torch.ones(batch_size)

                if self.device is not None:
                    gray_img = gray_img.cuda()
                    real_img = real_img.cuda()
                    zero_label = zero_label.cuda()
                    one_label = one_label.cuda()

                # Discriminator loss
                self.dis_optim.zero_grad()
                fake_img = self.gen(gray_img)

                dis_real_loss = mean_loss(self.dis(real_img), one_label)
                dis_fake_loss = mean_loss(self.dis(fake_img), zero_label)

                total_dis_loss = dis_fake_loss + dis_real_loss
                total_dis_loss.backward()
                self.dis_optim.step()

                # Generator loss
                self.gen_optim.zero_grad()

                fake_img = self.gen(gray_img)
                gen_adv_loss = mean_loss(self.dis(fake_img), one_label)
                gen_l1_loss = l1_loss(fake_img.view(batch_size, -1), real_img.view(batch_size, -1))
                gen_pre_train = l1_loss(vgg16_conv(fake_img), vgg16_conv(real_img))
                total_gen_loss = gen_adv_loss + self.lamda * gen_l1_loss + self.lamda * gen_pre_train
                total_gen_loss.backward()
                self.gen_optim.step()

                running_dis_loss += total_dis_loss.item()
                running_gen_loss += total_gen_loss.item()

            running_dis_loss /= (batches * 1.0)
            running_gen_loss /= (batches * 1.0)
            print('Epoch : {}, Generator Loss : {} and Discriminator Loss : {}'.format(i + 1, running_gen_loss,
                                                                                       running_dis_loss))
            if display_test_image[0] and i % display_test_image[2] == 0:
                self.gen.eval()
                out_result = self.gen(sample_img_test)
                out_result = out_result.detach().cpu()
                out_result = (out_result[0] + 1) / 2
                save_image(out_result, '{}/Training Images/epoch_{}.{}'.format(self.base_path, i,
                                                                               self.image_format))
                self.gen.train()

            save_tuple = ([running_gen_loss], [running_dis_loss])
            average_loss.add_loss(save_tuple)

            if save_model[0] and i % save_model[1] == 0:
                self.save_checkpoint('checkpoint_epoch_{}'.format(i), self.model_type)
                average_loss.save('checkpoint_avg_loss', save_index=0)

            self.lr_schedule_gen.step()
            self.lr_schedule_dis.step()
            for param_grp in self.dis_optim.param_groups:
                print('Learning rate after {} epochs is : {}'.format(i + 1, param_grp['lr']))

        self.save_checkpoint('checkpoint_train_final', self.model_type)
        average_loss.save('checkpoint_avg_loss_final', save_index=0)

    def get_learning_schedule(self, optimizer, option):

        schedular = None
        if option['lr_policy'] == 'linear':
            def lambda_rule(epoch):
                lr_l = 1.0 - max(0, epoch - option['n_epochs']) / float(option['n_epoch_decay'] + 1)
                return lr_l

            schedular = lr_schedular.LambdaLR(optimizer, lr_lambda=lambda_rule)
        elif option['lr_policy'] == 'plateau':
            schedular = lr_schedular.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
        elif option['lr_policy'] == 'step':
            schedular = lr_schedular.StepLR(optimizer, step_size=option['step_size'], gamma=0.1)
        elif option['lr_policy'] == 'cosine':
            schedular = lr_schedular.CosineAnnealingLR(optimizer, T_max=option['n_epochs'], eta_min=0)
        else:
            raise Exception('LR Policy not implemented!')

        return schedular

    def evaluate_model(self, loader, save_filename, no_of_images=1):
        # Considering that we have batch size of 1 for test set
        if self.gen is None or self.dis is None:
            raise Exception('Model has not been initialized and hence cannot be saved!');

        counter_images_generated = 0
        while counter_images_generated < no_of_images:
            gray, rgb = next(iter(loader))

            if self.device is not None:
                gray = gray.cuda()

            filename = '{}/Test Images/{}_{}.{}'.format(self.base_path, save_filename, self.count, self.image_format)
            real_filename = '{}/Test Images/{}_{}_real.{}'.format(self.base_path, save_filename, self.count,
                                                                  self.image_format)
            real_gray_filename = '{}/Test Images/{}_{}_real_gray.{}'.format(self.base_path, save_filename, self.count,
                                                                            self.image_format)
            self.count += 1

            self.gen.eval()
            out = self.gen(gray)
            out = out[0].detach().cpu()
            out = (out + 1) / 2
            save_image(out, filename)

            gray_img = gray[0].detach().cpu()
            save_image(gray_img, real_gray_filename)

            real_img = (rgb[0].detach().cpu() + 1) / 2
            save_image(real_img, real_filename)

            counter_images_generated += 1

    def evaluate_L1_loss_dataset(self, loader, train=False):

        if self.gen is None or self.dis is None:
            raise Exception('Model has not been initialized and hence cannot be evaluated!')

        loss_function = nn.L1Loss()
        self.gen.eval()
        total_loss = 0.0;
        iterations = 0;
        for gray, real in loader:
            iterations += 1
            if self.device is not None:
                gray = gray.cuda()
                real = real.cuda()

            gen_out = self.gen(gray)
            iteration_loss = loss_function(gen_out, real)
            total_loss += iteration_loss.item()
        total_loss = total_loss / (iterations * 1.0)
        train_test = 'test'
        if train:
            train_test = 'train'
        print('Total L1 loss over {} set is : {}'.format(train_test, total_loss))
        return total_loss;

    def change_params(self, epochs=None, learning_rate=None, leaky_relu=None, betas=None, lamda=None):
        if epochs is not None:
            self.epochs = epochs
            print('Changed the number of epochs to {}!'.format(self.epochs))
        if learning_rate is not None:
            self.lr = learning_rate
            print('Changed the learning rate to {}!'.format(self.lr))
        if leaky_relu is not None:
            self.leaky_relu_threshold = leaky_relu
            print('Changed the threshold for leaky relu to {}!'.format(self.leaky_relu_threshold))
        if betas is not None:
            self.betas = betas
            print('Changed the betas for Adams Optimizer!')
        if betas is not None or learning_rate is not None:
            self.gen_optim = optim.Adam(self.gen.parameters(), lr=self.lr, betas=self.betas)
            self.dis_optim = optim.Adam(self.dis.parameters(), lr=self.lr, betas=self.betas)

        if lamda is not None:
            self.lamda = lamda
            print('Lamda value has been changed to {}!'.format(self.lamda))

    def set_all_params(self, epochs, lr, leaky_thresh, lamda, beta):
        self.epochs = epochs
        self.lr = lr
        self.leaky_relu_threshold = leaky_thresh
        self.lamda = lamda
        self.betas = beta
        self.gen_optim = optim.Adam(self.gen.parameters(), lr=self.lr, betas=self.betas)
        self.dis_optim = optim.Adam(self.dis.parameters(), lr=self.lr, betas=self.betas)

        print('Model Parameters are:\nEpochs : {}\nLearning rate : {}\nLeaky Relu Threshold : {}\nLamda : {}\nBeta : {}'
              .format(self.epochs, self.lr, self.leaky_relu_threshold, self.lamda, self.betas))

    def run_model_on_dataset(self, loader, save_folder, save_path=None):
        if self.gen is None or self.dis is None:
            raise Exception('Model has not been initialized and hence cannot be saved!');
        index = 1
        if save_path is None:
            save_path = self.base_path
        for gray, dummy in loader:

            if self.device is not None:
                gray = gray.cuda()

            filename = '{}/{}/{}.{}'.format(save_path, save_folder, index, self.image_format)
            index += 1

            self.gen.eval()
            out = self.gen(gray)
            out = out[0].detach().cpu()
            out = (out + 1) / 2
            save_image(out, filename)

    def save_checkpoint(self, filename, model_type='unet'):
        if self.gen is None or self.dis is None:
            raise Exception('The model has not been initialized and hence cannot be saved !')

        filename = '{}/checkpoints/{}.pth'.format(self.base_path, filename)
        save_dict = {'model_type': model_type, 'dis_dict': self.dis.state_dict(), 'gen_dict': self.gen.state_dict(),
                     'lr': self.lr,
                     'epochs': self.epochs, 'betas': self.betas, 'image_size': self.image_size,
                     'leaky_relu_thresh': self.leaky_relu_threshold, 'lamda': self.lamda, 'base_path': self.base_path,
                     'count': self.count, 'image_format': self.image_format, 'device': self.device,
                     'residual_blocks': self.residual_blocks, 'layer_size': self.layer_size,
                     'lr_policy': self.lr_policy}

        torch.save(save_dict, filename)

        print('The model checkpoint has been saved !')

    def load_checkpoint(self, filename):
        filename = '{}/checkpoints/{}.pth'.format(self.base_path, filename)
        if not pathlib.Path(filename).exists():
            raise Exception('This checkpoint does not exist!')

        self.gen = None
        self.dis = None

        save_dict = torch.load(filename)

        self.betas = save_dict['betas']
        self.image_size = save_dict['image_size']
        self.epochs = save_dict['epochs']
        self.leaky_relu_threshold = save_dict['leaky_relu_thresh']
        self.lamda = save_dict['lamda']
        self.lr = save_dict['lr']
        self.base_path = save_dict['base_path']
        self.count = save_dict['count']
        self.image_format = save_dict['image_format']
        self.device = save_dict['device']
        self.residual_blocks = save_dict['residual_blocks']
        self.layer_size = save_dict['layer_size']
        self.lr_policy = save_dict['lr_policy']

        device = self.get_device()
        if device != self.device:
            error_msg = ''
            if self.device is None:
                error_msg = 'The model was trained on CPU and will therefore be continued on CPU only!'
            else:
                error_msg = 'The model was trained on GPU and cannot be loaded on a CPU machine!'
                raise Exception(error_msg)

        self.initialize_model(model_type=save_dict['model_type'], residual_blocks=self.residual_blocks,
                              layer_size=self.layer_size, lr_schedular_options=self.lr_policy)

        self.gen.load_state_dict(save_dict['gen_dict'])
        self.dis.load_state_dict(save_dict['dis_dict'])

        print('The model checkpoint has been restored!')
Пример #4
0
def train(FLAGS):
    # Define the hyperparameters
    p_every = FLAGS.p_every
    s_every = FLAGS.s_every
    epochs = FLAGS.epochs
    dlr = FLAGS.dlr
    glr = FLAGS.glr
    beta1 = FLAGS.beta1
    beta2 = FLAGS.beta2
    z_size = FLAGS.zsize
    batch_size = FLAGS.batch_size
    rh = FLAGS.resize_height
    rw = FLAGS.resize_width
    d_path = FLAGS.dataset_path
    d_type = FLAGS.dataset_type

    # Preprocessing Data
    transform = transforms.Compose([
        transforms.Resize((rh, rw)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    if FLAGS.dataset_path == None:
        if d_type == "cars":
            if not os.path.exists('./datasets/cars_train'):
                os.system('sh ./datasets/dload.sh cars')
            d_path = './datasets/cars_train/'

        elif d_type == "flowers":
            if not os.path.exists('./datasets/flowers/'):
                os.system('sh ./datasets/dload.sh flowers')
            d_path = './datasets/flowers/'

        elif d_type == "dogs":
            if not os.path.exists('./datasets/jpg'):
                os.system('sh ./datasets/dload.sh dogs')
            d_path = './datasets/jpg/'

    train_data = datasets.ImageFolder(d_path, transform=transform)
    trainloader = DataLoader(train_data, batch_size=batch_size, shuffle=True)

    # Define the D and G
    dis = Discriminator(64)
    gen = Generator()

    # Apply weight initialization
    dis.apply(init_weight)
    gen.apply(init_weight)

    # Define the loss function
    criterion = nn.BCELoss()

    # Optimizers
    d_opt = optim.Adam(dis.parameters(), lr=dlr, betas=(beta1, beta2))
    g_opt = optim.Adam(gen.parameters(), lr=glr, betas=(beta1, beta2))

    # Train loop
    device = torch.device(
        'cuda') if torch.cuda.is_available() else torch.device('cpu')

    train_losses = []
    eval_losses = []

    dis.to(device)
    gen.to(device)

    real_label = 1
    fake_label = 0

    for e in range(epochs):

        td_loss = 0
        tg_loss = 0

        for batch_i, (real_images, _) in enumerate(trainloader):

            real_images = real_images.to(device)

            batch_size = real_images.size(0)

            #### Train the Discriminator ####

            d_opt.zero_grad()

            d_real = dis(real_images)

            label = torch.full((batch_size, ), real_label, device=device)
            r_loss = criterion(d_real, label)
            r_loss.backward()

            z = torch.randn(batch_size, z_size, 1, 1, device=device)

            fake_images = gen(z)

            label.fill_(fake_label)

            d_fake = dis(fake_images.detach())

            f_loss = criterion(d_fake, label)
            f_loss.backward()

            d_loss = r_loss + f_loss

            d_opt.step()

            #### Train the Generator ####
            g_opt.zero_grad()

            label.fill_(real_label)
            d_fake2 = dis(fake_images)

            g_loss = criterion(d_fake2, label)
            g_loss.backward()

            g_opt.step()

            if batch_i % p_every == 0:
                print ('Epoch [{:5d} / {:5d}] | d_loss: {:6.4f} | g_loss: {:6.4f}'. \
                        format(e+1, epochs, d_loss, g_loss))

        train_losses.append([td_loss, tg_loss])

        if e % s_every == 0:
            d_ckpt = {
                'model_state_dict': dis.state_dict(),
                'opt_state_dict': d_opt.state_dict()
            }

            g_ckpt = {
                'model_state_dict': gen.state_dict(),
                'opt_state_dict': g_opt.state_dict()
            }

            torch.save(d_ckpt, 'd-nm-{}.pth'.format(e))
            torch.save(g_ckpt, 'g-nm-{}.pth'.format(e))

        utils.save_image(fake_images.detach(),
                         'fake_{}.png'.format(e),
                         normalize=True)

    print('[INFO] Training Completed successfully!')
Пример #5
0
class Pix2PixMain(object):
    def __init__(self):

        # -----------------------------------
        # global
        # -----------------------------------
        np.random.seed(Settings.SEED)
        torch.manual_seed(Settings.SEED)
        random.seed(Settings.SEED)
        if torch.cuda.is_available():
            torch.cuda.manual_seed(Settings.SEED)
            self.device = torch.device("cuda")
        else:
            self.device = torch.device("cpu")

        # -----------------------------------
        # model
        # -----------------------------------
        self.generator = Generator(in_c=Settings.IN_CHANNEL,
                                   out_c=Settings.OUT_CHANNEL,
                                   ngf=Settings.NGF).to(self.device)
        self.generator.apply(self.generator.weights_init)
        self.discriminator = Discriminator(
            in_c=Settings.IN_CHANNEL,
            out_c=Settings.OUT_CHANNEL,
            ndf=Settings.NDF,
            n_layers=Settings.DISCRIMINATOR_LAYER).to(self.device)
        self.discriminator.apply(self.discriminator.weights_init)
        print("model init done")

        # -----------------------------------
        # data
        # -----------------------------------
        train_transforms = transforms.Compose([
            transforms.Resize((Settings.INPUT_SIZE, Settings.INPUT_SIZE)),
            transforms.ToTensor(),
            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
        ])

        data_prepare = get_dataloader(
            dataset_name=Settings.DATASET,
            batch_size=Settings.BATCH_SIZE,
            data_root=Settings.DATASET_ROOT,
            train_num_workers=Settings.TRAIN_NUM_WORKERS,
            transforms=train_transforms,
            val_num_workers=Settings.TEST_NUM_WORKERS)
        self.train_dataloader = data_prepare.train_dataloader
        self.test_dataloader = data_prepare.test_dataloader
        print("data init done.....")

        # -----------------------------------
        # optimizer and criterion
        # -----------------------------------
        self.optimG = optim.Adam([{
            "params": self.generator.parameters()
        }],
                                 lr=Settings.G_LR,
                                 betas=Settings.G_BETAS)
        self.optimD = optim.Adam([{
            "params": self.discriminator.parameters()
        }],
                                 lr=Settings.D_LR,
                                 betas=Settings.D_BETAS)

        self.criterion_l1loss = nn.L1Loss()
        self.criterion_BCE = nn.BCELoss()
        print("optimizer and criterion init done.....")

        # -----------------------------------
        # recorder
        # -----------------------------------
        self.recorder = {
            "errD_fake": list(),
            "errD_real": list(),
            "errG_l1loss": list(),
            "errG_bce": list(),
            "errG": list(),
            "accD": list()
        }

        output_file = time.strftime(
            "{}_{}_%Y_%m_%d_%H_%M_%S".format("pix2pix", Settings.DATASET),
            time.localtime())
        self.output_root = os.path.join(Settings.OUTPUT_ROOT, output_file)
        os.makedirs(os.path.join(self.output_root, Settings.OUTPUT_MODEL_KEY))
        os.makedirs(os.path.join(self.output_root, Settings.OUTPUT_LOG_KEY))
        os.makedirs(os.path.join(self.output_root, Settings.OUTPUT_IMAGE_KEY))
        print("recorder init done.....")

    def __call__(self):

        print_steps = max(
            1, int(len(self.train_dataloader) * Settings.PRINT_FREQUENT))
        eval_steps = max(
            1, int(len(self.train_dataloader) * Settings.EVAL_FREQUENT))
        batch_steps = max(1, int(Settings.EPOCHS * Settings.BATCH_FREQUENT))

        print("begin train.....")
        for epoch in range(1, Settings.EPOCHS + 1):
            for step, batch in enumerate(self.train_dataloader):

                # train
                self.train_module(batch)

                # print
                self.print_module(epoch, step, print_steps)

                if epoch % batch_steps == 0:
                    # val
                    self.val_module(epoch, step, eval_steps)

            # save log
            self.log_save_module()

    def train_module(self, batch):
        self.generator.train()
        self.discriminator.train()

        input_images = None
        target_images = None
        if Settings.DATASET == "edge2shoes":
            input_images = batch["edge_images"].to(self.device)
            target_images = batch["color_images"].to(self.device)
        elif Settings.DATASET == "Mogaoku":
            input_images = batch["edge_images"].to(self.device)
            target_images = batch["color_images"].to(self.device)
        else:
            KeyError("DataSet {} doesn't exit".format(Settings.DATASET))

        # 判别器迭代
        self.optimD.zero_grad()
        true_image_d_pred = self.discriminator(input_images, target_images)
        true_images_label = torch.full(true_image_d_pred.shape,
                                       Settings.REAL_LABEL,
                                       dtype=torch.float32,
                                       device=self.device)
        errD_real_bce = self.criterion_BCE(true_image_d_pred,
                                           true_images_label)
        errD_real_bce.backward()

        fake_images = self.generator(input_images)
        fake_images_d_pred = self.discriminator(input_images,
                                                fake_images.detach())
        fake_images_label = torch.full(fake_images_d_pred.shape,
                                       Settings.FAKE_LABEL,
                                       dtype=torch.float32,
                                       device=self.device)
        errD_fake_bce = self.criterion_BCE(fake_images_d_pred,
                                           fake_images_label)
        errD_fake_bce.backward()
        self.optimD.step()

        real_image_pred_true_num = ((true_image_d_pred >
                                     0.5) == true_images_label).sum().float()
        fake_image_pred_true_num = ((fake_images_d_pred >
                                     0.5) == fake_images_label).sum().float()

        accD = (real_image_pred_true_num + fake_image_pred_true_num) / \
               (true_images_label.numel() + fake_images_label.numel())

        # 生成器迭代
        self.optimG.zero_grad()
        fake_images_d_pred = self.discriminator(input_images, fake_images)
        true_images_label = torch.full(fake_images_d_pred.shape,
                                       Settings.REAL_LABEL,
                                       dtype=torch.float32,
                                       device=self.device)
        errG_bce = self.criterion_BCE(fake_images_d_pred, true_images_label)
        errG_l1loss = self.criterion_l1loss(fake_images, target_images)

        errG = errG_bce + errG_l1loss * Settings.L1_LOSS_LAMUDA
        errG.backward()
        self.optimG.step()

        # recorder
        self.recorder["errD_real"].append(errD_real_bce.item())
        self.recorder["errD_fake"].append(errD_fake_bce.item())
        self.recorder["errG_l1loss"].append(errG_l1loss.item())
        self.recorder["errG_bce"].append(errG_bce.item())
        self.recorder["errG"].append(errG.item())
        self.recorder["accD"].append(accD)

    def val_module(self, epoch, step, eval_steps):
        def apply_dropout(m):
            if type(m) == nn.Dropout:
                m.train()

        if (step + 1) % eval_steps == 0:

            output_images = None
            output_count = 0

            self.generator.eval()
            self.discriminator.eval()

            # 启用dropout
            if Settings.USING_DROPOUT_DURING_EVAL:
                self.generator.apply(apply_dropout)
                self.discriminator.apply(apply_dropout)

            for eval_step, eval_batch in enumerate(self.test_dataloader):

                input_images = eval_batch["edge_images"].to(self.device)
                target_images = eval_batch["color_images"]

                pred_images = self.generator(input_images).detach().cpu()

                output_image = torch.cat(
                    [input_images.cpu(), target_images, pred_images], dim=3)

                if output_images is None:
                    output_images = output_image
                else:
                    output_images = torch.cat([output_images, output_image],
                                              dim=0)

                if output_images.shape[0] == int(
                        len(self.test_dataloader) / 4):

                    output_images = make_grid(
                        output_images,
                        padding=2,
                        normalize=True,
                        nrow=Settings.CONSTANT_FEATURE_DIS_LEN).numpy()
                    output_images = np.array(
                        np.transpose(output_images, (1, 2, 0)) * 255,
                        dtype=np.uint8)
                    output_images = Image.fromarray(output_images)
                    output_images.save(
                        os.path.join(
                            self.output_root, Settings.OUTPUT_IMAGE_KEY,
                            "epoch_{}_step_{}_count_{}.jpg".format(
                                epoch, step, output_count)))

                    output_count += 1
                    output_images = None

            self.model_save_module(epoch, step)
            self.log_save_module()

    def print_module(self, epoch, step, print_steps):
        if (step + 1) % print_steps == 0:
            print("[{}/{}]\t [{}/{}]\t ".format(epoch, Settings.EPOCHS,
                                                step + 1,
                                                len(self.train_dataloader)),
                  end=" ")

            for key in self.recorder:
                print("[{}:{}]\t".format(key, self.recorder[key][-1]), end=" ")

            print(" ")

    def model_save_module(self, epoch, step):
        torch.save(
            self.generator.state_dict(),
            os.path.join(
                self.output_root, Settings.OUTPUT_MODEL_KEY,
                "pix2pix_generator_epoch_{}_step_{}.pth".format(epoch, step)))
        torch.save(
            self.discriminator.state_dict(),
            os.path.join(
                self.output_root, Settings.OUTPUT_MODEL_KEY,
                "pix2pix_discriminator_epoch_{}_step_{}.pth".format(
                    epoch, step)))

    def log_save_module(self):
        # 保存记录
        with open(
                os.path.join(self.output_root, Settings.OUTPUT_LOG_KEY,
                             "log.txt"), "w") as f:
            for item_ in range(len(self.recorder["accD"])):
                for key in self.recorder:
                    f.write("{}:{}\t".format(key, self.recorder[key][item_]))
                f.write("\n")

        # 保存图表
        for key in self.recorder:
            plt.figure(figsize=(10, 5))
            plt.title("{} During Training".format(key))
            plt.plot(self.recorder[key], label=key)
            plt.xlabel("iterations")
            plt.ylabel("value")
            plt.legend()
            if "acc" in key:
                plt.yticks(np.arange(0, 1, 0.5))
            plt.savefig(
                os.path.join(self.output_root, Settings.OUTPUT_LOG_KEY,
                             "{}.jpg".format(key)))

        plt.close("all")

    def learning_rate_decay_module(self, epoch):
        if epoch % Settings.LR_DECAY_EPOCHS == 0:
            for param_group in self.optimD.param_groups:
                param_group["lr"] *= 0.2
            for param_group in self.optimG.param_groups:
                param_group["lr"] *= 0.2