コード例 #1
0
ファイル: run_fid.py プロジェクト: benoitgaujac/diswae
def main():

    # Select dataset to use
    if FLAGS.dataset == 'dsprites':
        opts = configs.config_dsprites
    elif FLAGS.dataset == 'noisydsprites':
        opts = configs.config_noisydsprites
    elif FLAGS.dataset == 'screamdsprites':
        opts = configs.config_screamdsprites
    elif FLAGS.dataset == 'smallNORB':
        opts = configs.config_smallNORB
    elif FLAGS.dataset == '3dshapes':
        opts = configs.config_3dshapes
    elif FLAGS.dataset == '3Dchairs':
        opts = configs.config_3Dchairs
    elif FLAGS.dataset == 'celebA':
        opts = configs.config_celebA
    elif FLAGS.dataset == 'mnist':
        opts = configs.config_mnist
    else:
        assert False, 'Unknown dataset'

    # Set method param
    opts['data_dir'] = FLAGS.data_dir
    opts['fid'] = True
    opts['network'] = net_configs[FLAGS.net_archi]

    # Model set up
    opts['model'] = FLAGS.model
    if FLAGS.dataset == 'celebA':
        opts['zdim'] = 32
    elif FLAGS.dataset == '3Dchairs':
        opts['zdim'] = 16
    else:
        opts['zdim'] = 10

    # Create directories
    opts['out_dir'] = FLAGS.out_dir
    out_subdir = os.path.join(opts['out_dir'], opts['model'])
    opts['exp_dir'] = os.path.join(out_subdir, FLAGS.res_dir)
    if not tf.io.gfile.isdir(opts['exp_dir']):
        raise Exception("Experiment doesn't exist!")

    #Reset tf graph
    tf.reset_default_graph()

    # Loading the dataset
    data = DataHandler(opts)
    assert data.train_size >= opts['batch_size'], 'Training set too small'

    # init method
    run = Run(opts, data)

    # get fid
    run.fid_score(opts['exp_dir'], FLAGS.weights_file, FLAGS.compute_stats, FLAGS.fid_inputs)
コード例 #2
0
ファイル: run.py プロジェクト: benoitgaujac/SWAE
def main():

    # dataset config
    if FLAGS.dataset == 'mnist':
        opts = configs.config_mnist
    elif FLAGS.dataset == 'svhn':
        opts = configs.config_svhn
    else:
        assert False, 'Unknown dataset'
    if FLAGS.data_dir:
        opts['data_dir'] = FLAGS.data_dir
    else:
        raise Exception('You must provide a data_dir')

    # Model set up
    opts['model'] = FLAGS.model
    opts['cost'] = FLAGS.cost
    opts['beta'] = FLAGS.beta
    opts['decoder'] = FLAGS.decoder
    opts['net_archi'] = FLAGS.net_archi

    # Create directories
    results_dir = 'results'
    if not tf.io.gfile.isdir(results_dir):
        utils.create_dir(results_dir)
    opts['out_dir'] = os.path.join(results_dir,FLAGS.out_dir)
    if not tf.io.gfile.isdir(opts['out_dir']):
        utils.create_dir(opts['out_dir'])
    out_subdir = os.path.join(opts['out_dir'], opts['model'])
    if not tf.io.gfile.isdir(out_subdir):
        utils.create_dir(out_subdir)
    exp_dir = os.path.join(out_subdir,
                           '{}_{}_{:%Y_%m_%d_%H_%M}'.format(
                                FLAGS.res_dir,
                                opts['beta'],
                                datetime.now()), )
    opts['exp_dir'] = exp_dir
    if not tf.io.gfile.isdir(exp_dir):
        utils.create_dir(exp_dir)
        utils.create_dir(os.path.join(exp_dir, 'checkpoints'))

    # Verbose
    logging.basicConfig(filename=os.path.join(opts['exp_dir'],'outputs.log'),
        level=logging.INFO, format='%(asctime)s - %(message)s')

    # Experiemnts set up
    opts['lr'] = FLAGS.lr
    opts['it_num'] = FLAGS.num_it
    opts['print_every'] = int(opts['it_num'] / 2.)
    opts['evaluate_every'] = int(opts['it_num'] / 4.)
    opts['save_every'] = 10000000000
    opts['save_final'] = FLAGS.save_model
    opts['save_train_data'] = FLAGS.save_data

    #Reset tf graph
    tf.reset_default_graph()

    # Loading the dataset
    data = DataHandler(opts)
    assert data.train_size >= opts['batch_size'], 'Training set too small'

    # inti method
    run = Run(opts, data)

    # Training/testing/vizu
    if FLAGS.mode=="train":
        # Dumping all the configs to the text file
        with utils.o_gfile((exp_dir, 'params.txt'), 'w') as text:
            text.write('Parameters:\n')
            for key in opts:
                text.write('%s : %s\n' % (key, opts[key]))
        run.train()
    else:
        assert False, 'Unknown mode %s' % FLAGS.mode
