예제 #1
0
def main():
    if FLAGS.exp == 'dir64':
        opts = configs.config_dir64
    else:
        assert False, 'Unknown experiment configuration'

    if FLAGS.zdim is not None:
        opts['zdim'] = FLAGS.zdim

    if opts['verbose']:
        logging.basicConfig(level=logging.DEBUG,
                            format='%(asctime)s - %(message)s')
    logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')

    wae = WAE(opts)
    wae.restore_checkpoint(FLAGS.checkpoint)
    batch_noise = wae.sample_pz(10)
    sample_gen = wae.sess.run(wae.decoded,
                              feed_dict={
                                  wae.sample_noise: batch_noise,
                                  wae.is_training: False
                              })
    img = np.hstack(sample_gen)
    img = (img + 1.0) / 2
    plt.imshow(img)
    plt.savefig('img.png')
예제 #2
0
파일: run.py 프로젝트: knok/wae
def main():

    if FLAGS.exp == 'celebA':
        opts = configs.config_celebA
    elif FLAGS.exp == 'celebA_small':
        opts = configs.config_celebA_small
    elif FLAGS.exp == 'mnist':
        opts = configs.config_mnist
    elif FLAGS.exp == 'mnist_small':
        opts = configs.config_mnist_small
    elif FLAGS.exp == 'dsprites':
        opts = configs.config_dsprites
    elif FLAGS.exp == 'grassli':
        opts = configs.config_grassli
    elif FLAGS.exp == 'grassli_small':
        opts = configs.config_grassli_small
    elif FLAGS.exp == 'dir64':
        opts = configs.config_dir64
    else:
        assert False, 'Unknown experiment configuration'

    if FLAGS.zdim is not None:
        opts['zdim'] = FLAGS.zdim
    if FLAGS.lr is not None:
        opts['lr'] = FLAGS.lr
    if FLAGS.z_test is not None:
        opts['z_test'] = FLAGS.z_test
    if FLAGS.lambda_schedule is not None:
        opts['lambda_schedule'] = FLAGS.lambda_schedule
    if FLAGS.work_dir is not None:
        opts['work_dir'] = FLAGS.work_dir
    if FLAGS.wae_lambda is not None:
        opts['lambda'] = FLAGS.wae_lambda
    if FLAGS.enc_noise is not None:
        opts['e_noise'] = FLAGS.enc_noise
    if FLAGS.epoch_num is not None:
        opts['epoch_num'] = FLAGS.epoch_num

    if opts['verbose']:
        logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(message)s')
    logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')
    utils.create_dir(opts['work_dir'])
    utils.create_dir(os.path.join(opts['work_dir'],
                     'checkpoints'))
    # Dumping all the configs to the text file
    with utils.o_gfile((opts['work_dir'], 'params.txt'), 'w') as text:
        text.write('Parameters:\n')
        for key in opts:
            text.write('%s : %s\n' % (key, opts[key]))

    # Loading the dataset

    data = DataHandler(opts)
    assert data.num_points >= opts['batch_size'], 'Training set too small'

    # Training WAE

    wae = WAE(opts)
    wae.train(data)
예제 #3
0
def main():

    opts = configs.config_mnist

    opts['mode'] = 'train'

    if opts['verbose']:
        logging.basicConfig(level=logging.DEBUG,
                            format='%(asctime)s - %(message)s')
    logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')
    utils.create_dir(opts['work_dir'])
    utils.create_dir(os.path.join(opts['work_dir'], 'checkpoints'))

    if opts['e_noise'] == 'gaussian' and opts['pz'] != 'normal':
        assert False, 'Gaussian encoders compatible only with Gaussian prior'
        return

    # Dumping all the configs to the text file
    with utils.o_gfile((opts['work_dir'], 'params.txt'), 'w') as text:
        text.write('Parameters:\n')
        for key in opts:
            text.write('%s : %s\n' % (key, opts[key]))

    # Loading the dataset
    data = DataHandler(opts)
    assert data.num_points >= opts['batch_size'], 'Training set too small'

    if opts['mode'] == 'train':

        # Creating WAE model
        wae = WAE(opts, data.num_points)

        # Training WAE
        wae.train(data)

    elif opts['mode'] == 'test':

        # Do something else
        improved_wae.improved_sampling(opts)
