Exemplo n.º 1
0
    def test_build_graph(self, mock_data_provider):
        hparams = train_lib.HParams(batch_size=16,
                                    max_number_of_steps=0,
                                    generator_lr=0.0002,
                                    discriminator_lr=0.0002,
                                    master='',
                                    train_log_dir='/tmp/tfgan_logdir/cifar/',
                                    ps_replicas=0,
                                    task=0)

        # Mock input pipeline.
        mock_imgs = np.zeros([hparams.batch_size, 32, 32, 3], dtype=np.float32)
        mock_lbls = np.concatenate(
            (np.ones([hparams.batch_size, 1], dtype=np.float32),
             np.zeros([hparams.batch_size, 9], dtype=np.float32)),
            axis=1)
        mock_data_provider.provide_data.return_value = (mock_imgs, mock_lbls)

        train_lib.train(hparams)
Exemplo n.º 2
0
def main(_):
    hparams = train_lib.HParams(FLAGS.batch_size, FLAGS.max_number_of_steps,
                                FLAGS.generator_lr, FLAGS.discriminator_lr,
                                FLAGS.master, FLAGS.train_log_dir,
                                FLAGS.ps_replicas, FLAGS.task)
    train_lib.train(hparams)