コード例 #3
0
ファイル: run_wae.py プロジェクト: benoitgaujac/tdwae
def main():

    # Select dataset to use
    if FLAGS.dataset == 'mnist':
        opts = configs.config_mnist
    elif FLAGS.dataset == 'smallNORB':
        opts = configs.config_smallNORB
    elif FLAGS.dataset == 'celebA':
        opts = configs.config_celebA
    else:
        assert False, 'Unknown dataset'

    # model
    opts['model'] = FLAGS.model
    opts['encoder'] = [
        FLAGS.encoder,
    ] * opts['nlatents']
    # opts['use_sigmoid'] = FLAGS.sigmoid
    opts['archi'] = [
        FLAGS.net_archi,
    ] * opts['nlatents']
    opts['obs_cost'] = FLAGS.cost
    opts['lambda_schedule'] = FLAGS.lmba_schedule
    opts['enc_sigma_pen'] = FLAGS.enc_sigma_pen
    opts['dec_sigma_pen'] = FLAGS.dec_sigma_pen

    # lamba
    lambda_rec = [0.01, 0.1]
    lamdba_match = [0.0001, 0.001]
    schedule = [
        'constant',
    ]
    sigmoid = [
        False,
    ]
    lmba = list(itertools.product(schedule, sigmoid, lambda_rec, lamdba_match))
    id = (FLAGS.id - 1) % len(lmba)
    sche, sig, lrec, lmatch = lmba[id][0], lmba[id][1], lmba[id][2], lmba[id][
        3]
    opts['lambda_schedule'] = sche
    opts['use_sigmoid'] = sig
    opts['lambda_init'] = [
        lrec * log(n + 1.0001) / opts['zdim'][n]
        for n in range(0, opts['nlatents'] - 1)
    ] + [
        lmatch / 100,
    ]
    opts['lambda'] = [
        lrec**(n + 1) / opts['zdim'][n] for n in range(0, opts['nlatents'] - 1)
    ] + [
        lmatch,
    ]

    # Create directories
    results_dir = 'results'
    if not tf.io.gfile.isdir(results_dir):
        utils.create_dir(results_dir)
    opts['out_dir'] = os.path.join(results_dir, FLAGS.out_dir)
    if not tf.io.gfile.isdir(opts['out_dir']):
        utils.create_dir(opts['out_dir'])
    out_subdir = os.path.join(opts['out_dir'], opts['model'])
    if not tf.io.gfile.isdir(out_subdir):
        utils.create_dir(out_subdir)
    out_subdir = os.path.join(out_subdir, 'l' + sche + '_sig' + str(sig))
    if not tf.io.gfile.isdir(out_subdir):
        utils.create_dir(out_subdir)
    opts['exp_dir'] = FLAGS.res_dir
    if opts['model'] == 'stackedwae':
        exp_dir = os.path.join(
            out_subdir, '{}_{}layers_lrec{}_lmatch{}_{:%Y_%m_%d_%H_%M}'.format(
                opts['exp_dir'], opts['nlatents'], lrec, lmatch,
                datetime.now()))
    else:
        exp_dir = os.path.join(
            out_subdir,
            '{}_lmatch{}_{:%Y_%m_%d_%H_%M}'.format(opts['exp_dir'], lmatch,
                                                   datetime.now()))
    opts['exp_dir'] = exp_dir
    if not tf.io.gfile.isdir(exp_dir):
        utils.create_dir(exp_dir)
        utils.create_dir(os.path.join(exp_dir, 'checkpoints'))

    # getting weights path
    if FLAGS.weights_file is not None:
        WEIGHTS_PATH = os.path.join(opts['exp_dir'], 'checkpoints',
                                    FLAGS.weights_file)
    else:
        WEIGHTS_PATH = None

    # Verbose
    logging.basicConfig(filename=os.path.join(exp_dir, 'outputs.log'),
                        level=logging.INFO,
                        format='%(asctime)s - %(message)s')

    # run set up
    opts['vizu_splitloss'] = FLAGS.losses
    opts['vizu_fullrec'] = FLAGS.reconstructions
    opts['vizu_embedded'] = FLAGS.embedded
    opts['vizu_latent'] = FLAGS.latents
    opts['vizu_pz_grid'] = FLAGS.grid
    opts['vizu_stochasticity'] = FLAGS.stoch
    opts['fid'] = FLAGS.fid
    opts['it_num'] = FLAGS.num_it
    opts['print_every'] = int(opts['it_num'] / 4)
    opts['evaluate_every'] = int(opts['it_num'] / 50)
    if FLAGS.batch_size is not None:
        opts['batch_size'] = FLAGS.batch_size
    opts['lr'] = FLAGS.lr
    opts['use_trained'] = FLAGS.use_trained
    opts['save_every'] = 10000000000
    opts['save_final'] = FLAGS.save_model
    opts['save_train_data'] = FLAGS.save_data

    #Reset tf graph
    tf.compat.v1.reset_default_graph()

    # Loading the dataset
    opts['data_dir'] = FLAGS.data_dir
    data = DataHandler(opts)
    assert data.train_size >= opts['batch_size'], 'Training set too small'

    # build model
    run = Run(opts, data)

    # Training/testing/vizu
    if FLAGS.mode == "train":
        # Dumping all the configs to the text file
        with utils.o_gfile((opts['exp_dir'], 'params.txt'), 'w') as text:
            text.write('Parameters:\n')
            for key in opts:
                text.write('%s : %s\n' % (key, opts[key]))
        run.train(WEIGHTS_PATH)
    elif FLAGS.mode == "vizu":
        opts['rec_loss_nsamples'] = 1
        opts['sample_recons'] = False
        run.latent_interpolation(opts['exp_dir'], WEIGHTS_PATH)
    elif FLAGS.mode == "fid":
        run.fid_score(WEIGHTS_PATH)
    elif FLAGS.mode == "test":
        run.test_losses(WEIGHTS_PATH)
    elif FLAGS.mode == "vlae_exp":
        run.vlae_experiment(WEIGHTS_PATH)
    else:
        assert False, 'Unknown mode %s' % FLAGS.mode
