예제 #1
0
 def recordio_train(self, module_name, model_name, file_patterns,
                    **module_kwargs):
     """Trains the model with a RecordIO dataset for a few steps."""
     tf_model = getattr(module_name, model_name)(**module_kwargs)
     params = self._get_params(
         model_dir=self._test_case.create_tempdir().full_path,
         **module_kwargs)
     input_generator = default_input_generator.DefaultRecordInputGenerator(
         file_patterns, batch_size=params['batch_size'])
     initialize_system = tpu.initialize_system
     with mock.patch.object(tpu, 'initialize_system',
                            autospec=True) as mock_init:
         mock_init.side_effect = initialize_system
         train_eval.train_eval_model(
             t2r_model=tf_model,
             input_generator_train=input_generator,
             max_train_steps=params['max_train_steps'],
             model_dir=params['model_dir'],
             use_tpu_wrapper=params['use_tpu_wrapper'])
         if self._use_tpu:
             mock_init.assert_called()
         train_eval_test_utils.assert_output_files(
             test_case=self._test_case,
             model_dir=params['model_dir'],
             expected_output_filename_patterns=train_eval_test_utils.
             DEFAULT_TRAIN_FILENAME_PATTERNS)
 def test_record_input_generator(self):
     base_dir = 'tensor2robot'
     file_pattern = os.path.join(FLAGS.test_srcdir, base_dir,
                                 'test_data/pose_env_test_data.tfrecord')
     input_generator = default_input_generator.DefaultRecordInputGenerator(
         file_patterns=file_pattern, batch_size=BATCH_SIZE)
     self._test_input_generator(input_generator)
 def test_multi_record_input_generator(self):
     base_dir = 'tensor2robot'
     file_pattern = os.path.join(FLAGS.test_srcdir, base_dir,
                                 'test_data/pose_env_test_data.tfrecord')
     dataset_map = {'d1': file_pattern, 'd2': file_pattern}
     input_generator = default_input_generator.DefaultRecordInputGenerator(
         dataset_map=dataset_map, batch_size=BATCH_SIZE)
     self._test_multi_record_input_generator(input_generator)
예제 #4
0
    def setUp(self):
        super(PoseEnvModelsTest, self).setUp()
        base_dir = 'tensor2robot'
        test_data = os.path.join(FLAGS.test_srcdir, base_dir,
                                 'test_data/pose_env_test_data.tfrecord')
        self._train_log_dir = FLAGS.test_tmpdir
        if tf.io.gfile.exists(self._train_log_dir):
            tf.io.gfile.rmtree(self._train_log_dir)
        gin.bind_parameter('train_eval_model.max_train_steps', 3)
        gin.bind_parameter('train_eval_model.eval_steps', 2)

        self._record_input_generator = (
            default_input_generator.DefaultRecordInputGenerator(
                batch_size=BATCH_SIZE, file_patterns=test_data))

        self._meta_record_input_generator_train = (
            default_input_generator.DefaultRandomInputGenerator(
                batch_size=BATCH_SIZE))
        self._meta_record_input_generator_eval = (
            default_input_generator.DefaultRandomInputGenerator(
                batch_size=BATCH_SIZE))