示例#1
0
def train(opt):
    #### device
    device = torch.device('cuda:{}'.format(opt.gpu_id)
                          if opt.gpu_id >= 0 else torch.device('cpu'))

    #### dataset
    data_loader = UnAlignedDataLoader()
    data_loader.initialize(opt)
    data_set = data_loader.load_data()
    print("The number of training images = %d." % len(data_set))

    #### initialize models
    ## declaration
    E_a2Zb = Encoder(input_nc=opt.input_nc,
                     ngf=opt.ngf,
                     norm_type=opt.norm_type,
                     use_dropout=not opt.no_dropout,
                     n_blocks=9)
    G_Zb2b = Decoder(output_nc=opt.output_nc,
                     ngf=opt.ngf,
                     norm_type=opt.norm_type)
    T_Zb2Za = LatentTranslator(n_channels=256,
                               norm_type=opt.norm_type,
                               use_dropout=not opt.no_dropout)
    D_b = Discriminator(input_nc=opt.input_nc,
                        ndf=opt.ndf,
                        n_layers=opt.n_layers,
                        norm_type=opt.norm_type)

    E_b2Za = Encoder(input_nc=opt.input_nc,
                     ngf=opt.ngf,
                     norm_type=opt.norm_type,
                     use_dropout=not opt.no_dropout,
                     n_blocks=9)
    G_Za2a = Decoder(output_nc=opt.output_nc,
                     ngf=opt.ngf,
                     norm_type=opt.norm_type)
    T_Za2Zb = LatentTranslator(n_channels=256,
                               norm_type=opt.norm_type,
                               use_dropout=not opt.no_dropout)
    D_a = Discriminator(input_nc=opt.input_nc,
                        ndf=opt.ndf,
                        n_layers=opt.n_layers,
                        norm_type=opt.norm_type)

    ## initialization
    E_a2Zb = init_net(E_a2Zb, init_type=opt.init_type).to(device)
    G_Zb2b = init_net(G_Zb2b, init_type=opt.init_type).to(device)
    T_Zb2Za = init_net(T_Zb2Za, init_type=opt.init_type).to(device)
    D_b = init_net(D_b, init_type=opt.init_type).to(device)

    E_b2Za = init_net(E_b2Za, init_type=opt.init_type).to(device)
    G_Za2a = init_net(G_Za2a, init_type=opt.init_type).to(device)
    T_Za2Zb = init_net(T_Za2Zb, init_type=opt.init_type).to(device)
    D_a = init_net(D_a, init_type=opt.init_type).to(device)
    print(
        "+------------------------------------------------------+\nFinish initializing networks."
    )

    #### optimizer and criterion
    ## criterion
    criterionGAN = GANLoss(opt.gan_mode).to(device)
    criterionZId = nn.L1Loss()
    criterionIdt = nn.L1Loss()
    criterionCTC = nn.L1Loss()
    criterionZCyc = nn.L1Loss()

    ## optimizer
    optimizer_G = torch.optim.Adam(itertools.chain(E_a2Zb.parameters(),
                                                   G_Zb2b.parameters(),
                                                   T_Zb2Za.parameters(),
                                                   E_b2Za.parameters(),
                                                   G_Za2a.parameters(),
                                                   T_Za2Zb.parameters()),
                                   lr=opt.lr,
                                   betas=(opt.beta1, opt.beta2))
    optimizer_D = torch.optim.Adam(itertools.chain(D_a.parameters(),
                                                   D_b.parameters()),
                                   lr=opt.lr,
                                   betas=(opt.beta1, opt.beta2))

    ## scheduler
    scheduler = [
        get_scheduler(optimizer_G, opt),
        get_scheduler(optimizer_D, opt)
    ]

    print(
        "+------------------------------------------------------+\nFinish initializing the optimizers and criterions."
    )

    #### global variables
    checkpoints_pth = os.path.join(opt.checkpoints, opt.name)
    if os.path.exists(checkpoints_pth) is not True:
        os.mkdir(checkpoints_pth)
        os.mkdir(os.path.join(checkpoints_pth, 'images'))
    record_fh = open(os.path.join(checkpoints_pth, 'records.txt'),
                     'w',
                     encoding='utf-8')
    loss_names = [
        'GAN_A', 'Adv_A', 'Idt_A', 'CTC_A', 'ZId_A', 'ZCyc_A', 'GAN_B',
        'Adv_B', 'Idt_B', 'CTC_B', 'ZId_B', 'ZCyc_B'
    ]

    fake_A_pool = ImagePool(
        opt.pool_size
    )  # create image buffer to store previously generated images
    fake_B_pool = ImagePool(
        opt.pool_size
    )  # create image buffer to store previously generated images

    print(
        "+------------------------------------------------------+\nFinish preparing the other works."
    )
    print(
        "+------------------------------------------------------+\nNow training is beginning .."
    )
    #### training
    cur_iter = 0
    for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1):
        epoch_start_time = time.time()  # timer for entire epoch

        for i, data in enumerate(data_set):
            ## setup inputs
            real_A = data['A'].to(device)
            real_B = data['B'].to(device)

            ## forward
            # image cycle / GAN
            latent_B = E_a2Zb(real_A)  #-> a -> Zb     : E_a2b(a)
            fake_B = G_Zb2b(latent_B)  #-> Zb -> b'    : G_b(E_a2b(a))
            latent_A = E_b2Za(real_B)  #-> b -> Za     : E_b2a(b)
            fake_A = G_Za2a(latent_A)  #-> Za -> a'    : G_a(E_b2a(b))

            # Idt
            '''
            rec_A = G_Za2a(E_b2Za(fake_B))          #-> b' -> Za' -> rec_a  : G_a(E_b2a(fake_b))
            rec_B = G_Zb2b(E_a2Zb(fake_A))          #-> a' -> Zb' -> rec_b  : G_b(E_a2b(fake_a))
            '''
            idt_latent_A = E_b2Za(real_A)  #-> a -> Za        : E_b2a(a)
            idt_A = G_Za2a(idt_latent_A)  #-> Za -> idt_a    : G_a(E_b2a(a))
            idt_latent_B = E_a2Zb(real_B)  #-> b -> Zb        : E_a2b(b)
            idt_B = G_Zb2b(idt_latent_B)  #-> Zb -> idt_b    : G_b(E_a2b(b))

            # ZIdt
            T_latent_A = T_Zb2Za(latent_B)  #-> Zb -> Za''  : T_b2a(E_a2b(a))
            T_rec_A = G_Za2a(
                T_latent_A)  #-> Za'' -> a'' : G_a(T_b2a(E_a2b(a)))
            T_latent_B = T_Za2Zb(latent_A)  #-> Za -> Zb''  : T_a2b(E_b2a(b))
            T_rec_B = G_Zb2b(
                T_latent_B)  #-> Zb'' -> b'' : G_b(T_a2b(E_b2a(b)))

            # CTC
            T_idt_latent_B = T_Za2Zb(idt_latent_A)  #-> a -> T_a2b(E_b2a(a))
            T_idt_latent_A = T_Zb2Za(idt_latent_B)  #-> b -> T_b2a(E_a2b(b))

            # ZCyc
            TT_latent_B = T_Za2Zb(T_latent_A)  #-> T_a2b(T_b2a(E_a2b(a)))
            TT_latent_A = T_Zb2Za(T_latent_B)  #-> T_b2a(T_a2b(E_b2a(b)))

            ### optimize parameters
            ## Generator updating
            set_requires_grad(
                [D_b, D_a],
                False)  #-> set Discriminator to require no gradient
            optimizer_G.zero_grad()
            # GAN loss
            loss_G_A = criterionGAN(D_b(fake_B), True)
            loss_G_B = criterionGAN(D_a(fake_A), True)
            loss_GAN = loss_G_A + loss_G_B
            # Idt loss
            loss_idt_A = criterionIdt(idt_A, real_A)
            loss_idt_B = criterionIdt(idt_B, real_B)
            loss_Idt = loss_idt_A + loss_idt_B
            # Latent cross-identity loss
            loss_Zid_A = criterionZId(T_rec_A, real_A)
            loss_Zid_B = criterionZId(T_rec_B, real_B)
            loss_Zid = loss_Zid_A + loss_Zid_B
            # Latent cross-translation consistency
            loss_CTC_A = criterionCTC(T_idt_latent_A, latent_A)
            loss_CTC_B = criterionCTC(T_idt_latent_B, latent_B)
            loss_CTC = loss_CTC_B + loss_CTC_A
            # Latent cycle consistency
            loss_ZCyc_A = criterionZCyc(TT_latent_A, latent_A)
            loss_ZCyc_B = criterionZCyc(TT_latent_B, latent_B)
            loss_ZCyc = loss_ZCyc_B + loss_ZCyc_A

            loss_G = opt.lambda_gan * loss_GAN + opt.lambda_idt * loss_Idt + opt.lambda_zid * loss_Zid + opt.lambda_ctc * loss_CTC + opt.lambda_zcyc * loss_ZCyc

            # backward and gradient updating
            loss_G.backward()
            optimizer_G.step()

            ## Discriminator updating
            set_requires_grad([D_b, D_a],
                              True)  # -> set Discriminator to require gradient
            optimizer_D.zero_grad()

            # backward D_b
            fake_B_ = fake_B_pool.query(fake_B)
            #-> real_B, fake_B
            pred_real_B = D_b(real_B)
            loss_D_real_B = criterionGAN(pred_real_B, True)

            pred_fake_B = D_b(fake_B_)
            loss_D_fake_B = criterionGAN(pred_fake_B, False)

            loss_D_B = (loss_D_real_B + loss_D_fake_B) * 0.5
            loss_D_B.backward()

            # backward D_a
            fake_A_ = fake_A_pool.query(fake_A)
            #-> real_A, fake_A
            pred_real_A = D_a(real_A)
            loss_D_real_A = criterionGAN(pred_real_A, True)

            pred_fake_A = D_a(fake_A_)
            loss_D_fake_A = criterionGAN(pred_fake_A, False)

            loss_D_A = (loss_D_real_A + loss_D_fake_A) * 0.5
            loss_D_A.backward()

            # update the gradients
            optimizer_D.step()

            ### validate here, both qualitively and quantitatively
            ## record the losses
            if cur_iter % opt.log_freq == 0:
                # loss_names = ['GAN_A', 'Adv_A', 'Idt_A', 'CTC_A', 'ZId_A', 'ZCyc_A', 'GAN_B', 'Adv_B', 'Idt_B', 'CTC_B', 'ZId_B', 'ZCyc_B']
                losses = [
                    loss_G_A.item(),
                    loss_D_A.item(),
                    loss_idt_A.item(),
                    loss_CTC_A.item(),
                    loss_Zid_A.item(),
                    loss_ZCyc_A.item(),
                    loss_G_B.item(),
                    loss_D_B.item(),
                    loss_idt_B.item(),
                    loss_CTC_B.item(),
                    loss_Zid_B.item(),
                    loss_ZCyc_B.item()
                ]
                # record
                line = ''
                for loss in losses:
                    line += '{} '.format(loss)
                record_fh.write(line[:-1] + '\n')
                # print out
                print('Epoch: %3d/%3dIter: %9d--------------------------+' %
                      (epoch, opt.epoch, i))
                field_names = loss_names[:len(loss_names) // 2]
                table = PrettyTable(field_names=field_names)
                for l_n in field_names:
                    table.align[l_n] = 'm'
                table.add_row(losses[:len(field_names)])
                print(table.get_string(reversesort=True))

                field_names = loss_names[len(loss_names) // 2:]
                table = PrettyTable(field_names=field_names)
                for l_n in field_names:
                    table.align[l_n] = 'm'
                table.add_row(losses[-len(field_names):])
                print(table.get_string(reversesort=True))

            ## visualize
            if cur_iter % opt.vis_freq == 0:
                if opt.gpu_id >= 0:
                    real_A = real_A.cpu().data
                    real_B = real_B.cpu().data
                    fake_A = fake_A.cpu().data
                    fake_B = fake_B.cpu().data
                    idt_A = idt_A.cpu().data
                    idt_B = idt_B.cpu().data
                    T_rec_A = T_rec_A.cpu().data
                    T_rec_B = T_rec_B.cpu().data

                plt.subplot(241), plt.title('real_A'), plt.imshow(
                    tensor2image_RGB(real_A[0, ...]))
                plt.subplot(242), plt.title('fake_B'), plt.imshow(
                    tensor2image_RGB(fake_B[0, ...]))
                plt.subplot(243), plt.title('idt_A'), plt.imshow(
                    tensor2image_RGB(idt_A[0, ...]))
                plt.subplot(244), plt.title('L_idt_A'), plt.imshow(
                    tensor2image_RGB(T_rec_A[0, ...]))

                plt.subplot(245), plt.title('real_B'), plt.imshow(
                    tensor2image_RGB(real_B[0, ...]))
                plt.subplot(246), plt.title('fake_A'), plt.imshow(
                    tensor2image_RGB(fake_A[0, ...]))
                plt.subplot(247), plt.title('idt_B'), plt.imshow(
                    tensor2image_RGB(idt_B[0, ...]))
                plt.subplot(248), plt.title('L_idt_B'), plt.imshow(
                    tensor2image_RGB(T_rec_B[0, ...]))

                plt.savefig(
                    os.path.join(checkpoints_pth, 'images',
                                 '%03d_%09d.jpg' % (epoch, i)))

            cur_iter += 1
            #break #-> debug

        ## till now, we finish one epoch, try to update the learning rate
        update_learning_rate(schedulers=scheduler,
                             opt=opt,
                             optimizer=optimizer_D)
        ## save the model
        if epoch % opt.ckp_freq == 0:
            #-> save models
            # torch.save(model.state_dict(), PATH)
            #-> load in models
            # model.load_state_dict(torch.load(PATH))
            # model.eval()
            if opt.gpu_id >= 0:
                E_a2Zb = E_a2Zb.cpu()
                G_Zb2b = G_Zb2b.cpu()
                T_Zb2Za = T_Zb2Za.cpu()
                D_b = D_b.cpu()

                E_b2Za = E_b2Za.cpu()
                G_Za2a = G_Za2a.cpu()
                T_Za2Zb = T_Za2Zb.cpu()
                D_a = D_a.cpu()
                '''
                torch.save( E_a2Zb.cpu().state_dict(), os.path.join(checkpoints_pth, 'epoch_%3d-E_a2b.pth' % epoch))
                torch.save( G_Zb2b.cpu().state_dict(), os.path.join(checkpoints_pth,   'epoch_%3d-G_b.pth' % epoch))
                torch.save(T_Zb2Za.cpu().state_dict(), os.path.join(checkpoints_pth, 'epoch_%3d-T_b2a.pth' % epoch))
                torch.save(    D_b.cpu().state_dict(), os.path.join(checkpoints_pth,   'epoch_%3d-D_b.pth' % epoch))

                torch.save( E_b2Za.cpu().state_dict(), os.path.join(checkpoints_pth, 'epoch_%3d-E_b2a.pth' % epoch))
                torch.save( G_Za2a.cpu().state_dict(), os.path.join(checkpoints_pth,   'epoch_%3d-G_a.pth' % epoch))
                torch.save(T_Za2Zb.cpu().state_dict(), os.path.join(checkpoints_pth, 'epoch_%3d-T_a2b.pth' % epoch))
                torch.save(    D_a.cpu().state_dict(), os.path.join(checkpoints_pth,   'epoch_%3d-D_a.pth' % epoch))
                '''
            torch.save(
                E_a2Zb.state_dict(),
                os.path.join(checkpoints_pth, 'epoch_%3d-E_a2b.pth' % epoch))
            torch.save(
                G_Zb2b.state_dict(),
                os.path.join(checkpoints_pth, 'epoch_%3d-G_b.pth' % epoch))
            torch.save(
                T_Zb2Za.state_dict(),
                os.path.join(checkpoints_pth, 'epoch_%3d-T_b2a.pth' % epoch))
            torch.save(
                D_b.state_dict(),
                os.path.join(checkpoints_pth, 'epoch_%3d-D_b.pth' % epoch))

            torch.save(
                E_b2Za.state_dict(),
                os.path.join(checkpoints_pth, 'epoch_%3d-E_b2a.pth' % epoch))
            torch.save(
                G_Za2a.state_dict(),
                os.path.join(checkpoints_pth, 'epoch_%3d-G_a.pth' % epoch))
            torch.save(
                T_Za2Zb.state_dict(),
                os.path.join(checkpoints_pth, 'epoch_%3d-T_a2b.pth' % epoch))
            torch.save(
                D_a.state_dict(),
                os.path.join(checkpoints_pth, 'epoch_%3d-D_a.pth' % epoch))
            if opt.gpu_id >= 0:
                E_a2Zb = E_a2Zb.to(device)
                G_Zb2b = G_Zb2b.to(device)
                T_Zb2Za = T_Zb2Za.to(device)
                D_b = D_b.to(device)

                E_b2Za = E_b2Za.to(device)
                G_Za2a = G_Za2a.to(device)
                T_Za2Zb = T_Za2Zb.to(device)
                D_a = D_a.to(device)
            print("+Successfully saving models in epoch: %3d.-------------+" %
                  epoch)
        #break #-> debug
    record_fh.close()
    print("≧◔◡◔≦ Congratulation! Finishing the training!")
示例#2
0
    def initialize(self, cfg):
        self.cfg = cfg

        ## set devices
        if cfg['GPU_IDS']:
            assert (torch.cuda.is_available())
            self.device = torch.device('cuda:{}'.format(cfg['GPU_IDS'][0]))
            torch.backends.cudnn.benchmark = True
            print('Using %d GPUs' % len(cfg['GPU_IDS']))
        else:
            self.device = torch.device('cpu')

        # define network
        if cfg['ARCHI'] == 'alexnet':
            self.netB = networks.netB_alexnet()
            self.netH = networks.netH_alexnet()
            if self.cfg['USE_DA'] and self.cfg['TRAIN']:
                self.netD = networks.netD_alexnet(self.cfg['DA_LAYER'])
        elif cfg['ARCHI'] == 'vgg16':
            raise NotImplementedError
            self.netB = networks.netB_vgg16()
            self.netH = networks.netH_vgg16()
            if self.cfg['USE_DA'] and self.cfg['TRAIN']:
                self.netD = netD_vgg16(self.cfg['DA_LAYER'])
        elif 'resnet' in cfg['ARCHI']:
            self.netB = networks.netB_resnet34()
            self.netH = networks.netH_resnet34()
            if self.cfg['USE_DA'] and self.cfg['TRAIN']:
                self.netD = networks.netD_resnet(self.cfg['DA_LAYER'])
        else:
            raise ValueError('Un-supported network')

        ## initialize network param.
        self.netB = networks.init_net(self.netB, cfg['GPU_IDS'], 'xavier')
        self.netH = networks.init_net(self.netH, cfg['GPU_IDS'], 'xavier')

        if self.cfg['USE_DA'] and self.cfg['TRAIN']:
            self.netD = networks.init_net(self.netD, cfg['GPU_IDS'], 'xavier')

        # loss, optimizer, and scherduler
        if cfg['TRAIN']:
            self.total_steps = 0
            ## Output path
            self.save_dir = os.path.join(
                cfg['OUTPUT_PATH'], cfg['ARCHI'],
                datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
            if not os.path.isdir(self.save_dir):
                os.makedirs(self.save_dir)
            # self.logger = Logger(self.save_dir)

            ## model names
            self.model_names = ['netB', 'netH']
            ## loss
            self.criterionGAN = networks.GANLoss().to(self.device)
            self.criterionDepth1 = torch.nn.MSELoss().to(self.device)
            self.criterionNorm = torch.nn.CosineEmbeddingLoss().to(self.device)
            self.criterionEdge = torch.nn.BCELoss().to(self.device)

            ## optimizers
            self.lr = cfg['LR']
            self.optimizers = []
            self.optimizer_B = torch.optim.Adam(self.netB.parameters(),
                                                lr=cfg['LR'],
                                                betas=(cfg['BETA1'],
                                                       cfg['BETA2']))
            self.optimizer_H = torch.optim.Adam(self.netH.parameters(),
                                                lr=cfg['LR'],
                                                betas=(cfg['BETA1'],
                                                       cfg['BETA2']))
            self.optimizers.append(self.optimizer_B)
            self.optimizers.append(self.optimizer_H)
            if cfg['USE_DA']:
                self.real_pool = ImagePool(cfg['POOL_SIZE'])
                self.syn_pool = ImagePool(cfg['POOL_SIZE'])
                self.model_names.append('netD')
                ## use SGD for discriminator
                self.optimizer_D = torch.optim.SGD(
                    self.netD.parameters(),
                    lr=cfg['LR'],
                    momentum=cfg['MOMENTUM'],
                    weight_decay=cfg['WEIGHT_DECAY'])
                self.optimizers.append(self.optimizer_D)
            ## LR scheduler
            self.schedulers = [
                networks.get_scheduler(optimizer, cfg)
                for optimizer in self.optimizers
            ]
        else:
            ## testing
            self.model_names = ['netB', 'netH']
            self.criterionDepth1 = torch.nn.MSELoss().to(self.device)
            self.criterionNorm = torch.nn.CosineEmbeddingLoss(
                reduction='none').to(self.device)

        self.load_dir = os.path.join(cfg['CKPT_PATH'])
        self.criterionNorm_eval = torch.nn.CosineEmbeddingLoss(
            reduction='none').to(self.device)

        if cfg['TEST'] or cfg['RESUME']:
            self.load_networks(cfg['EPOCH_LOAD'])
 def define_net(self, net):
     return networks.init_net(net, self.opt.init_type, self.opt.init_gain,
                              self.gpu_ids)