예제 #4
0
파일: analogy.py 프로젝트: knok/wae
def main():
    if FLAGS.exp == 'dir64':
        opts = configs.config_dir64
    else:
        assert False, 'Unknown experiment configuration'

    if FLAGS.zdim is not None:
        opts['zdim'] = FLAGS.zdim

    if opts['verbose']:
        logging.basicConfig(level=logging.DEBUG,
                            format='%(asctime)s - %(message)s')
    logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')

    data = DataHandler(opts)
    wae = WAE(opts)
    wae.restore_checkpoint(FLAGS.checkpoint)

    batch_img = data.data[0:2]
    enc_vec = wae.sess.run(wae.encoded,
                           feed_dict={
                               wae.sample_points: batch_img,
                               wae.is_training: False
                           })
    vdiff = enc_vec[1] - enc_vec[0]
    vdiff = vdiff / 10
    gen_vec = np.zeros((10, vdiff.shape[0]), dtype=np.float32)
    for i in range(10):
        gen_vec[i, :] = enc_vec[0] + vdiff * i

    sample_gen = wae.sess.run(wae.decoded,
                              feed_dict={
                                  wae.sample_noise: gen_vec,
                                  wae.is_training: False
                              })
    img = np.hstack(sample_gen)
    img = (img + 1.0) / 2
    plt.imshow(img)
    plt.savefig('analogy.png')
예제 #5
0
def main():

    if FLAGS.exp == 'celebA':
        opts = configs.config_celebA
    elif FLAGS.exp == 'celebA_small':
        opts = configs.config_celebA_small
    elif FLAGS.exp == 'mnist':
        opts = configs.config_mnist
    elif FLAGS.exp == 'mnist_small':
        opts = configs.config_mnist_small
    elif FLAGS.exp == 'dsprites':
        opts = configs.config_dsprites
    elif FLAGS.exp == 'grassli':
        opts = configs.config_grassli
    elif FLAGS.exp == 'grassli_small':
        opts = configs.config_grassli_small
    else:
        assert False, 'Unknown experiment configuration'

    opts['mode'] = FLAGS.mode
    if opts['mode'] == 'test':
        assert FLAGS.checkpoint is not None, 'Checkpoint must be provided'
        opts['checkpoint'] = FLAGS.checkpoint

    if FLAGS.zdim is not None:
        opts['zdim'] = FLAGS.zdim
    if FLAGS.pz is not None:
        opts['pz'] = FLAGS.pz
    if FLAGS.lr is not None:
        opts['lr'] = FLAGS.lr
    if FLAGS.w_aef is not None:
        opts['w_aef'] = FLAGS.w_aef
    if FLAGS.z_test is not None:
        opts['z_test'] = FLAGS.z_test
    if FLAGS.lambda_schedule is not None:
        opts['lambda_schedule'] = FLAGS.lambda_schedule
    if FLAGS.work_dir is not None:
        opts['work_dir'] = FLAGS.work_dir
    if FLAGS.wae_lambda is not None:
        opts['lambda'] = FLAGS.wae_lambda
    if FLAGS.enc_noise is not None:
        opts['e_noise'] = FLAGS.enc_noise

    if opts['verbose']:
        logging.basicConfig(level=logging.DEBUG,
                            format='%(asctime)s - %(message)s')
    logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')
    utils.create_dir(opts['work_dir'])
    utils.create_dir(os.path.join(opts['work_dir'], 'checkpoints'))

    if opts['e_noise'] == 'gaussian' and opts['pz'] != 'normal':
        assert False, 'Gaussian encoders compatible only with Gaussian prior'
        return

    # Dumping all the configs to the text file
    with utils.o_gfile((opts['work_dir'], 'params.txt'), 'w') as text:
        text.write('Parameters:\n')
        for key in opts:
            text.write('%s : %s\n' % (key, opts[key]))

    # Loading the dataset
    data = DataHandler(opts)
    assert data.num_points >= opts['batch_size'], 'Training set too small'

    if opts['mode'] == 'train':

        # Creating WAE model
        wae = WAE(opts)

        # Training WAE
        wae.train(data)

    elif opts['mode'] == 'test':

        # Do something else
        improved_wae.improved_sampling(opts)
