Exemplo n.º 1
0
def main(args):
    train_loader, test_loader = load_data(args)

    GeneratorA2B = CycleGAN()
    GeneratorB2A = CycleGAN()

    DiscriminatorA = Discriminator()
    DiscriminatorB = Discriminator()

    if args.cuda:
        GeneratorA2B = GeneratorA2B.cuda()
        GeneratorB2A = GeneratorB2A.cuda()

        DiscriminatorA = DiscriminatorA.cuda()
        DiscriminatorB = DiscriminatorB.cuda()

    optimizerG = optim.Adam(itertools.chain(GeneratorA2B.parameters(), GeneratorB2A.parameters()), lr=args.lr, betas=(0.5, 0.999))
    optimizerD = optim.Adam(itertools.chain(DiscriminatorA.parameters(), DiscriminatorB.parameters()), lr=args.lr, betas=(0.5, 0.999))

    if args.training:
        path = 'E:/cyclegan/checkpoints/model_{}_{}.pth'.format(285, 200)

        checkpoint = torch.load(path)
        GeneratorA2B.load_state_dict(checkpoint['generatorA'])
        GeneratorB2A.load_state_dict(checkpoint['generatorB'])
        DiscriminatorA.load_state_dict(checkpoint['discriminatorA'])
        DiscriminatorB.load_state_dict(checkpoint['discriminatorB'])
        optimizerG.load_state_dict(checkpoint['optimizerG'])
        optimizerD.load_state_dict(checkpoint['optimizerD'])

        start_epoch = 285
    else:
        init_net(GeneratorA2B, init_type='normal', init_gain=0.02, gpu_ids=[0])
        init_net(GeneratorB2A, init_type='normal', init_gain=0.02, gpu_ids=[0])

        init_net(DiscriminatorA, init_type='normal', init_gain=0.02, gpu_ids=[0])
        init_net(DiscriminatorB, init_type='normal', init_gain=0.02, gpu_ids=[0])
        start_epoch = 1

    if args.evaluation:
        evaluation(test_loader, GeneratorA2B, GeneratorB2A, args)
    else:
        cycle = nn.L1Loss()
        gan = nn.BCEWithLogitsLoss()
        identity = nn.L1Loss()

        for epoch in range(start_epoch, args.epochs):
            train(train_loader, GeneratorA2B, GeneratorB2A, DiscriminatorA, DiscriminatorB, optimizerG, optimizerD, cycle, gan, identity, args, epoch)
        evaluation(test_loader, GeneratorA2B, GeneratorB2A, args)
Exemplo n.º 2
0
# configure full paths
checkpoints_dir = os.path.join(checkpoints_dir, experiment_id)
samples_dir = os.path.join(samples_dir, experiment_id)
logs_dir = os.path.join(logs_dir, experiment_id)

# make directories
os.system('mkdir -p ' + checkpoints_dir)
os.system('mkdir -p ' + samples_dir)
os.system('mkdir -p ' + logs_dir)

## create models
# call the function to get models
G_XtoY, G_YtoX, Dp_X, Dp_Y, Dg_X, Dg_Y = CycleGAN(n_res_blocks=2)

# define optimizer parameters
g_params = list(G_XtoY.parameters()) + list(
    G_YtoX.parameters())  # Get generator parameters

# Create optimizers for the generators and discriminators
g_optimizer = optim.Adam(g_params, g_lr, [beta1, beta2])
dp_x_optimizer = optim.Adam(Dp_X.parameters(), d_lr, [beta1, beta2])
dp_y_optimizer = optim.Adam(Dp_Y.parameters(), d_lr, [beta1, beta2])
dg_x_optimizer = optim.Adam(Dg_X.parameters(), d_lr, [beta1, beta2])
dg_y_optimizer = optim.Adam(Dg_Y.parameters(), d_lr, [beta1, beta2])


# count number of parameters in a model
def count_model_parameters(model):
    n_params = sum(p.numel() for p in model.parameters())
    n_trainable_params = sum(p.numel() for p in model.parameters()
                             if p.requires_grad)