示例#1
0
def main():

    # Select dataset to use
    if FLAGS.dataset == 'dsprites':
        opts = configs.config_dsprites
    elif FLAGS.dataset == 'noisydsprites':
        opts = configs.config_noisydsprites
    elif FLAGS.dataset == 'screamdsprites':
        opts = configs.config_screamdsprites
    elif FLAGS.dataset == 'smallNORB':
        opts = configs.config_smallNORB
    elif FLAGS.dataset == '3dshapes':
        opts = configs.config_3dshapes
    elif FLAGS.dataset == '3Dchairs':
        opts = configs.config_3Dchairs
    elif FLAGS.dataset == 'celebA':
        opts = configs.config_celebA
    elif FLAGS.dataset == 'mnist':
        opts = configs.config_mnist
    else:
        assert False, 'Unknown dataset'

    # Set method param
    opts['data_dir'] = FLAGS.data_dir
    opts['fid'] = True
    opts['network'] = net_configs[FLAGS.net_archi]

    # Model set up
    opts['model'] = FLAGS.model
    if FLAGS.dataset == 'celebA':
        opts['zdim'] = 32
    elif FLAGS.dataset == '3Dchairs':
        opts['zdim'] = 16
    else:
        opts['zdim'] = 10

    # Create directories
    opts['out_dir'] = FLAGS.out_dir
    out_subdir = os.path.join(opts['out_dir'], opts['model'])
    opts['exp_dir'] = os.path.join(out_subdir, FLAGS.res_dir)
    if not tf.io.gfile.isdir(opts['exp_dir']):
        raise Exception("Experiment doesn't exist!")

    #Reset tf graph
    tf.reset_default_graph()

    # Loading the dataset
    data = DataHandler(opts)
    assert data.train_size >= opts['batch_size'], 'Training set too small'

    # init method
    run = Run(opts, data)

    # get fid
    run.fid_score(opts['exp_dir'], FLAGS.weights_file, FLAGS.compute_stats, FLAGS.fid_inputs)