예제 #6
0
파일: run.py 프로젝트: csadrian/wae
def main():

    if FLAGS.exp == 'celebA':
        opts = configs.config_celebA
    elif FLAGS.exp == 'celebA_small':
        opts = configs.config_celebA_small
    elif FLAGS.exp == 'mnist':
        opts = configs.config_mnist
    elif FLAGS.exp == 'mnist_ord':
        opts = configs.config_mnist_ord
    elif FLAGS.exp == 'mnist_small':
        opts = configs.config_mnist_small
    elif FLAGS.exp == 'dsprites':
        opts = configs.config_dsprites
    elif FLAGS.exp == 'grassli':
        opts = configs.config_grassli
    elif FLAGS.exp == 'grassli_small':
        opts = configs.config_grassli_small
    elif FLAGS.exp == 'syn_constant_uniform':
        opts = configs.config_syn_constant_uniform
    elif FLAGS.exp == 'syn_2_constant_uniform':
        opts = configs.config_syn_2_constant_uniform
    elif FLAGS.exp == 'checkers':
        opts = configs.config_checkers
    elif FLAGS.exp == 'noise':
        opts = configs.config_noise
    elif FLAGS.exp == 'noise_unif':
        opts = configs.config_noise_unif
    else:
        assert False, 'Unknown experiment configuration'

    opts['exp'] = FLAGS.exp
    opts['seed'] = FLAGS.seed

    opts['mode'] = FLAGS.mode
    if opts['mode'] == 'test':
        assert FLAGS.checkpoint is not None, 'Checkpoint must be provided'
        opts['checkpoint'] = FLAGS.checkpoint

    if FLAGS.batch_size is not None:
        opts['batch_size'] = FLAGS.batch_size

    if FLAGS.recalculate_size is not None:
        opts['recalculate_size'] = FLAGS.recalculate_size
        assert opts['recalculate_size'] >= opts[
            'batch_size'], "recalculate_size should be at least as large as batch_size"
    else:
        opts['recalculate_size'] = opts['batch_size']

    if FLAGS.zdim is not None:
        opts['zdim'] = FLAGS.zdim
    if FLAGS.pz is not None:
        opts['pz'] = FLAGS.pz
    if FLAGS.lr is not None:
        opts['lr'] = FLAGS.lr
    if FLAGS.lr_schedule is not None:
        opts['lr_schedule'] = FLAGS.lr_schedule

    if FLAGS.w_aef is not None:
        opts['w_aef'] = FLAGS.w_aef
    if FLAGS.z_test is not None:
        opts['z_test'] = FLAGS.z_test
    if FLAGS.lambda_schedule is not None:
        opts['lambda_schedule'] = FLAGS.lambda_schedule
    if FLAGS.work_dir is not None:
        opts['work_dir'] = FLAGS.work_dir
    if FLAGS.wae_lambda is not None:
        opts['lambda'] = FLAGS.wae_lambda
    if FLAGS.enc_noise is not None:
        opts['e_noise'] = FLAGS.enc_noise
    if FLAGS.z_test_scope is not None:
        opts['z_test_scope'] = FLAGS.z_test_scope

    if FLAGS.length_lambda is not None:
        opts['length_lambda'] = FLAGS.length_lambda

    if FLAGS.grad_clip is not None:
        opts['grad_clip'] = FLAGS.grad_clip
    else:
        opts['grad_clip'] = None

    if FLAGS.rec_lambda is not None:
        opts['rec_lambda'] = FLAGS.rec_lambda
    if FLAGS.zxz_lambda is not None:
        opts['zxz_lambda'] = FLAGS.zxz_lambda
    if FLAGS.train_size is not None:
        opts['train_size'] = FLAGS.train_size
    if FLAGS.nat_size is not None:
        opts['nat_size'] = FLAGS.nat_size
    else:
        opts['nat_size'] = FLAGS.train_size
    opts['nat_resampling'] = FLAGS.nat_resampling

    opts['sinkhorn_sparse'] = FLAGS.sinkhorn_sparse
    opts['sinkhorn_sparsifier'] = FLAGS.sinkhorn_sparsifier
    opts['sparsifier_freq'] = FLAGS.sparsifier_freq
    opts['sinkhorn_unbiased'] = FLAGS.sinkhorn_unbiased
    opts['feed_by_score_from_epoch'] = FLAGS.feed_by_score_from_epoch
    opts['recalculate_size'] = FLAGS.recalculate_size
    opts['stay_lambda'] = FLAGS.stay_lambda

    opts['mover_ratio'] = FLAGS.mover_ratio
    assert opts['mover_ratio'] >= 0 and opts[
        'mover_ratio'] <= 1, "mover_ratio must be in [0,1]"

    if FLAGS.sinkhorn_iters is not None:
        opts['sinkhorn_iters'] = FLAGS.sinkhorn_iters
    if FLAGS.sinkhorn_epsilon is not None:
        opts['sinkhorn_epsilon'] = FLAGS.sinkhorn_epsilon
    if FLAGS.name is not None:
        opts['name'] = FLAGS.name
    if FLAGS.tags is not None:
        opts['tags'] = FLAGS.tags
    if FLAGS.epoch_num is not None:
        opts['epoch_num'] = FLAGS.epoch_num
    if FLAGS.e_pretrain is not None:
        opts['e_pretrain'] = FLAGS.e_pretrain
    if FLAGS.shuffle is not None:
        opts['shuffle'] = FLAGS.shuffle

    if opts['verbose']:
        pass
        #logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(message)s')
    logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')
    utils.create_dir(opts['work_dir'])
    utils.create_dir(os.path.join(opts['work_dir'], 'checkpoints'))

    if opts['e_noise'] == 'gaussian' and opts['pz'] != 'normal':
        assert False, 'Gaussian encoders compatible only with Gaussian prior'
        return

    # Dumping all the configs to the text file
    with utils.o_gfile((opts['work_dir'], 'params.txt'), 'w') as text:
        text.write('Parameters:\n')
        for key in opts:
            text.write('%s : %s\n' % (key, opts[key]))

    # Loading the dataset
    data = DataHandler(opts)
    assert data.num_points >= opts['batch_size'], 'Training set too small'

    if 'train_size' in opts and opts['train_size'] is not None:
        train_size = opts['train_size']
    else:
        train_size = data.num_points
    print("Train size:", train_size)

    if opts['nat_size'] == -1:
        opts['nat_size'] = train_size

    use_neptune = "NEPTUNE_API_TOKEN" in os.environ

    if opts['mode'] == 'train':
        if use_neptune:
            neptune.init(project_qualified_name="csadrian/global-sinkhorn")
            exp = neptune.create_experiment(
                params=opts,
                name=opts['name'],
                upload_source_files=['wae.py', 'run.py', 'models.py'])

            for tag in opts['tags'].split(','):
                neptune.append_tag(tag)

        # Creating WAE model
        wae = WAE(opts, train_size)
        data.num_points = train_size

        # Training WAE
        wae.train(data)

        if use_neptune:
            exp.stop()

    elif opts['mode'] == 'test':
        # Do something else
        improved_wae.improved_sampling(opts)

    elif opts['mode'] == 'generate':
        fideval.generate(opts)

    elif opts['mode'] == 'draw':
        picture_plot.createimgs(opts)
