def test_build_model_beam(self): inputs, targets, outputs, _ = visualization.build_model( hparams_set, model_name, self.data_dir, problem_name, beam_size=8) self.assertAllEqual((1, None, 1, 1), inputs.shape.as_list()) self.assertAllEqual((1, None, 1, 1), targets.shape.as_list()) self.assertAllEqual((None, None), outputs.shape.as_list())