def test_full_flow(self, mock_data_provider): hparams = train_lib.HParams(batch_size=16, max_number_of_steps=2, noise_dims=3, output_dir=self.get_temp_dir()) # Construct mock inputs. mock_imgs = np.zeros([hparams.batch_size, 28, 28, 1], dtype=np.float32) mock_lbls = np.concatenate( (np.ones([hparams.batch_size, 1], dtype=np.int32), np.zeros([hparams.batch_size, 9], dtype=np.int32)), axis=1) mock_data_provider.provide_data.return_value = (mock_imgs, mock_lbls) train_lib.train(hparams)
def main(_): hparams = train_lib.HParams(FLAGS.batch_size, FLAGS.max_number_of_steps, FLAGS.noise_dims, FLAGS.output_dir) train_lib.train(hparams)