예제 #7
0
파일: run.py 프로젝트: paruby/wae
def main():

    if FLAGS.exp == 'celebA':
        opts = configs.config_celebA
    elif FLAGS.exp == 'celebA_small':
        opts = configs.config_celebA_small
    elif FLAGS.exp == 'celebA_ae_patch_var':
        opts = configs.config_celebA_ae_patch_var
    elif FLAGS.exp == 'celebA_sylvain_adv':
        opts = configs.config_celebA_sylvain_adv
    elif FLAGS.exp == 'celebA_adv':
        opts = configs.config_celebA_adv
    elif FLAGS.exp == 'mnist':
        opts = configs.config_mnist
    elif FLAGS.exp == 'mnist_small':
        opts = configs.config_mnist_small
    elif FLAGS.exp == 'dsprites':
        opts = configs.config_dsprites
    elif FLAGS.exp == 'grassli':
        opts = configs.config_grassli
    elif FLAGS.exp == 'grassli_small':
        opts = configs.config_grassli_small
    elif FLAGS.exp == 'cifar':
        opts = configs.config_cifar
    else:
        assert False, 'Unknown experiment configuration'

    opts['mode'] = FLAGS.mode
    if opts['mode'] == 'test':
        assert FLAGS.checkpoint is not None, 'Checkpoint must be provided'
        opts['checkpoint'] = FLAGS.checkpoint

    if FLAGS.zdim is not None:
        opts['zdim'] = FLAGS.zdim
    if FLAGS.pz is not None:
        opts['pz'] = FLAGS.pz
    if FLAGS.lr is not None:
        opts['lr'] = FLAGS.lr
    if FLAGS.w_aef is not None:
        opts['w_aef'] = FLAGS.w_aef
    if FLAGS.z_test is not None:
        opts['z_test'] = FLAGS.z_test
    if FLAGS.lambda_schedule is not None:
        opts['lambda_schedule'] = FLAGS.lambda_schedule
    if FLAGS.work_dir is not None:
        opts['work_dir'] = FLAGS.work_dir
    if FLAGS.wae_lambda is not None:
        opts['lambda'] = FLAGS.wae_lambda
    if FLAGS.celebA_crop is not None:
        opts['celebA_crop'] = FLAGS.celebA_crop
    if FLAGS.enc_noise is not None:
        opts['e_noise'] = FLAGS.enc_noise
    if FLAGS.e_num_filters is not None:
        opts['e_num_filters'] = FLAGS.e_num_filters
    if FLAGS.g_num_filters is not None:
        opts['g_num_filters'] = FLAGS.g_num_filters
    if FLAGS.smart_cost is True:
        opts['cost'] = []
        if FLAGS.patch_var_w is not None:
            opts['cost'].append(('patch_variances', FLAGS.patch_var_w))
        if FLAGS.l2sq_w is not None:
            opts['cost'].append(('l2sq', FLAGS.l2sq_w))
        if FLAGS.sylvain_adv_c_w is not None and FLAGS.sylvain_emb_c_w is not None:
            adv_c_w = FLAGS.sylvain_adv_c_w
            emb_c_w = FLAGS.sylvain_emb_c_w
            opts['cost'].append(
                ('_sylvain_recon_loss_using_disc_conv', [adv_c_w, emb_c_w]))
            opts['cross_p_w'] = 0
            opts['diag_p_w'] = 0
        if FLAGS.adv_c_num_units is not None:
            opts['adv_c_num_units'] = FLAGS.adv_c_num_units
        if FLAGS.adv_c_patches_size is not None:
            opts['adv_c_patches_size'] = FLAGS.adv_c_patches_size
        if FLAGS.adv_use_sq is not None:
            opts['adv_use_sq'] = FLAGS.adv_use_sq

    if opts['verbose']:
        logging.basicConfig(level=logging.DEBUG,
                            format='%(asctime)s - %(message)s')
    logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')
    utils.create_dir(opts['work_dir'])
    utils.create_dir(os.path.join(opts['work_dir'], 'checkpoints'))

    if opts['e_noise'] == 'gaussian' and opts['pz'] != 'normal':
        assert False, 'Gaussian encoders compatible only with Gaussian prior'
        return

    # Dumping all the configs to the text file
    with utils.o_gfile((opts['work_dir'], 'params.txt'), 'w') as text:
        text.write('Parameters:\n')
        for key in opts:
            text.write('%s : %s\n' % (key, opts[key]))

    # Loading the dataset
    data = DataHandler(opts)
    assert data.num_points >= opts['batch_size'], 'Training set too small'

    # Creating WAE model
    wae = WAE(opts)
    if opts['mode'] == 'train':

        # Training WAE
        wae.train(data)

    elif opts['mode'] == 'test':

        # Do something else
        wae.test()
