def main(args): train_loader, test_loader = load_data(args) GeneratorA2B = CycleGAN() GeneratorB2A = CycleGAN() DiscriminatorA = Discriminator() DiscriminatorB = Discriminator() if args.cuda: GeneratorA2B = GeneratorA2B.cuda() GeneratorB2A = GeneratorB2A.cuda() DiscriminatorA = DiscriminatorA.cuda() DiscriminatorB = DiscriminatorB.cuda() optimizerG = optim.Adam(itertools.chain(GeneratorA2B.parameters(), GeneratorB2A.parameters()), lr=args.lr, betas=(0.5, 0.999)) optimizerD = optim.Adam(itertools.chain(DiscriminatorA.parameters(), DiscriminatorB.parameters()), lr=args.lr, betas=(0.5, 0.999)) if args.training: path = 'E:/cyclegan/checkpoints/model_{}_{}.pth'.format(285, 200) checkpoint = torch.load(path) GeneratorA2B.load_state_dict(checkpoint['generatorA']) GeneratorB2A.load_state_dict(checkpoint['generatorB']) DiscriminatorA.load_state_dict(checkpoint['discriminatorA']) DiscriminatorB.load_state_dict(checkpoint['discriminatorB']) optimizerG.load_state_dict(checkpoint['optimizerG']) optimizerD.load_state_dict(checkpoint['optimizerD']) start_epoch = 285 else: init_net(GeneratorA2B, init_type='normal', init_gain=0.02, gpu_ids=[0]) init_net(GeneratorB2A, init_type='normal', init_gain=0.02, gpu_ids=[0]) init_net(DiscriminatorA, init_type='normal', init_gain=0.02, gpu_ids=[0]) init_net(DiscriminatorB, init_type='normal', init_gain=0.02, gpu_ids=[0]) start_epoch = 1 if args.evaluation: evaluation(test_loader, GeneratorA2B, GeneratorB2A, args) else: cycle = nn.L1Loss() gan = nn.BCEWithLogitsLoss() identity = nn.L1Loss() for epoch in range(start_epoch, args.epochs): train(train_loader, GeneratorA2B, GeneratorB2A, DiscriminatorA, DiscriminatorB, optimizerG, optimizerD, cycle, gan, identity, args, epoch) evaluation(test_loader, GeneratorA2B, GeneratorB2A, args)
def load_checkpoint(): PATH = os.path.join(C.EXP_DIR, C.TAG) ckpt_file = C.CKPT_PREFIX % str(C.n_ckpt) model_path = os.path.join(PATH, ckpt_file) print("load model at %s" % model_path) cycle_gan = CycleGAN( Generator(C.g_conv_ch,C.g_trans_ch,C.g_kernels, C.g_strides,C.g_n_res_block, C.g_leaky_slop), Generator(C.g_conv_ch,C.g_trans_ch,C.g_kernels, C.g_strides,C.g_n_res_block, C.g_leaky_slop), Discriminator(C.nc_input), Discriminator(C.nc_input) ) cycle_gan.load_checkpoint(model_path) if C.use_cuda: cycle_gan.cuda() return cycle_gan