def __init__(self, config):

        # Config
        self.config = config

        self.start = 0  # Unless using pre-trained model

        # Create directories if not exist
        utils.make_folder(self.config.save_path)
        utils.make_folder(self.config.model_weights_path)
        utils.make_folder(self.config.sample_images_path)

        # Copy files
        utils.write_config_to_file(self.config, self.config.save_path)
        utils.copy_scripts(self.config.save_path)

        # Check for CUDA
        utils.check_for_CUDA(self)

        # Make dataloader
        self.dataloader, self.num_of_classes = utils.make_dataloader(
            self.config.batch_size_in_gpu, self.config.dataset,
            self.config.data_path, self.config.shuffle, self.config.drop_last,
            self.config.dataloader_args, self.config.resize,
            self.config.imsize, self.config.centercrop,
            self.config.centercrop_size)

        # Data iterator
        self.data_iter = iter(self.dataloader)

        # Build G and D
        self.build_models()

        if self.config.adv_loss == 'dcgan':
            self.criterion = nn.BCELoss()
Beispiel #2
0
    def __init__(self, config):

        # Images data path & Output path
        self.dataset = config.dataset
        self.data_path = config.data_path
        self.save_path = os.path.join(config.save_path, config.name)

        # Training settings
        self.batch_size = config.batch_size
        self.total_step = config.total_step
        self.d_steps_per_iter = config.d_steps_per_iter
        self.g_steps_per_iter = config.g_steps_per_iter
        self.d_lr = config.d_lr
        self.g_lr = config.g_lr
        self.beta1 = config.beta1
        self.beta2 = config.beta2
        self.inst_noise_sigma = config.inst_noise_sigma
        self.inst_noise_sigma_iters = config.inst_noise_sigma_iters
        self.start = 0  # Unless using pre-trained model

        # Image transforms
        self.shuffle = config.shuffle
        self.drop_last = config.drop_last
        self.resize = config.resize
        self.imsize = config.imsize
        self.centercrop = config.centercrop
        self.centercrop_size = config.centercrop_size
        self.tanh_scale = config.tanh_scale
        self.normalize = config.normalize

        # Step size
        self.log_step = config.log_step
        self.sample_step = config.sample_step
        self.model_save_step = config.model_save_step
        self.save_n_images = config.save_n_images
        self.max_frames_per_gif = config.max_frames_per_gif

        # Pretrained model
        self.pretrained_model = config.pretrained_model

        # Misc
        self.manual_seed = config.manual_seed
        self.disable_cuda = config.disable_cuda
        self.parallel = config.parallel
        self.dataloader_args = config.dataloader_args

        # Output paths
        self.model_weights_path = os.path.join(self.save_path,
                                               config.model_weights_dir)
        self.sample_path = os.path.join(self.save_path, config.sample_dir)

        # Model hyper-parameters
        self.adv_loss = config.adv_loss
        self.z_dim = config.z_dim
        self.g_conv_dim = config.g_conv_dim
        self.d_conv_dim = config.d_conv_dim
        self.lambda_gp = config.lambda_gp

        # Model name
        self.name = config.name

        # Create directories if not exist
        utils.make_folder(self.save_path)
        utils.make_folder(self.model_weights_path)
        utils.make_folder(self.sample_path)

        # Copy files
        utils.write_config_to_file(config, self.save_path)
        utils.copy_scripts(self.save_path)

        # Check for CUDA
        utils.check_for_CUDA(self)

        # Make dataloader
        self.dataloader, self.num_of_classes = utils.make_dataloader(
            self.batch_size, self.dataset, self.data_path, self.shuffle,
            self.drop_last, self.dataloader_args, self.resize, self.imsize,
            self.centercrop, self.centercrop_size)

        # Data iterator
        self.data_iter = iter(self.dataloader)

        # Build G and D
        self.build_models()

        # Start with pretrained model (if it exists)
        if self.pretrained_model != '':
            utils.load_pretrained_model(self)

        if self.adv_loss == 'dcgan':
            self.criterion = nn.BCELoss()