示例#2
0
def main():

    # Select dataset to use
    if FLAGS.dataset == 'mnist':
        opts = configs.config_mnist
    elif FLAGS.dataset == 'smallNORB':
        opts = configs.config_smallNORB
    elif FLAGS.dataset == 'celebA':
        opts = configs.config_celebA
    else:
        assert False, 'Unknown dataset'

    # model
    opts['model'] = FLAGS.model
    opts['encoder'] = [
        FLAGS.encoder,
    ] * opts['nlatents']
    # opts['use_sigmoid'] = FLAGS.sigmoid
    opts['archi'] = [
        FLAGS.net_archi,
    ] * opts['nlatents']
    opts['obs_cost'] = FLAGS.cost
    opts['lambda_schedule'] = FLAGS.lmba_schedule
    opts['enc_sigma_pen'] = FLAGS.enc_sigma_pen
    opts['dec_sigma_pen'] = FLAGS.dec_sigma_pen

    # lamba
    lambda_rec = [0.01, 0.1]
    lamdba_match = [0.0001, 0.001]
    schedule = [
        'constant',
    ]
    sigmoid = [
        False,
    ]
    lmba = list(itertools.product(schedule, sigmoid, lambda_rec, lamdba_match))
    id = (FLAGS.id - 1) % len(lmba)
    sche, sig, lrec, lmatch = lmba[id][0], lmba[id][1], lmba[id][2], lmba[id][
        3]
    opts['lambda_schedule'] = sche
    opts['use_sigmoid'] = sig
    opts['lambda_init'] = [
        lrec * log(n + 1.0001) / opts['zdim'][n]
        for n in range(0, opts['nlatents'] - 1)
    ] + [
        lmatch / 100,
    ]
    opts['lambda'] = [
        lrec**(n + 1) / opts['zdim'][n] for n in range(0, opts['nlatents'] - 1)
    ] + [
        lmatch,
    ]

    # Create directories
    results_dir = 'results'
    if not tf.io.gfile.isdir(results_dir):
        utils.create_dir(results_dir)
    opts['out_dir'] = os.path.join(results_dir, FLAGS.out_dir)
    if not tf.io.gfile.isdir(opts['out_dir']):
        utils.create_dir(opts['out_dir'])
    out_subdir = os.path.join(opts['out_dir'], opts['model'])
    if not tf.io.gfile.isdir(out_subdir):
        utils.create_dir(out_subdir)
    out_subdir = os.path.join(out_subdir, 'l' + sche + '_sig' + str(sig))
    if not tf.io.gfile.isdir(out_subdir):
        utils.create_dir(out_subdir)
    opts['exp_dir'] = FLAGS.res_dir
    if opts['model'] == 'stackedwae':
        exp_dir = os.path.join(
            out_subdir, '{}_{}layers_lrec{}_lmatch{}_{:%Y_%m_%d_%H_%M}'.format(
                opts['exp_dir'], opts['nlatents'], lrec, lmatch,
                datetime.now()))
    else:
        exp_dir = os.path.join(
            out_subdir,
            '{}_lmatch{}_{:%Y_%m_%d_%H_%M}'.format(opts['exp_dir'], lmatch,
                                                   datetime.now()))
    opts['exp_dir'] = exp_dir
    if not tf.io.gfile.isdir(exp_dir):
        utils.create_dir(exp_dir)
        utils.create_dir(os.path.join(exp_dir, 'checkpoints'))

    # getting weights path
    if FLAGS.weights_file is not None:
        WEIGHTS_PATH = os.path.join(opts['exp_dir'], 'checkpoints',
                                    FLAGS.weights_file)
    else:
        WEIGHTS_PATH = None

    # Verbose
    logging.basicConfig(filename=os.path.join(exp_dir, 'outputs.log'),
                        level=logging.INFO,
                        format='%(asctime)s - %(message)s')

    # run set up
    opts['vizu_splitloss'] = FLAGS.losses
    opts['vizu_fullrec'] = FLAGS.reconstructions
    opts['vizu_embedded'] = FLAGS.embedded
    opts['vizu_latent'] = FLAGS.latents
    opts['vizu_pz_grid'] = FLAGS.grid
    opts['vizu_stochasticity'] = FLAGS.stoch
    opts['fid'] = FLAGS.fid
    opts['it_num'] = FLAGS.num_it
    opts['print_every'] = int(opts['it_num'] / 4)
    opts['evaluate_every'] = int(opts['it_num'] / 50)
    if FLAGS.batch_size is not None:
        opts['batch_size'] = FLAGS.batch_size
    opts['lr'] = FLAGS.lr
    opts['use_trained'] = FLAGS.use_trained
    opts['save_every'] = 10000000000
    opts['save_final'] = FLAGS.save_model
    opts['save_train_data'] = FLAGS.save_data

    #Reset tf graph
    tf.compat.v1.reset_default_graph()

    # Loading the dataset
    opts['data_dir'] = FLAGS.data_dir
    data = DataHandler(opts)
    assert data.train_size >= opts['batch_size'], 'Training set too small'

    # build model
    run = Run(opts, data)

    # Training/testing/vizu
    if FLAGS.mode == "train":
        # Dumping all the configs to the text file
        with utils.o_gfile((opts['exp_dir'], 'params.txt'), 'w') as text:
            text.write('Parameters:\n')
            for key in opts:
                text.write('%s : %s\n' % (key, opts[key]))
        run.train(WEIGHTS_PATH)
    elif FLAGS.mode == "vizu":
        opts['rec_loss_nsamples'] = 1
        opts['sample_recons'] = False
        run.latent_interpolation(opts['exp_dir'], WEIGHTS_PATH)
    elif FLAGS.mode == "fid":
        run.fid_score(WEIGHTS_PATH)
    elif FLAGS.mode == "test":
        run.test_losses(WEIGHTS_PATH)
    elif FLAGS.mode == "vlae_exp":
        run.vlae_experiment(WEIGHTS_PATH)
    else:
        assert False, 'Unknown mode %s' % FLAGS.mode
