示例#1
0
def get_encoder_decoder(vocab):
    """ Given the arguments, returns the correct combination of CNN/RNN/GAN encoders and decoders. """
    if args.pretrain_rnn:
        encoder = EncoderRNN(len(vocab),
                             args.embed_size,
                             args.encoder_rnn_hidden_size,
                             num_layers=args.num_layers).to(device)
    elif args.gan_embedding:
        gan = torch.load('DCGAN_embed_2.tch').to(device)
        encoder = gan.discriminator
    elif args.progan_embedding:
        pro_gan = pg.ProGAN(depth=7,
                            latent_size=256,
                            device=torch.device('cuda'))
        pro_gan.dis.load_state_dict(torch.load('progan_weights/GAN_DIS_6.pth'))
        # pro_gan.dis_optim.load_state_dict(torch.load('progan_weights/GAN_DIS_OPTIM_6.pth'))
        pro_gan.gen.load_state_dict(torch.load('progan_weights/GAN_GEN_6.pth'))
        # pro_gan.gen_optim.load_state_dict(torch.load('progan_weights/GAN_GEN_OPTIM_6.pth'))
        pro_gan.gen_shadow.load_state_dict(
            torch.load('progan_weights/GAN_GEN_SHADOW_6.pth'))
        print("Loaded proGAN weights.", flush=True)
        encoder = pro_gan.dis.to(device)
    else:
        encoder = EncoderCNN(args.embed_size).to(device)

    decoder = DecoderRNNOld(args.embed_size,
                            args.decoder_rnn_hidden_size,
                            len(vocab),
                            args.num_layers,
                            vocab,
                            device=device).to(device)
    return encoder, decoder
示例#2
0
def train_model(device_to_run):
    #Data
    dataset = LandscapeImages()

    #Hyperparameters
    depth = 7
    batch_sizes = [5, 5, 5, 5, 5, 5, 5]
    num_epochs = [10, 15, 20, 20, 20, 20, 20]
    fade_ins = [50, 50, 50, 50, 50, 50, 50]
    latent_size = 512

    gan = pg.ProGAN(device=device_to_run, latent_size=latent_size, depth=depth)

    gan.train(dataset=dataset,
              epochs=num_epochs,
              batch_sizes=batch_sizes,
              fade_in_percentage=fade_ins,
              num_workers=3)
示例#3
0
    current_depth = state['current_depth']
    epoch = state['epoch']
    print('Loaded {}'.format(basename))
    return current_depth, epoch


if __name__ == '__main__':

    # some parameters:
    depth = 6
    latent_size = 512

    # ======================================================================
    # This line creates the PRO-GAN
    # ======================================================================
    pro_gan = pg.ProGAN(depth=depth, latent_size=latent_size, device=device)
    # gen = pg.Generator(depth=depth, latent_size=latent_size, use_eql=False).to(device)
    # ======================================================================

    current_depth, epoch = load_checkpoint(pro_gan, ckpt_dir)
    # if epoch != 10:
    # current_depth -= 1

    # OUT_DIR = out_dir
    # for current_depth in range(depth):
    current_depth -= 1
    if True:
        load_checkpoint(pro_gan,
                        ckpt_dir,
                        basename='checkpoint-{}-10.ckpt'.format(current_depth))
        # out_dir = '{}-{}'.format(OUT_DIR, current_depth)