コード例 #1
0
def train_cifar_gan(args):
    cifar_ds = CifarDataset(args.cifar_ds_path, 64, args.how_many_labeled)
    learning_options = adam_learning_options()
    learning_options['lr'] = args.learning_rate
    gan_cifar_trainer = SSGanTrainer(args.latent_dim, cifar_ds, 10,
                                     args.out_weights_dir, 64,
                                     learning_options)
    gan_cifar_trainer.train_ss(args.epochs_num)
コード例 #2
0
def train_cifar_gan(args):
    cifar_ds = SupervisedCifarDataset(args.cifar_ds_path, 64,
                                      args.how_many_labeled)
    learning_options = adam_learning_options()
    learning_options['lr'] = args.learning_rate
    cifar_trainer = SupCnnTrainer(cifar_ds, 64, cifar_gan_discriminator,
                                  learning_options)
    cifar_trainer.train(args.epochs_num, args.testing_interval)
コード例 #3
0
    def __init__(self,
                 dataset,
                 batch_size,
                 create_model_func,
                 train_options=adam_learning_options()):
        self.dataset = dataset
        self.batch_size = batch_size
        self.train_options = train_options

        self.img_size = dataset.img_size()
        self.classes_num = dataset.classes_num()

        self.input_images = tf.placeholder(
            tf.float32,
            shape=[batch_size, self.img_size, self.img_size, 3],
            name="input_images")
        self.labels = tf.placeholder(tf.float32,
                                     shape=[batch_size, self.classes_num])
        self.create_model_func = create_model_func

        self.logger.info("Initialization finished, dataset size = %d" %
                         dataset.size())
コード例 #4
0
    def __init__(self,
                 latent_dim,
                 dataset,
                 classes_num,
                 out_weights_dir,
                 batch_size=64,
                 train_options=adam_learning_options()):
        super(SSGanTrainer,
              self).__init__(latent_dim, dataset, out_weights_dir, batch_size,
                             train_options)

        self.dataset = dataset
        self.classes_num = classes_num
        self.testing_interval = config.TESTING_INTERVAL
        self.testing_iterations = config.TESTING_ITERATIONS

        self.labels = tf.placeholder(dtype=tf.float32,
                                     shape=[self.batch_size, self.classes_num])
        self.unlabeled_images = tf.placeholder(
            dtype=tf.float32,
            shape=[self.batch_size, self.img_size, self.img_size, 3],
            name='unlabeled_images')
コード例 #5
0
ファイル: vae_trainer.py プロジェクト: plazowicz/pydata2017
    def __init__(self,
                 latent_dim,
                 dataset,
                 out_weights_dir,
                 train_options=adam_learning_options(),
                 kl_loss_weight=0.0005):
        self.out_weights_dir = out_weights_dir
        self.train_options = train_options
        self.latent_dim = latent_dim
        self.batch_size = dataset.batch_size

        self.logger.info("Creating output weights directory %s" %
                         self.out_weights_dir)
        tl.files.exists_or_mkdir(out_weights_dir)
        self.img_size = dataset.img_size()
        self.input_images = tf.placeholder(
            tf.float32,
            shape=[self.batch_size, self.img_size, self.img_size, 3],
            name='input_images')

        self.kl_loss_weight = kl_loss_weight
        self.weights_dump_interval = config.WEIGHTS_DUMP_INTERVAL
        self.dataset = dataset
コード例 #6
0
    def __init__(self, latent_dim, transformer, out_weights_dir, batch_size=64, train_options=adam_learning_options()):
        self.out_weights_dir = out_weights_dir
        self.transformer = transformer
        self.latent_dim = latent_dim
        self.batch_size = batch_size

        self.logger.info("Creating output weights directory %s" % self.out_weights_dir)
        tl.files.exists_or_mkdir(out_weights_dir)

        self.img_size = transformer.img_size()
        self.train_options = train_options
        self.weights_dump_interval = config.WEIGHTS_DUMP_INTERVAL

        self.input_images = tf.placeholder(tf.float32, shape=[self.batch_size, self.img_size, self.img_size, 3],
                                           name='input_images')
        self.z = tf.placeholder(tf.float32, shape=[None, self.latent_dim], name='z')
コード例 #7
0
def train_cifar_gan(args):
    cifar_ds = UnsupervisedCifarDataSet(args.cifar_ds_path, 64)
    learning_options = adam_learning_options()
    learning_options['lr'] = args.learning_rate
    vae_cifar_trainer = VaeTrainer(args.latent_dim, cifar_ds, args.out_weights_dir, learning_options)
    vae_cifar_trainer.train(args.epochs_num)