コード例 #4
0
ファイル: run.py プロジェクト: benoitgaujac/diswae
def main():

    # Select dataset to use
    if FLAGS.dataset == 'dsprites':
        opts = configs.config_dsprites
    elif FLAGS.dataset == 'noisydsprites':
        opts = configs.config_noisydsprites
    elif FLAGS.dataset == 'screamdsprites':
        opts = configs.config_screamdsprites
    elif FLAGS.dataset == 'smallNORB':
        opts = configs.config_smallNORB
    elif FLAGS.dataset == '3dshapes':
        opts = configs.config_3dshapes
    elif FLAGS.dataset == '3Dchairs':
        opts = configs.config_3Dchairs
    elif FLAGS.dataset == 'celebA':
        opts = configs.config_celebA
    elif FLAGS.dataset == 'mnist':
        opts = configs.config_mnist
    else:
        assert False, 'Unknown dataset'

    # Set method param
    opts['fid'] = FLAGS.fid
    opts['cost'] = FLAGS.cost  #l2, l2sq, l2sq_norm, l1, xentropy
    if FLAGS.net_archi:
        opts['network'] = net_configs[FLAGS.net_archi]
    else:
        if FLAGS.dataset == 'celebA':
            opts['network'] = net_configs['conv_rae']
        else:
            opts['network'] = net_configs['conv_locatello']
    # Model set up
    opts['model'] = FLAGS.model
    if FLAGS.dataset == 'celebA':
        opts['zdim'] = 32
    elif FLAGS.dataset == '3Dchairs':
        opts['zdim'] = 16
    else:
        opts['zdim'] = 10
    opts['lr'] = 0.0001

    # Objective Function Coefficients
    if opts['model'] in ['BetaTCVAE', 'FactorVAE']:
        opts['obj_fn_coeffs'] = FLAGS.beta
    elif opts['model'] in ['TCWAE_MWS', 'TCWAE_GAN']:
        opts['obj_fn_coeffs'] = [FLAGS.beta, FLAGS.gamma]
    else:
        raise Exception('Unknown {} model for {}'.format(
            opts['model'], FLAGS.dataset))

    # Create directories
    results_dir = 'results'
    if not tf.io.gfile.isdir(results_dir):
        utils.create_dir(results_dir)
    opts['out_dir'] = os.path.join(results_dir, FLAGS.out_dir)
    if not tf.io.gfile.isdir(opts['out_dir']):
        utils.create_dir(opts['out_dir'])
    out_subdir = os.path.join(opts['out_dir'], opts['model'])
    if not tf.io.gfile.isdir(out_subdir):
        utils.create_dir(out_subdir)
    opts['exp_dir'] = FLAGS.res_dir
    if opts['model'] == 'disWAE' or opts['model'] == 'TCWAE_MWS' or opts[
            'model'] == 'TCWAE_GAN':
        exp_dir = os.path.join(
            out_subdir,
            '{}_{}_{}_{:%Y_%m_%d_%H_%M}'.format(opts['exp_dir'],
                                                opts['obj_fn_coeffs'][0],
                                                opts['obj_fn_coeffs'][1],
                                                datetime.now()),
        )
    else:
        exp_dir = os.path.join(
            out_subdir,
            '{}_{}_{:%Y_%m_%d_%H_%M}'.format(opts['exp_dir'],
                                             opts['obj_fn_coeffs'],
                                             datetime.now()),
        )
    opts['exp_dir'] = exp_dir
    if not tf.io.gfile.isdir(exp_dir):
        utils.create_dir(exp_dir)
        utils.create_dir(os.path.join(exp_dir, 'checkpoints'))

    # Verbose
    logging.basicConfig(filename=os.path.join(exp_dir, 'outputs.log'),
                        level=logging.INFO,
                        format='%(asctime)s - %(message)s')

    # Experiemnts set up
    opts['it_num'] = FLAGS.num_it
    opts['print_every'] = int(opts['it_num'] / 2.)
    opts['evaluate_every'] = int(opts['it_num'] / 4.)
    opts['save_every'] = 10000000000
    opts['save_final'] = FLAGS.save_model
    opts['save_train_data'] = FLAGS.save_data
    opts['vizu_encSigma'] = False

    #Reset tf graph
    tf.reset_default_graph()

    # Loading the dataset
    opts['data_dir'] = FLAGS.data_dir
    opts['stage_to_scratch'] = FLAGS.stage_to_scratch
    opts['scratch_dir'] = FLAGS.scratch_dir
    data = DataHandler(opts)
    assert data.train_size >= opts['batch_size'], 'Training set too small'

    # inti method
    run = Run(opts, data)

    # Training/testing/vizu
    if FLAGS.mode == "train":
        # Dumping all the configs to the text file
        with utils.o_gfile((exp_dir, 'params.txt'), 'w') as text:
            text.write('Parameters:\n')
            for key in opts:
                text.write('%s : %s\n' % (key, opts[key]))
        run.train()
    else:
        assert False, 'Unknown mode %s' % FLAGS.mode
