def recordio_train(self, module_name, model_name, file_patterns, **module_kwargs): """Trains the model with a RecordIO dataset for a few steps.""" tf_model = getattr(module_name, model_name)(**module_kwargs) params = self._get_params( model_dir=self._test_case.create_tempdir().full_path, **module_kwargs) input_generator = default_input_generator.DefaultRecordInputGenerator( file_patterns, batch_size=params['batch_size']) initialize_system = tpu.initialize_system with mock.patch.object(tpu, 'initialize_system', autospec=True) as mock_init: mock_init.side_effect = initialize_system train_eval.train_eval_model( t2r_model=tf_model, input_generator_train=input_generator, max_train_steps=params['max_train_steps'], model_dir=params['model_dir'], use_tpu_wrapper=params['use_tpu_wrapper']) if self._use_tpu: mock_init.assert_called() train_eval_test_utils.assert_output_files( test_case=self._test_case, model_dir=params['model_dir'], expected_output_filename_patterns=train_eval_test_utils. DEFAULT_TRAIN_FILENAME_PATTERNS)
def test_record_input_generator(self): base_dir = 'tensor2robot' file_pattern = os.path.join(FLAGS.test_srcdir, base_dir, 'test_data/pose_env_test_data.tfrecord') input_generator = default_input_generator.DefaultRecordInputGenerator( file_patterns=file_pattern, batch_size=BATCH_SIZE) self._test_input_generator(input_generator)
def test_multi_record_input_generator(self): base_dir = 'tensor2robot' file_pattern = os.path.join(FLAGS.test_srcdir, base_dir, 'test_data/pose_env_test_data.tfrecord') dataset_map = {'d1': file_pattern, 'd2': file_pattern} input_generator = default_input_generator.DefaultRecordInputGenerator( dataset_map=dataset_map, batch_size=BATCH_SIZE) self._test_multi_record_input_generator(input_generator)
def setUp(self): super(PoseEnvModelsTest, self).setUp() base_dir = 'tensor2robot' test_data = os.path.join(FLAGS.test_srcdir, base_dir, 'test_data/pose_env_test_data.tfrecord') self._train_log_dir = FLAGS.test_tmpdir if tf.io.gfile.exists(self._train_log_dir): tf.io.gfile.rmtree(self._train_log_dir) gin.bind_parameter('train_eval_model.max_train_steps', 3) gin.bind_parameter('train_eval_model.eval_steps', 2) self._record_input_generator = ( default_input_generator.DefaultRecordInputGenerator( batch_size=BATCH_SIZE, file_patterns=test_data)) self._meta_record_input_generator_train = ( default_input_generator.DefaultRandomInputGenerator( batch_size=BATCH_SIZE)) self._meta_record_input_generator_eval = ( default_input_generator.DefaultRandomInputGenerator( batch_size=BATCH_SIZE))