示例#1
0
def train(configPath):
    with open(configPath, 'r') as f:
        d = json.load(f)

        train_dataset = d['train_dataset']
        train_Folder = d['train_Folder']
        val_dataset = d['val_dataset']
        output_size = d['output_size']
        categories = d['categories']
        batch_size = d['batch_size']

        if d.has_key('improvedGan'):
            improvedGan = d['improvedGan']
        else:
            improvedGan = False
        print("Using improvedGAN ", improvedGan)
        if d.has_key('exp_name'):
            exp_name = d['exp_name']
        else:
            filename = os.path.split(configPath)[-1]
            assert (len(filename.split(".")) == 2)
            exp_name = filename.split(".")[0]

        arch = d['arch']

        if d.has_key('semiSup'):
            semiSup = d['semiSup']
        else:
            semiSup = False
        if d.has_key('trainSplit'):
            trainSplit = d['trainSplit']
        else:
            trainSplit = 0.7

    now = datetime.datetime.now(dateutil.tz.tzlocal())
    timestamp = now.strftime('%Y_%m_%d_%H_%M_%S')

    root_log_dir = "logs/" + train_dataset
    root_checkpoint_dir = "ckt/" + train_dataset
    root_samples_dir = "samples/" + train_dataset
    updates_per_epoch = 100
    max_epoch = 50

    if exp_name is None:
        exp_name = "t-%s_v-%s_o-%d" % (train_dataset, val_dataset, output_size)
        if not (categories is None):
            exp_name = exp_name + "_c-%d" % (categories)
        exp_name = exp_name + "_%s" % (timestamp)

    print("Experiment Name: %s" % (exp_name))

    log_dir = os.path.join(root_log_dir, exp_name)
    checkpoint_dir = os.path.join(root_checkpoint_dir, exp_name)
    samples_dir = os.path.join(root_samples_dir, exp_name)

    mkdir_p(log_dir)
    mkdir_p(checkpoint_dir)
    mkdir_p(samples_dir)

    output_dist = None
    network_type = arch
    if train_dataset == "mnist":
        print("Creating train dataset ")
        dataset = datasets.MnistDataset()
        output_dist = MeanBernoulli(dataset.image_dim)
        print("CREATED train dataset ")
        network_type = 'mnist'

        print("Creating VAL dataset ")
        val_dataset = dataset
    elif train_dataset == "dataFolder":
        dataset = datasets.DataFolder(train_Folder,
                                      batch_size,
                                      out_size=output_size,
                                      validation_proportion=(1 - trainSplit))

        print("Folder datasets created ")
        val_dataset = dataset
    else:
        dataset = datasets.Dataset(name=train_dataset,
                                   batch_size=batch_size,
                                   output_size=output_size)
        val_dataset = datasets.Dataset(name=val_dataset,
                                       batch_size=batch_size,
                                       output_size=output_size)

    latent_spec = [(Uniform(100), False)]
    if categories is not None:
        latent_spec.append((Categorical(categories), True))

    is_reg = False
    for x, y in latent_spec:
        if y:
            is_reg = True
    model = RegularizedGAN(output_dist=output_dist,
                           latent_spec=latent_spec,
                           is_reg=is_reg,
                           batch_size=batch_size,
                           image_shape=dataset.image_shape,
                           network_type=network_type,
                           impr=improvedGan)
    if (not improvedGan):
        algo = InfoGANTrainer(model=model,
                              dataset=dataset,
                              val_dataset=val_dataset,
                              batch_size=batch_size,
                              isTrain=True,
                              exp_name=exp_name,
                              log_dir=log_dir,
                              checkpoint_dir=checkpoint_dir,
                              samples_dir=samples_dir,
                              max_epoch=max_epoch,
                              info_reg_coeff=1.0,
                              generator_learning_rate=2e-3,
                              discriminator_learning_rate=2e-3,
                              semiSup=semiSup)
    else:
        algo = ImprovedGAN(model=model,
                           dataset=dataset,
                           val_dataset=val_dataset,
                           batch_size=batch_size,
                           isTrain=True,
                           exp_name=exp_name,
                           log_dir=log_dir,
                           checkpoint_dir=checkpoint_dir,
                           samples_dir=samples_dir,
                           max_epoch=max_epoch,
                           info_reg_coeff=1.0,
                           generator_learning_rate=2e-3,
                           discriminator_learning_rate=2e-3,
                           semiSup=semiSup)

    #gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.03)
    #config=tf.ConfigProto(gpu_options=gpu_options)
    #device_name = "/gpu:0"
    #with tf.device(device_name):
    algo.init_opt()
    with tf.Session() as sess:
        algo.train(sess)