コード例 #5
0
def main():

    # Select dataset to use
    if FLAGS.dataset == 'dsprites':
        opts = configs.config_dsprites
    elif FLAGS.dataset == 'noisydsprites':
        opts = configs.config_noisydsprites
    elif FLAGS.dataset == 'screamdsprites':
        opts = configs.config_screamdsprites
    elif FLAGS.dataset == 'smallNORB':
        opts = configs.config_smallNORB
    elif FLAGS.dataset == '3dshapes':
        opts = configs.config_3dshapes
    elif FLAGS.dataset == '3Dchairs':
        opts = configs.config_3Dchairs
    elif FLAGS.dataset == 'celebA':
        opts = configs.config_celebA
    elif FLAGS.dataset == 'mnist':
        opts = configs.config_mnist
    else:
        assert False, 'Unknown dataset'

    # Set method param
    opts['fid'] = FLAGS.fid
    opts['cost'] = FLAGS.cost  #l2, l2sq, l2sq_norm, l1, xentropy
    if FLAGS.net_archi:
        opts['network'] = net_configs[FLAGS.net_archi]
    else:
        if FLAGS.dataset == 'celebA':
            opts['network'] = net_configs['conv_rae']
        else:
            opts['network'] = net_configs['conv_locatello']
    # Model set up
    opts['model'] = FLAGS.model
    if FLAGS.dataset == 'celebA':
        opts['zdim'] = 32
    elif FLAGS.dataset == '3Dchairs':
        opts['zdim'] = 16
    else:
        opts['zdim'] = 10
    opts['lr'] = 0.0001

    # Objective Function Coefficients
    if FLAGS.dataset == 'celebA':
        if opts['model'] == 'BetaTCVAE':
            beta = [1, 2, 4, 6, 8, 10]
            coef_id = (FLAGS.id - 1) % len(beta)
            opts['obj_fn_coeffs'] = beta[coef_id]
        elif opts['model'] == 'FactorVAE':
            beta = [1, 2, 4, 6, 8, 10]
            # beta = [1, 5, 10, 25, 50, 100]
            coef_id = (FLAGS.id - 1) % len(beta)
            opts['obj_fn_coeffs'] = beta[coef_id]
        elif opts['model'] == 'TCWAE_MWS':
            # beta = [0.1, 0.25, 0.5, 1., 2., 5.]
            # gamma = [0.1, 0.25, 0.5, 1., 2., 5.]
            beta = [1., 2., 5.]
            gamma = [0.1, 0.25, 0.5]
            lmba = list(itertools.product(beta, gamma))
            coef_id = (FLAGS.id - 1) % len(lmba)
            opts['obj_fn_coeffs'] = list(lmba[coef_id])
        elif opts['model'] == 'TCWAE_GAN':
            # beta = [0.1, 0.25, 0.5, 1., 2., 5.]
            # gamma = [0.1, 0.25, 0.5, 1., 2., 5.]
            beta = [1., 2., 5., 10., 15.]
            gamma = [1., 2., 5.]  #
            lmba = list(itertools.product(beta, gamma))
            coef_id = (FLAGS.id - 1) % len(lmba)
            opts['obj_fn_coeffs'] = list(lmba[coef_id])
        else:
            raise Exception('Unknown {} model for celebA'.format(
                opts['model']))
    elif FLAGS.dataset == '3Dchairs':
        if opts['model'] == 'BetaTCVAE':
            beta = [1, 2, 4, 6, 8, 10]
            coef_id = (FLAGS.id - 1) % len(beta)
            opts['obj_fn_coeffs'] = beta[coef_id]
        elif opts['model'] == 'FactorVAE':
            beta = [1., 2., 4.]
            # beta = [1., 5., 10., 25., 50., 100.]
            coef_id = (FLAGS.id - 1) % len(beta)
            opts['obj_fn_coeffs'] = beta[coef_id]
        elif opts['model'] == 'TCWAE_MWS':
            # beta = [0.1, 0.25, 0.5, 1., 2., 5.]
            # gamma = [0.1, 0.25, 0.5, 1., 2., 5.]
            beta = [1., 2., 5.]
            gamma = [0.1, 0.25, 0.5]
            lmba = list(itertools.product(beta, gamma))
            coef_id = (FLAGS.id - 1) % len(lmba)
            opts['obj_fn_coeffs'] = list(lmba[coef_id])
        elif opts['model'] == 'TCWAE_GAN':
            # beta = [0.1, 0.5, 1., 2., 5., 10.]
            # gamma = [0.1, 0.5, 1., 2., 5., 10.]
            beta = [0.1, 0.5]
            gamma = [0.1, 0.5, 1.]
            lmba = list(itertools.product(beta, gamma))
            coef_id = (FLAGS.id - 1) % len(lmba)
            opts['obj_fn_coeffs'] = list(lmba[coef_id])
        else:
            raise Exception('Unknown {} model for celebA'.format(
                opts['model']))
    else:
        if opts['model'] == 'BetaTCVAE':
            beta = [1, 2, 4, 6, 8, 10]
            coef_id = (FLAGS.id - 1) % len(beta)
            opts['obj_fn_coeffs'] = beta[coef_id]
        elif opts['model'] == 'FactorVAE':
            beta = [1, 10, 25, 50, 75, 100]
            coef_id = (FLAGS.id - 1) % len(beta)
            opts['obj_fn_coeffs'] = beta[coef_id]
        elif opts['model'] == 'WAE':
            beta = [1, 5, 10, 25, 50, 100]
            coef_id = (FLAGS.id - 1) % len(beta)
            opts['obj_fn_coeffs'] = beta[coef_id]
        elif opts['model'] == 'TCWAE_MWS':
            if opts['cost'] == 'xent':
                beta = [1, 2, 4, 6, 8, 10]
                gamma = [1, 2, 4, 6, 8, 10]
            else:
                beta = [2., 5.]
                gamma = [0.1, 0.25, 0.5, 0.75, 1., 2., 5.]
                lmba1 = list(itertools.product(beta, gamma))
                lmba2 = list(itertools.product(gamma, beta))
                lmba = lmba1 + lmba2
                # beta = [0.1, 0.25, 0.5, 0.75, 1, 2]
                # gamma = [0.1, 0.25, 0.5, 0.75, 1, 2]
            # lmba = list(itertools.product(beta,gamma))
            coef_id = (FLAGS.id - 1) % len(lmba)
            opts['obj_fn_coeffs'] = list(lmba[coef_id])
        elif opts['model'] == 'TCWAE_GAN':
            if opts['cost'] == 'xent':
                beta = [1, 10, 25, 50, 75, 100]
                gamma = [1, 10, 25, 50, 75, 100]
            else:
                # beta = [0.5,20.]
                # gamma = [0.1, 0.5, 1., 2.5, 5., 7.5, 10.0, 20.]
                # lmba1 =  list(itertools.product(beta,gamma))
                # lmba2 =  list(itertools.product(gamma,beta))
                # lmba = lmba1+lmba2
                beta = [0.5, 1, 2.5, 5, 7.5, 10]
                gamma = [0.5, 1.0, 2.5, 5.0, 7.5, 10.0]
            lmba = list(itertools.product(beta, gamma))
            coef_id = (FLAGS.id - 1) % len(lmba)
            opts['obj_fn_coeffs'] = list(lmba[coef_id])
        elif opts['model'] == 'TCWAE_MWS_MI':
            beta = [0.1, 0.25, 0.5, 0.75, 1, 10]
            coef_id = (FLAGS.id - 1) % len(beta)
            opts['obj_fn_coeffs'] = beta[coef_id]
        elif opts['model'] == 'TCWAE_GAN_MI':
            beta = [1, 2, 4, 6, 8, 10]
            coef_id = (FLAGS.id - 1) % len(beta)
            opts['obj_fn_coeffs'] = beta[coef_id]
        else:
            raise NotImplementedError('Model type not recognised')

    # Create directories
    results_dir = 'results'
    if not tf.io.gfile.isdir(results_dir):
        utils.create_dir(results_dir)
    opts['out_dir'] = os.path.join(results_dir, FLAGS.out_dir)
    if not tf.io.gfile.isdir(opts['out_dir']):
        utils.create_dir(opts['out_dir'])
    out_subdir = os.path.join(opts['out_dir'], opts['model'])
    if not tf.io.gfile.isdir(out_subdir):
        utils.create_dir(out_subdir)
    opts['exp_dir'] = FLAGS.res_dir
    if opts['model'] == 'disWAE' or opts['model'] == 'TCWAE_MWS' or opts[
            'model'] == 'TCWAE_GAN':
        exp_dir = os.path.join(
            out_subdir,
            '{}_{}_{}_{:%Y_%m_%d_%H_%M}'.format(opts['exp_dir'],
                                                opts['obj_fn_coeffs'][0],
                                                opts['obj_fn_coeffs'][1],
                                                datetime.now()),
        )
    else:
        exp_dir = os.path.join(
            out_subdir,
            '{}_{}_{:%Y_%m_%d_%H_%M}'.format(opts['exp_dir'],
                                             opts['obj_fn_coeffs'],
                                             datetime.now()),
        )
    opts['exp_dir'] = exp_dir
    if not tf.io.gfile.isdir(exp_dir):
        utils.create_dir(exp_dir)
        utils.create_dir(os.path.join(exp_dir, 'checkpoints'))

    # Verbose
    logging.basicConfig(filename=os.path.join(exp_dir, 'outputs.log'),
                        level=logging.INFO,
                        format='%(asctime)s - %(message)s')

    # Experiemnts set up
    opts['it_num'] = FLAGS.num_it
    opts['print_every'] = int(opts['it_num'] / 2.)
    opts['evaluate_every'] = int(opts['it_num'] / 4.)
    opts['save_every'] = 10000000000
    opts['save_final'] = FLAGS.save_model
    opts['save_train_data'] = FLAGS.save_data
    opts['vizu_encSigma'] = False

    #Reset tf graph
    tf.reset_default_graph()

    # Loading the dataset
    opts['data_dir'] = FLAGS.data_dir
    opts['stage_to_scratch'] = FLAGS.stage_to_scratch
    opts['scratch_dir'] = FLAGS.scratch_dir
    data = DataHandler(opts)
    assert data.train_size >= opts['batch_size'], 'Training set too small'

    # inti method
    run = Run(opts, data)

    # Training/testing/vizu
    if FLAGS.mode == "train":
        # Dumping all the configs to the text file
        with utils.o_gfile((exp_dir, 'params.txt'), 'w') as text:
            text.write('Parameters:\n')
            for key in opts:
                text.write('%s : %s\n' % (key, opts[key]))
        run.train()
    else:
        assert False, 'Unknown mode %s' % FLAGS.mode
