示例#1
0
    config_fname = 'config.ini'
    n_runs = 10
    N = 1000  # generated sample size for each trained model
    subset_size = 10  # for computing DDP
    sample_times = 1000  # for computing DDP

    ###############################################################################
    # Plot diversity and quality scores
    print('Plotting diversity and quality scores ...')
    plt.rcParams.update({'font.size': 20})

    for (dataset, function) in list_data_function:

        example_name = '{}+{}'.format(dataset, function)
        _, _, _, _, _, lambda0_, lambda1_, _ = read_config(
            config_fname, example_name)

        # Data
        data_obj = getattr(datasets, dataset)(N)
        data = data_obj.data
        # Function
        func_obj = getattr(functions, function)()
        func = func_obj.evaluate

        div_data = diversity_score(data, subset_size, sample_times)
        qa_data = quality_score(data, func_obj)
        oa_data = overall_score(data, func_obj)

        list_div_lambdas = []
        list_qa_lambdas = []
        list_oa_lambdas = []
示例#2
0
if __name__ == "__main__":
    
    list_lambdas0 = [1.0, 2.0, 3.0]
    list_lambdas1 = [0.1, 0.2, 0.3]
    n_runs = 10
    
    def train(lambda0, lambda1, n_runs):
        for i in range(n_runs):
            png_path = './trained_gan/{}_{}/{}/synthesized.png'.format(lambda0, lambda1, i)
            if not os.path.exists(png_path):
                os.system('python run_experiment.py train --lambda0={} --lambda1={} --id={}'.format(lambda0, lambda1, i))
            else:
                os.system('python run_experiment.py evaluate --lambda0={} --lambda1={} --id={}'.format(lambda0, lambda1, i))
    
    config_fname = 'config.ini'
    latent_dim, noise_dim, bezier_degree, _, _, _, _, lambda0_, lambda1_, _ = read_config(config_fname)
    for lambda0 in list_lambdas0:
        lambda1 = lambda1_
        train(lambda0, lambda1, n_runs)
    for lambda1 in list_lambdas1:
        lambda0 = lambda0_
        train(lambda0, lambda1, n_runs)
                    
    ###############################################################################
    # Plot diversity and quality scores
    print('Plotting diversity and quality scores ...')
    plt.rcParams.update({'font.size': 20})
    
    # Read dataset
    data_fname = './data/xs_train.npy'
    X = np.load(data_fname)
示例#3
0
import os
import argparse
from run_experiment import read_config

if __name__ == "__main__":

    parser = argparse.ArgumentParser(description='Train')
    parser.add_argument('mode',
                        type=str,
                        default='train',
                        help='train or evaluate')
    args = parser.parse_args()
    assert args.mode in ['train', 'evaluate']

    list_models = ['GAN', 'PaDGAN']

    n_runs = 10
    config_fname = 'config.ini'
    for model_name in list_models:
        _, _, _, _, _, _, _, lambda0, lambda1, _ = read_config(config_fname)
        if model_name == 'GAN':
            lambda0, lambda1 = 0., 0.
        for i in range(n_runs):
            png_path = './trained_gan/{}_{}/{}/synthesized.png'.format(
                lambda0, lambda1, i)
            if not os.path.exists(png_path) or args.mode == 'evaluate':
                os.system(
                    'python run_experiment.py {} --lambda0={} --lambda1={} --id={}'
                    .format(args.mode, lambda0, lambda1, i))
示例#4
0
from matplotlib import cm
import matplotlib
import seaborn as sns
import pandas as pd

from run_experiment import read_config
from bezier_gan import BezierGAN
from simulation import evaluate
from evaluation import diversity_score
from shape_plot import plot_shape


if __name__ == "__main__":
    
    config_fname = 'config.ini'
    latent_dim, noise_dim, bezier_degree, train_steps, batch_size, disc_lr, gen_lr, lambda0_, lambda1_, save_interval = read_config(config_fname)
    bounds = (0., 1.)
    
    # Read dataset
    data_fname = './data/xs_train.npy'
    X = np.load(data_fname)
    N = X.shape[0]
    
    list_models = ['GAN', 'PaDGAN']
    
    ###############################################################################
    # Plot diversity and quality scores
    print('Plotting diversity and quality scores ...')
    plt.rcParams.update({'font.size': 20})
    
    n_runs = 10