예제 #8
0
def main():

    # dataset = 'shuffle_mnist'
    dataset = 'disjoint_mnist'
    # dataset = 'cifar_10'
    # Load configs
    opts = configs.config_mnist
    # opts = configs.config_cifar
    task_num = opts['task_num']
    logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')
    # Load data
    data = setdata.set_data(dataset, task_num)
    logging.error(opts)
    # print (data[0].data.shape)
    # print (data[0].labels.shape)
    # print (data[0].labels[:5])
    # print (data[1].labels[:5])
    # print (data[0].test_labels[:5])
    # print (data[1].test_labels[:5])
    # print (data[0].test_data.shape)
    # print (data[0].test_labels.shape)
    # print (data[0].data_shape)

    # images = []
    # for i in range(3):
    #     img = np.reshape(data[i].data[:10], (10, 28, 28))
    #     images.append(np.concatenate(img, axis = 1))
    # img = np.concatenate([images[0], images[1], images[2]], axis = 0)
    # plt.imshow(img, cmap='Greys_r',interpolation='none', vmin=0., vmax=1)
    # plt.savefig("dataset_1.jpg")

    # tf.set_random_seed(1233)

    with tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions(
            allow_growth=True))) as sess:

        # Initialize model
        wae = WAE(opts, sess)
        # Initialize variables
        sess.run(wae.init)
        logging.error(tf.trainable_variables())
        logging.error(len(tf.trainable_variables()))
        # logging.error(wae.t_network_vars)

        # Initialize/Set former parameters as zeros
        random_z, pseudo_G_z, pseudo_T_z, w_old, b_old, f_g_z, pred = wae.former_init(
        )

        for i in range(len(data)):
            if i == 0:
                # Consider no trans loss and regularization loss in Task 0
                lambda_list = {
                    'wae_lambda': 1,
                    'rec_lambda': 0.1,
                    'trans_lambda': 0.0,
                    'reg_lambda': 0.0,
                    'f_lambda': 0.0,
                    'main_lambda': 1.0
                }
                logging.error("task " + str(i) + ":")
                logging.error(lambda_list)
                wae.train(data, i, random_z, pseudo_G_z, pseudo_T_z, w_old,
                          b_old, f_g_z, pred, lambda_list)
                wae.test(data, True)

                # Update/Save former parameters
                random_z, pseudo_G_z, pseudo_T_z, w_old, b_old, f_g_z, pred = wae.former(
                )
                print(type(random_z))
                print(type(pseudo_G_z))
                print(type(pseudo_T_z[-1]))
                print(pseudo_G_z.shape)

                # Print images generated by random z
                # idx = np.random.choice(400, 20, replace=False)
                # img_gen = np.reshape(pseudo_G_z[idx], (20, 28, 28))
                # img_gen = np.concatenate(img_gen, axis = 1)
                # plt.imshow(img_gen, cmap='Greys_r',interpolation='none', vmin=0., vmax=1)
                # plt.savefig("img_gen_after_task_%d_0.jpg" % i)
                # break

            else:
                lambda_list = {
                    'wae_lambda': 1.0,
                    'rec_lambda': 0.5 * opts['task_num'],
                    'trans_lambda': 0.01 * opts['task_num'],
                    'reg_lambda': 0.01,
                    'f_lambda': 1.0,
                    'main_lambda': 1.0
                }
                logging.error("task " + str(i) + ":")
                logging.error(lambda_list)
                wae.train(data, i, random_z, pseudo_G_z, pseudo_T_z, w_old,
                          b_old, f_g_z, pred, lambda_list)
                wae.test(data, True)

                random_z, pseudo_G_z, pseudo_T_z, w_old, b_old, f_g_z, pred = wae.former(
                )