コード例 #6
0
ファイル: run_vae.py プロジェクト: benoitgaujac/tdwae
def main():

    # Select dataset to use
    if FLAGS.dataset == 'mnist':
        opts = configs.config_mnist
    elif FLAGS.dataset == 'smallNORB':
        opts = configs.config_smallNORB
    elif FLAGS.dataset == 'celebA':
        opts = configs.config_celebA
    else:
        assert False, 'Unknown dataset'

    # model
    opts['model'] = FLAGS.model
    opts['encoder'] = [
        FLAGS.encoder,
    ] * opts['nlatents']
    opts['use_sigmoid'] = FLAGS.sigmoid
    opts['archi'] = [
        FLAGS.net_archi,
    ] * opts['nlatents']
    opts['obs_cost'] = FLAGS.cost
    opts['lambda_schedule'] = FLAGS.lmba_schedule
    opts['enc_sigma_pen'] = FLAGS.enc_sigma_pen
    opts['dec_sigma_pen'] = FLAGS.dec_sigma_pen

    # opts['nlatents'] = 1
    # zdims = [2,4,8,16]
    # id = (FLAGS.id-1) % len(zdims)
    # opts['zdim'] = [zdims[id],]
    # opts['lambda_init'] = [1,]
    # opts['lambda'] = [1.,]
    # beta = opts['lambda']
    # opts['lambda_sigma'] = [1.,]

    # lamba
    beta = [0.0001, 1.]
    id = (FLAGS.id - 1) % len(beta)
    opts['lambda_init'] = [beta[id] for n in range(opts['nlatents'])]
    opts['lambda'] = [1. for n in range(opts['nlatents'])]

    # Create directories
    results_dir = 'results'
    if not tf.io.gfile.isdir(results_dir):
        utils.create_dir(results_dir)
    opts['out_dir'] = os.path.join(results_dir, FLAGS.out_dir)
    if not tf.io.gfile.isdir(opts['out_dir']):
        utils.create_dir(opts['out_dir'])
    out_subdir = os.path.join(opts['out_dir'], opts['model'])
    if not tf.io.gfile.isdir(out_subdir):
        utils.create_dir(out_subdir)
    # out_subdir = os.path.join(out_subdir, 'dz'+str(zdims[id]))
    # if not tf.io.gfile.isdir(out_subdir):
    #     utils.create_dir(out_subdir)
    opts['exp_dir'] = FLAGS.res_dir
    exp_dir = os.path.join(
        out_subdir,
        '{}_{}layers_lreg{}_{:%Y_%m_%d_%H_%M}'.format(opts['exp_dir'],
                                                      opts['nlatents'],
                                                      beta[id],
                                                      datetime.now()))
    opts['exp_dir'] = exp_dir
    if not tf.io.gfile.isdir(exp_dir):
        utils.create_dir(exp_dir)
        utils.create_dir(os.path.join(exp_dir, 'checkpoints'))

    # getting weights path
    if FLAGS.weights_file is not None:
        WEIGHTS_PATH = os.path.join(opts['exp_dir'], 'checkpoints',
                                    FLAGS.weights_file)
    else:
        WEIGHTS_PATH = None

    # Verbose
    logging.basicConfig(filename=os.path.join(exp_dir, 'outputs.log'),
                        level=logging.INFO,
                        format='%(asctime)s - %(message)s')

    # run set up
    opts['vizu_splitloss'] = FLAGS.losses
    opts['vizu_fullrec'] = FLAGS.reconstructions
    opts['vizu_embedded'] = FLAGS.embedded
    opts['vizu_latent'] = FLAGS.latents
    opts['fid'] = FLAGS.fid
    opts['it_num'] = FLAGS.num_it
    opts['print_every'] = int(opts['it_num'] / 4)
    opts['evaluate_every'] = int(opts['it_num'] / 50)
    if FLAGS.batch_size is not None:
        opts['batch_size'] = FLAGS.batch_size
    opts['lr'] = FLAGS.lr
    opts['use_trained'] = FLAGS.use_trained
    opts['save_every'] = 10000000000
    opts['save_final'] = FLAGS.save_model
    opts['save_train_data'] = FLAGS.save_data

    #Reset tf graph
    tf.compat.v1.reset_default_graph()

    # Loading the dataset
    opts['data_dir'] = FLAGS.data_dir
    data = DataHandler(opts)
    assert data.train_size >= opts['batch_size'], 'Training set too small'

    # build model
    run = Run(opts, data)

    # Training/testing/vizu
    if FLAGS.mode == "train":
        # Dumping all the configs to the text file
        with utils.o_gfile((opts['exp_dir'], 'params.txt'), 'w') as text:
            text.write('Parameters:\n')
            for key in opts:
                text.write('%s : %s\n' % (key, opts[key]))
        run.train(WEIGHTS_PATH)
    elif FLAGS.mode == "vizu":
        opts['rec_loss_nsamples'] = 1
        opts['sample_recons'] = False
        run.latent_interpolation(opts['exp_dir'], WEIGHTS_PATH)
    elif FLAGS.mode == "fid":
        run.fid_score(WEIGHTS_PATH)
    elif FLAGS.mode == "test":
        run.test_losses(WEIGHTS_PATH)
    elif FLAGS.mode == "vlae_exp":
        run.vlae_experiment(WEIGHTS_PATH)
    else:
        assert False, 'Unknown mode %s' % FLAGS.mode