示例#1
0
def main(args):
    train_loader, test_loader = load_data(args)

    GeneratorA2B = CycleGAN()
    GeneratorB2A = CycleGAN()

    DiscriminatorA = Discriminator()
    DiscriminatorB = Discriminator()

    if args.cuda:
        GeneratorA2B = GeneratorA2B.cuda()
        GeneratorB2A = GeneratorB2A.cuda()

        DiscriminatorA = DiscriminatorA.cuda()
        DiscriminatorB = DiscriminatorB.cuda()

    optimizerG = optim.Adam(itertools.chain(GeneratorA2B.parameters(), GeneratorB2A.parameters()), lr=args.lr, betas=(0.5, 0.999))
    optimizerD = optim.Adam(itertools.chain(DiscriminatorA.parameters(), DiscriminatorB.parameters()), lr=args.lr, betas=(0.5, 0.999))

    if args.training:
        path = 'E:/cyclegan/checkpoints/model_{}_{}.pth'.format(285, 200)

        checkpoint = torch.load(path)
        GeneratorA2B.load_state_dict(checkpoint['generatorA'])
        GeneratorB2A.load_state_dict(checkpoint['generatorB'])
        DiscriminatorA.load_state_dict(checkpoint['discriminatorA'])
        DiscriminatorB.load_state_dict(checkpoint['discriminatorB'])
        optimizerG.load_state_dict(checkpoint['optimizerG'])
        optimizerD.load_state_dict(checkpoint['optimizerD'])

        start_epoch = 285
    else:
        init_net(GeneratorA2B, init_type='normal', init_gain=0.02, gpu_ids=[0])
        init_net(GeneratorB2A, init_type='normal', init_gain=0.02, gpu_ids=[0])

        init_net(DiscriminatorA, init_type='normal', init_gain=0.02, gpu_ids=[0])
        init_net(DiscriminatorB, init_type='normal', init_gain=0.02, gpu_ids=[0])
        start_epoch = 1

    if args.evaluation:
        evaluation(test_loader, GeneratorA2B, GeneratorB2A, args)
    else:
        cycle = nn.L1Loss()
        gan = nn.BCEWithLogitsLoss()
        identity = nn.L1Loss()

        for epoch in range(start_epoch, args.epochs):
            train(train_loader, GeneratorA2B, GeneratorB2A, DiscriminatorA, DiscriminatorB, optimizerG, optimizerD, cycle, gan, identity, args, epoch)
        evaluation(test_loader, GeneratorA2B, GeneratorB2A, args)
示例#2
0
        Dp_Y)
    print('Dp_Y:')
    print('\t- Num of Parameters                : {:,}'.format(n_params))
    print('\t- Num of Trainable Parameters      : {:,}'.format(
        n_trainable_params))
    print('\t- Num of Non-Trainable Parameters  : {:,}'.format(
        n_non_trainable_params))
    print('==========================================================')
    summary(G_XtoY, (3, image_size[1], image_size[0]))
    summary(Dp_X, (3, image_size[1], image_size[0]))
    summary(Dg_X, (3, image_size[1], image_size[0]))

    # load weights
    if use_pretrained_weights:
        G_XtoY.load_state_dict(
            torch.load(generator_x_y_weights,
                       map_location=lambda storage, loc: storage))
        G_YtoX.load_state_dict(
            torch.load(generator_y_x_weights,
                       map_location=lambda storage, loc: storage))
        Dp_X.load_state_dict(
            torch.load(discriminator_xp_weights,
                       map_location=lambda storage, loc: storage))
        Dp_Y.load_state_dict(
            torch.load(discriminator_yp_weights,
                       map_location=lambda storage, loc: storage))
        Dg_X.load_state_dict(
            torch.load(discriminator_xg_weights,
                       map_location=lambda storage, loc: storage))
        Dg_Y.load_state_dict(
            torch.load(discriminator_yg_weights,