예제 #9
0
파일: run.py 프로젝트: benoitgaujac/ss_swae
def main():

    # Select dataset to use
    if FLAGS.exp == 'celebA':
        opts = configs.config_celebA
    elif FLAGS.exp == 'celebA_small':
        opts = configs.config_celebA_small
    elif FLAGS.exp == 'mnist':
        opts = configs.config_mnist
    elif FLAGS.exp == 'mnist_small':
        opts = configs.config_mnist_small
    elif FLAGS.exp == 'dsprites':
        opts = configs.config_dsprites
    elif FLAGS.exp == 'grassli':
        opts = configs.config_grassli
    elif FLAGS.exp == 'grassli_small':
        opts = configs.config_grassli_small
    else:
        assert False, 'Unknown experiment configuration'

    # Select training method
    if FLAGS.method:
        opts['method'] = FLAGS.method

    # Working directory
    if FLAGS.work_dir:
        opts['work_dir'] = FLAGS.work_dir

    # Verbose
    if opts['verbose']:
        logging.basicConfig(level=logging.DEBUG,
                            format='%(asctime)s - %(message)s')
    logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')

    # Create directories
    utils.create_dir(opts['method'])
    work_dir = os.path.join(opts['method'], opts['work_dir'])
    utils.create_dir(work_dir)
    utils.create_dir(os.path.join(work_dir, 'checkpoints'))

    # Dumping all the configs to the text file
    with utils.o_gfile((work_dir, 'params.txt'), 'w') as text:
        text.write('Parameters:\n')
        for key in opts:
            text.write('%s : %s\n' % (key, opts[key]))

    # Loading the dataset
    data = DataHandler(opts)
    assert data.num_points >= opts['batch_size'], 'Training set too small'

    #Reset tf graph
    tf.reset_default_graph()

    # build WAE
    wae = WAE(opts)

    # Training/testing/vizu
    if FLAGS.mode == "train":
        wae.train(data, opts['work_dir'], FLAGS.weights_file)
    elif FLAGS.mode == "test":
        wae.test(data, opts['work_dir'], FLAGS.weights_file)
    elif FLAGS.mode == "reg":
        wae.reg(data, opts['work_dir'], FLAGS.weights_file)
    elif FLAGS.mode == "vizu":
        wae.vizu(data, opts['work_dir'], FLAGS.weights_file)