示例#2
0
    checkpoint_dir = os.path.join(root_checkpoint_dir, exp_name)

    mkdir_p(log_dir)
    mkdir_p(checkpoint_dir)

    dataset = MnistDataset()

    latent_spec = [
        (Uniform(62), False),
        (Categorical(10), True),
        (Uniform(1, fix_std=True), True),
        (Uniform(1, fix_std=True), True),
    ]

    model = RegularizedGAN(
        output_dist=MeanBernoulli(dataset.image_dim),
        latent_spec=latent_spec,
        batch_size=batch_size,
        image_shape=dataset.image_shape,
        network_type="mnist",
    )

    algo = InfoGANTrainer(
        model=model,
        dataset=dataset,
        batch_size=batch_size,
        exp_name=exp_name,
        log_dir=log_dir,
        checkpoint_dir=checkpoint_dir,
        max_epoch=max_epoch,
        updates_per_epoch=updates_per_epoch,
示例#3
0
文件: run.py 项目: rootkit/gan-lib
def train(model_name, learning_params):
    timestamp = get_timestamp()

    root_log_dir = os.path.join('logs', model_name)
    root_checkpoint_dir = os.path.join('ckt', model_name)
    experiment_name = '{}_{}'.format(model_name, timestamp)
    log_dir = os.path.join(root_log_dir, experiment_name)
    checkpoint_dir = os.path.join(root_checkpoint_dir, experiment_name)
    make_exists(log_dir)
    make_exists(checkpoint_dir)

    batch_size = learning_params['batch_size']
    updates_per_epoch = learning_params['updates_per_epoch']
    max_epoch = learning_params['max_epoch']
    trainer = learning_params['trainer']

    if model_name == 'mnist_infogan':
        output_dataset = MnistDataset()
        latent_spec = [
            (Uniform(62), False),
            (Categorical(10), True),
            (Uniform(1, fix_std=True), True),
            (Uniform(1, fix_std=True), True),
        ]
        model = MNISTInfoGAN(
            batch_size=batch_size,
            output_dataset=output_dataset,
            output_dist=MeanBernoulli(output_dataset.image_dim),
            latent_spec=latent_spec,
        )
    elif model_name == 'mnist_wasserstein':
        output_dataset = MnistDataset()
        latent_spec = [
            (Uniform(62), False),
        ]
        model = MNISTInfoGAN(
            batch_size=batch_size,
            output_dataset=output_dataset,
            output_dist=MeanBernoulli(output_dataset.image_dim),
            final_activation=None,
            latent_spec=latent_spec,
        )
    elif model_name == 'celebA_infogan':
        output_dataset = CelebADataset()
        latent_spec = [
            (Uniform(128), False),
            (Categorical(10), True),
            (Categorical(10), True),
            (Categorical(10), True),
            (Categorical(10), True),
            (Categorical(10), True),
            (Categorical(10), True),
            (Categorical(10), True),
            (Categorical(10), True),
            (Categorical(10), True),
            (Categorical(10), True),
        ]
        model = CelebAInfoGAN(
            batch_size=batch_size,
            output_dataset=output_dataset,
            output_dist=MeanGaussian(output_dataset.image_dim, fix_std=True),
            latent_spec=latent_spec,
        )
    elif model_name == 'celebA_wasserstein':
        output_dataset = CelebADataset()
        latent_spec = [
            (Uniform(128), False),
        ]
        model = CelebAInfoGAN(
            batch_size=batch_size,
            output_dataset=output_dataset,
            output_dist=MeanGaussian(output_dataset.image_dim, fix_std=True),
            final_activation=None,
            latent_spec=latent_spec,
        )
    elif model_name == 'horse_zebra':
        horse_dataset = HorseOrZebraDataset('horse')
        zebra_dataset = HorseOrZebraDataset('zebra')
        horse2zebra_model = Horse2Zebra_CycleGAN(
            input_dataset=horse_dataset,
            batch_size=batch_size,
            output_dataset=zebra_dataset,
            output_dist=MeanGaussian(zebra_dataset.image_dim, fix_std=True),
            final_activation=None,
            scope_suffix='_horse2zebra',
        )
        zebra2horse_model = Horse2Zebra_CycleGAN(
            input_dataset=zebra_dataset,
            batch_size=batch_size,
            output_dataset=horse_dataset,
            output_dist=MeanGaussian(zebra_dataset.image_dim, fix_std=True),
            final_activation=None,
            scope_suffix='_zebra2horse',
        )
    else:
        raise ValueError('Invalid model_name: {}'.format(model_name))

    if trainer == 'infogan':
        d_optim = tf.train.AdamOptimizer(2e-4, beta1=0.5)
        g_optim = tf.train.AdamOptimizer(1e-3, beta1=0.5)
        loss = GANLoss()
        loss_builder = InfoGANLossBuilder(
            model=model,
            loss=loss,
            batch_size=batch_size,
            g_optimizer=g_optim,
            d_optimizer=d_optim,

        )
        algo = GANTrainer(
            loss_builder=loss_builder,
            exp_name=experiment_name,
            log_dir=log_dir,
            checkpoint_dir=checkpoint_dir,
            max_epoch=max_epoch,
            updates_per_epoch=updates_per_epoch,
        )

    elif trainer == 'wasserstein':
        d_optim = tf.train.AdamOptimizer(2e-4, beta1=0.5)
        g_optim = tf.train.AdamOptimizer(1e-3, beta1=0.5)
        loss = WassersteinGANLoss()
        loss_builder = GANLossBuilder(
            model=model,
            loss=loss,
            batch_size=batch_size,
            g_optimizer=g_optim,
            d_optimizer=d_optim,

        )
        algo = WassersteinGANTrainer(
            loss_builder=loss_builder,
            exp_name=experiment_name,
            log_dir=log_dir,
            checkpoint_dir=checkpoint_dir,
            max_epoch=max_epoch,
            updates_per_epoch=updates_per_epoch,
        )

    elif trainer == 'test':
        d_optim = tf.train.AdamOptimizer(2e-4, beta1=0.5)
        g_optim = tf.train.AdamOptimizer(1e-3, beta1=0.5)
        loss = GANLoss()
        loss_builder = InfoGANLossBuilder(
            model=model,
            loss=loss,
            dataset=output_dataset,
            batch_size=batch_size,
            discrim_optimizer=d_optim,
            generator_optimizer=g_optim,
        )
        algo = GANTrainer(
            loss_builder=loss_builder,
            exp_name=experiment_name,
            log_dir=log_dir,
            checkpoint_dir=checkpoint_dir,
            max_epoch=max_epoch,
            updates_per_epoch=updates_per_epoch,
        )
    elif trainer == 'cycle_gan':
        d_optim = tf.train.AdamOptimizer(2e-4, beta1=0.5)
        g_optim = tf.train.AdamOptimizer(2e-4, beta1=0.5)
        loss = LeastSquaresGANLoss()
        horse2zebra_loss_builder = GANLossBuilder(
            model=horse2zebra_model,
            loss=loss,
            batch_size=batch_size,
        )
        zebra2horse_loss_builder = GANLossBuilder(
            model=zebra2horse_model,
            loss=loss,
            batch_size=batch_size,
        )
        loss_builders = [horse2zebra_loss_builder, zebra2horse_loss_builder]
        loss_builder = CycleGANLossBuilder(
            loss_builders,
            g_optimizer=g_optim,
            d_optimizer=d_optim,
        )
        algo = GANTrainer(
            loss_builder=loss_builder,
            exp_name=experiment_name,
            log_dir=log_dir,
            checkpoint_dir=checkpoint_dir,
            max_epoch=max_epoch,
            updates_per_epoch=updates_per_epoch,
        )
    else:
        raise ValueError('Invalid trainer: {}'.format(trainer))

    algo.train()