示例#3
0
def main():

    # Select dataset to use
    if FLAGS.dataset == 'mnist':
        opts = configs.config_mnist
    elif FLAGS.dataset == 'smallNORB':
        opts = configs.config_smallNORB
    elif FLAGS.dataset == 'celebA':
        opts = configs.config_celebA
    else:
        assert False, 'Unknown dataset'

    # model
    opts['model'] = FLAGS.model
    opts['encoder'] = [
        FLAGS.encoder,
    ] * opts['nlatents']
    opts['use_sigmoid'] = FLAGS.sigmoid
    opts['archi'] = [
        FLAGS.net_archi,
    ] * opts['nlatents']
    opts['obs_cost'] = FLAGS.cost
    opts['lambda_schedule'] = FLAGS.lmba_schedule
    opts['enc_sigma_pen'] = FLAGS.enc_sigma_pen
    opts['dec_sigma_pen'] = FLAGS.dec_sigma_pen

    # opts['nlatents'] = 1
    # zdims = [2,4,8,16]
    # id = (FLAGS.id-1) % len(zdims)
    # opts['zdim'] = [zdims[id],]
    # opts['lambda_init'] = [1,]
    # opts['lambda'] = [1.,]
    # beta = opts['lambda']
    # opts['lambda_sigma'] = [1.,]

    # lamba
    beta = [0.0001, 1.]
    id = (FLAGS.id - 1) % len(beta)
    opts['lambda_init'] = [beta[id] for n in range(opts['nlatents'])]
    opts['lambda'] = [1. for n in range(opts['nlatents'])]

    # Create directories
    results_dir = 'results'
    if not tf.io.gfile.isdir(results_dir):
        utils.create_dir(results_dir)
    opts['out_dir'] = os.path.join(results_dir, FLAGS.out_dir)
    if not tf.io.gfile.isdir(opts['out_dir']):
        utils.create_dir(opts['out_dir'])
    out_subdir = os.path.join(opts['out_dir'], opts['model'])
    if not tf.io.gfile.isdir(out_subdir):
        utils.create_dir(out_subdir)
    # out_subdir = os.path.join(out_subdir, 'dz'+str(zdims[id]))
    # if not tf.io.gfile.isdir(out_subdir):
    #     utils.create_dir(out_subdir)
    opts['exp_dir'] = FLAGS.res_dir
    exp_dir = os.path.join(
        out_subdir,
        '{}_{}layers_lreg{}_{:%Y_%m_%d_%H_%M}'.format(opts['exp_dir'],
                                                      opts['nlatents'],
                                                      beta[id],
                                                      datetime.now()))
    opts['exp_dir'] = exp_dir
    if not tf.io.gfile.isdir(exp_dir):
        utils.create_dir(exp_dir)
        utils.create_dir(os.path.join(exp_dir, 'checkpoints'))

    # getting weights path
    if FLAGS.weights_file is not None:
        WEIGHTS_PATH = os.path.join(opts['exp_dir'], 'checkpoints',
                                    FLAGS.weights_file)
    else:
        WEIGHTS_PATH = None

    # Verbose
    logging.basicConfig(filename=os.path.join(exp_dir, 'outputs.log'),
                        level=logging.INFO,
                        format='%(asctime)s - %(message)s')

    # run set up
    opts['vizu_splitloss'] = FLAGS.losses
    opts['vizu_fullrec'] = FLAGS.reconstructions
    opts['vizu_embedded'] = FLAGS.embedded
    opts['vizu_latent'] = FLAGS.latents
    opts['fid'] = FLAGS.fid
    opts['it_num'] = FLAGS.num_it
    opts['print_every'] = int(opts['it_num'] / 4)
    opts['evaluate_every'] = int(opts['it_num'] / 50)
    if FLAGS.batch_size is not None:
        opts['batch_size'] = FLAGS.batch_size
    opts['lr'] = FLAGS.lr
    opts['use_trained'] = FLAGS.use_trained
    opts['save_every'] = 10000000000
    opts['save_final'] = FLAGS.save_model
    opts['save_train_data'] = FLAGS.save_data

    #Reset tf graph
    tf.compat.v1.reset_default_graph()

    # Loading the dataset
    opts['data_dir'] = FLAGS.data_dir
    data = DataHandler(opts)
    assert data.train_size >= opts['batch_size'], 'Training set too small'

    # build model
    run = Run(opts, data)

    # Training/testing/vizu
    if FLAGS.mode == "train":
        # Dumping all the configs to the text file
        with utils.o_gfile((opts['exp_dir'], 'params.txt'), 'w') as text:
            text.write('Parameters:\n')
            for key in opts:
                text.write('%s : %s\n' % (key, opts[key]))
        run.train(WEIGHTS_PATH)
    elif FLAGS.mode == "vizu":
        opts['rec_loss_nsamples'] = 1
        opts['sample_recons'] = False
        run.latent_interpolation(opts['exp_dir'], WEIGHTS_PATH)
    elif FLAGS.mode == "fid":
        run.fid_score(WEIGHTS_PATH)
    elif FLAGS.mode == "test":
        run.test_losses(WEIGHTS_PATH)
    elif FLAGS.mode == "vlae_exp":
        run.vlae_experiment(WEIGHTS_PATH)
    else:
        assert False, 'Unknown mode %s' % FLAGS.mode