예제 #10
0
파일: evaluate.py 프로젝트: dntai/PAE
    if DATASET == 'grid':
        data_dist = Grid()
    elif DATASET == 'low_dim_embed':
        data_dist = LowDimEmbed()
    elif DATASET == 'color_mnist':
        data_dist = CMNIST(os.path.join('data', 'mnist'))
    elif DATASET == 'cifar_100':
        data_dist = CIFAR100(os.path.join('data', 'cifar-100'))
        NOISE_PERTURB = 0.005
    if METHOD == 'pae':
        from pae import PAE
        model = PAE(data_dist, noise_dist, flags, args)
    elif METHOD == 'wae':
        from wae import WAE
        model = WAE(data_dist, noise_dist, flags, args)
    elif METHOD == 'cougan':
        noise_dist = NormalNoise(flags.NOISE_DIM)
        from cougan import CoulombGAN
        model = CoulombGAN(data_dist, noise_dist, flags, args)
    elif METHOD == 'bigan':
        from bigan import BiGAN
        model = BiGAN(data_dist, noise_dist, flags, args)
    elif METHOD == 'veegan':
        if DATASET == 'grid':
            flags.NOISE_DIM = 254  # refer to author's implementation
            flags.GEN_ARCH[0] = flags.NOISE_DIM
            flags.ENC_ARCH[-1] = flags.NOISE_DIM
            flags.DISC_ARCH[0] = flags.NOISE_DIM + flags.DATA_DIM
            #flags.DISC_ARCH = [flags.NOISE_DIM+flags.DATA_DIM, 128, 1]
        elif DATASET == 'low_dim_embed':