Beispiel #3
0
def main():
    global args
    args = get_config()
    args.commond = 'python ' + ' '.join(sys.argv)

    # Create saving directory
    if args.unigen:
        save_dir = './results_unigen/{0}/G{1}_glr{2}_dlr{3}_dstep{4}_zdim{5}_{6}/'.format(
            args.dataset, args.dec_dist, str(args.lr), str(args.lr_d),
            str(args.d_steps_per_iter), str(args.latent_dim), args.div)
    else:
        save_dir = './results/{0}/E{1}_G{2}_glr{3}_dlr{4}_gstep{5}_dstep{6}_zdim{7}_{8}/'.format(
            args.dataset, args.enc_dist, args.dec_dist, str(args.lr),
            str(args.lr_d), str(args.g_steps_per_iter),
            str(args.d_steps_per_iter), str(args.latent_dim), args.div)

    utils.make_folder(save_dir)
    utils.write_config_to_file(args, save_dir)

    global device
    device = torch.device('cuda')

    random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    # Load datasets
    train_loader, test_loader = utils.make_dataloader(args)

    num_samples = len(train_loader.dataset)
    global num_iter_per_epoch
    num_iter_per_epoch = num_samples // args.batch_size

    # Losses file
    log_file_name = os.path.join(save_dir, 'log.txt')
    global log_file
    if args.resume:
        log_file = open(log_file_name, "at")
    else:
        log_file = open(log_file_name, "wt")

    # Build model
    if args.unigen:
        if args.dataset == 'mnist_stack':
            model = DCDecoder(args.latent_dim, 64, args.image_size, 3,
                              args.dec_dist)
            discriminator = DCDiscriminator(args.d_conv_dim, args.image_size)
        else:
            model = Generator(args.latent_dim, args.g_conv_dim,
                              args.image_size)
            discriminator = Discriminator(args.d_conv_dim, args.image_size)
        encoder_optimizer = None
        decoder_optimizer = optim.Adam(model.parameters(),
                                       lr=args.lr,
                                       betas=(args.beta1, args.beta2))
        D_optimizer = optim.Adam(discriminator.parameters(),
                                 lr=args.lr_d,
                                 betas=(args.beta1, args.beta2))

    else:
        if args.dataset == 'mog':
            model = ToyAE(data_dim=2,
                          latent_dim=args.latent_dim,
                          enc_hidden_dim=500,
                          dec_hidden_dim=500,
                          enc_dist=args.enc_dist,
                          dec_dist=args.dec_dist)
            discriminator = DiscriminatorMLP(data_dim=2,
                                             latent_dim=args.latent_dim,
                                             hidden_dim_x=400,
                                             hidden_dim_z=400,
                                             hidden_dim=400)
        elif args.dataset in ['mnist', 'mnist_stack']:
            image_channel = 3 if args.dataset == 'mnist_stack' else 1
            tanh = args.prior == 'uniform' and args.enc_dist == 'deterministic'
            model = DCAE(args.latent_dim, 64, args.image_size, image_channel,
                         args.enc_dist, args.dec_dist, tanh)
            discriminator = DCJointDiscriminator(args.latent_dim, 64,
                                                 args.image_size,
                                                 image_channel,
                                                 args.dis_fc_size)
        else:
            model = BGM(args.latent_dim, args.g_conv_dim, args.image_size, 3,
                        args.enc_dist, args.enc_arch, args.enc_fc_size,
                        args.enc_noise_dim, args.dec_dist)
            discriminator = BigJointDiscriminator(args.latent_dim,
                                                  args.d_conv_dim,
                                                  args.image_size,
                                                  args.dis_fc_size)
        encoder_optimizer = optim.Adam(model.encoder.parameters(),
                                       lr=args.lr,
                                       betas=(args.beta1, args.beta2))
        decoder_optimizer = optim.Adam(model.decoder.parameters(),
                                       lr=args.lr,
                                       betas=(args.beta1, args.beta2))
        D_optimizer = optim.Adam(discriminator.parameters(),
                                 lr=args.lr_d,
                                 betas=(args.beta1, args.beta2))

    # Load model from checkpoint
    if args.resume:
        ckpt_dir = args.ckpt_dir if args.ckpt_dir != '' else save_dir + 'model' + str(
            args.start_epoch - 1) + '.sav'
        checkpoint = torch.load(ckpt_dir)
        model.load_state_dict(checkpoint['model'])
        discriminator.load_state_dict(checkpoint['discriminator'])
        del checkpoint

    model = nn.DataParallel(model.to(device))
    discriminator = nn.DataParallel(discriminator.to(device))

    # Fixed noise from prior p_z for generating from G
    global fixed_noise
    if args.prior == 'gaussian':
        fixed_noise = torch.randn(args.save_n_samples,
                                  args.latent_dim,
                                  device=device)
    else:
        fixed_noise = torch.rand(
            args.save_n_samples, args.latent_dim, device=device) * 2 - 1

    # Train
    for i in range(args.start_epoch, args.start_epoch + args.n_epochs):
        train_age(i, model, discriminator, encoder_optimizer,
                  decoder_optimizer, D_optimizer, train_loader,
                  args.print_every, save_dir, args.sample_every, test_loader)
        if i % args.save_model_every == 0:
            torch.save(
                {
                    'model': model.module.state_dict(),
                    'discriminator': discriminator.module.state_dict()
                }, save_dir + 'model' + str(i) + '.sav')
Beispiel #4
0
    # IMAGES DATALOADER
    train_loader, valid_loader = utils.make_dataloader(args)

    print(args)

    # OUT PATH
    if not os.path.exists(args.out_path):
        print("Making", args.out_path)
        os.makedirs(args.out_path)

    # Copy all scripts
    utils.copy_scripts(args.out_path)

    # Save all args
    utils.write_config_to_file(args, args.out_path)

    # MODEL

    torch.manual_seed(args.seed)
    if args.pth != '':
        pth_dir_name = os.path.dirname(args.pth)
        full_model_pth = os.path.join(pth_dir_name, 'model.pth')
        if os.path.exists(full_model_pth):
            print("Loading", full_model_pth)
            model = torch.load(full_model_pth)
            print("Loading pretrained state_dict", args.pth)
            model.load_state_dict(torch.load(args.pth))
            model = model.to(args.device)
        else:
            if args.model == 'baseline':