コード例 #1
0
ファイル: train_test.py プロジェクト: sts-sadr/gan-2
    def test_full_flow(self, mock_data_provider):
        hparams = train_lib.HParams(batch_size=16,
                                    max_number_of_steps=2,
                                    noise_dims=3,
                                    output_dir=self.get_temp_dir())

        # Construct mock inputs.
        mock_imgs = np.zeros([hparams.batch_size, 28, 28, 1], dtype=np.float32)
        mock_lbls = np.concatenate(
            (np.ones([hparams.batch_size, 1], dtype=np.int32),
             np.zeros([hparams.batch_size, 9], dtype=np.int32)),
            axis=1)
        mock_data_provider.provide_data.return_value = (mock_imgs, mock_lbls)

        train_lib.train(hparams)
コード例 #2
0
ファイル: train.py プロジェクト: zhouyonglong/gan
def main(_):
    hparams = train_lib.HParams(FLAGS.batch_size, FLAGS.max_number_of_steps,
                                FLAGS.noise_dims, FLAGS.output_dir)
    train_lib.train(hparams)