def setUp(self): self.task_type = "worker" self.worker_index = 0 self.num_workers = 5 set_fake_tf_config(task_type=self.task_type, worker_index=self.worker_index) self.base_training_params = setup_fake_base_training_params() self.schema_params = setup_fake_schema_params() self.base_dir = tempfile.mkdtemp() self.output_dir = tempfile.mkdtemp(dir=self.base_dir) self.input_dir = tempfile.mkdtemp(dir=self.base_dir) # Create a dummy folder inside input_dir # This is to make sure the input_dir is not empty tempfile.mkdtemp(dir=self.input_dir) self.mock_model = Mock() self.mock_model.metadata_file = os.path.join( os.getcwd(), "test/resources/metadata/tensor_metadata.json") self.mock_model.validation_data_dir = self.input_dir self.mock_model.checkpoint_path = self.output_dir self.mock_model.training_data_dir = self.input_dir self.fixed_effect_driver = FixedEffectDriver( base_training_params=self.base_training_params, model=self.mock_model) self.random_effect_driver = RandomEffectDriver( base_training_params=self.base_training_params, model=self.mock_model)
def test_drivers_without_tfconfig(self): """ Test fixed and random effect driver constructors without TF_CONFIG, expect to set local mode :return: None """ if constants.TF_CONFIG in os.environ: del os.environ[constants.TF_CONFIG] expected_fe_execution_context = { constants.TASK_TYPE: 'worker', constants.TASK_INDEX: 0, constants.CLUSTER_SPEC: { "worker": ["localhost:2222"] }, constants.NUM_WORKERS: 1, constants.NUM_SHARDS: 1, constants.SHARD_INDEX: 0, constants.IS_CHIEF: True } fixed_effect_driver = FixedEffectDriver( base_training_params=self.base_training_params, model=self.mock_model) self.assertEqual(fixed_effect_driver.execution_context, expected_fe_execution_context) random_effect_driver = RandomEffectDriver( base_training_params=self.base_training_params, model=self.mock_model) expected_re_execution_context = expected_fe_execution_context expected_re_execution_context[constants.CLUSTER_SPEC] = None self.assertEqual(random_effect_driver.execution_context, expected_re_execution_context)
def get_driver(base_training_params, raw_model_params): """ Create driver and associated dependencies, based on type. Only linear, estimator-based models supported for now :param base_training_params: Parsed base training parameters common to all models. This could including path to training data, validation data, metadata file path, learning rate etc. :param raw_model_params: Raw model parameters, representing model-specific requirements. For example, a CNN might expose filter_size as a parameter, a text-based model might expose the size it's word embedding matrix as a parameter :return: Fixed or Random effect driver """ driver_type = base_training_params[constants.STAGE] model = ModelFactory.get_model(base_training_params, raw_model_params) if driver_type == constants.FIXED_EFFECT: logger.info("Instantiating fixed effect model and driver") driver = FixedEffectDriver(base_training_params=base_training_params, model=model) elif driver_type == constants.RANDOM_EFFECT: logger.info("Instantiating random effect model and driver") driver = RandomEffectDriver(base_training_params=base_training_params, model=model) else: raise Exception("Unknown training stage") return driver
class TestDriver(tf.test.TestCase): """ Test Fixed and Random effect drivers """ def setUp(self): self.task_type = "worker" self.worker_index = 0 self.num_workers = 5 set_fake_tf_config(task_type=self.task_type, worker_index=self.worker_index) self.base_training_params = setup_fake_base_training_params() self.schema_params = setup_fake_schema_params() self.base_dir = tempfile.mkdtemp() self.output_dir = tempfile.mkdtemp(dir=self.base_dir) self.input_dir = tempfile.mkdtemp(dir=self.base_dir) # Create a dummy folder inside input_dir # This is to make sure the input_dir is not empty tempfile.mkdtemp(dir=self.input_dir) self.mock_model = Mock() self.mock_model.metadata_file = os.path.join( os.getcwd(), "test/resources/metadata/tensor_metadata.json") self.mock_model.validation_data_dir = self.input_dir self.mock_model.checkpoint_path = self.output_dir self.mock_model.training_data_dir = self.input_dir self.fixed_effect_driver = FixedEffectDriver( base_training_params=self.base_training_params, model=self.mock_model) self.random_effect_driver = RandomEffectDriver( base_training_params=self.base_training_params, model=self.mock_model) def tearDown(self): # Clean up the checkpoint dir created by the driver tf.io.gfile.rmtree(self.base_dir) def test_drivers_without_tfconfig(self): """ Test fixed and random effect driver constructors without TF_CONFIG, expect to set local mode :return: None """ if constants.TF_CONFIG in os.environ: del os.environ[constants.TF_CONFIG] expected_fe_execution_context = { constants.TASK_TYPE: 'worker', constants.TASK_INDEX: 0, constants.CLUSTER_SPEC: { "worker": ["localhost:2222"] }, constants.NUM_WORKERS: 1, constants.NUM_SHARDS: 1, constants.SHARD_INDEX: 0, constants.IS_CHIEF: True } fixed_effect_driver = FixedEffectDriver( base_training_params=self.base_training_params, model=self.mock_model) self.assertEqual(fixed_effect_driver.execution_context, expected_fe_execution_context) random_effect_driver = RandomEffectDriver( base_training_params=self.base_training_params, model=self.mock_model) expected_re_execution_context = expected_fe_execution_context expected_re_execution_context[constants.CLUSTER_SPEC] = None self.assertEqual(random_effect_driver.execution_context, expected_re_execution_context) def test_fixed_effect_cluster_spec(self): """ Test the cluster specification for fixed effect training :return: None """ fe_execution_context = self.fixed_effect_driver.execution_context # Assert cluster specification self.assertEqual(fe_execution_context[constants.TASK_INDEX], self.worker_index) self.assertEqual(fe_execution_context[constants.TASK_TYPE], self.task_type) self.assertEqual(fe_execution_context[constants.NUM_WORKERS], self.num_workers) self.assertEqual(fe_execution_context[constants.NUM_SHARDS], self.num_workers) self.assertEqual(fe_execution_context[constants.SHARD_INDEX], self.worker_index) def test_random_effect_cluster_spec(self): """ Test the cluster specification for random effect training :return: None """ re_execution_context = self.random_effect_driver.execution_context # Assert cluster specification self.assertEqual(re_execution_context[constants.TASK_INDEX], self.worker_index) self.assertEqual(re_execution_context[constants.TASK_TYPE], self.task_type) self.assertEqual(re_execution_context[constants.NUM_WORKERS], self.num_workers) self.assertEqual(re_execution_context[constants.NUM_SHARDS], 1) self.assertEqual(re_execution_context[constants.SHARD_INDEX], 0) def test_fixed_effect_training(self): """ Test the fixed effect driver during training :return: None """ # Run training self.fixed_effect_driver.run_training(schema_params=self.schema_params, export_model=False, output_model_dir=None) # Assert model is trained only once with the right parameters self.mock_model.train.assert_called_once_with( training_data_dir=self.mock_model.training_data_dir, validation_data_dir=self.mock_model.validation_data_dir, metadata_file=self.mock_model.metadata_file, checkpoint_path=self.mock_model.checkpoint_path, execution_context=self.fixed_effect_driver.execution_context, schema_params=self.schema_params) def test_random_effect_training(self): """ Test the random effect driver during training :return: None """ # Read dummy partition index list. Parse the partitions random effect worker should work on with tf.io.gfile.GFile( self.base_training_params.partition_list_file) as f: line = f.readline() all_partitions = [int(l) for l in line.split(',')] partition_index_list = [ all_partitions[i] for i in (list( range( self.random_effect_driver.execution_context[ constants.TASK_INDEX], len(all_partitions), self.random_effect_driver.execution_context[ constants.NUM_WORKERS]))) ] # Gather all the calls to compile() and train() method of the mock model train_calls = [] for partition_index in partition_index_list: checkpoint_path = self.random_effect_driver._anchor_directory( self.mock_model.checkpoint_path, partition_index) training_data_dir = self.random_effect_driver._anchor_directory( self.mock_model.training_data_dir, partition_index) validation_data_dir = self.random_effect_driver._anchor_directory( self.mock_model.validation_data_dir, partition_index) # Create training_data_dir os.mkdir(training_data_dir) # Make sure training_data_dir is not empty tempfile.mkdtemp(dir=training_data_dir) train_calls.append( mock.call(training_data_dir=training_data_dir, validation_data_dir=validation_data_dir, metadata_file=self.mock_model.metadata_file, checkpoint_path=checkpoint_path, execution_context=self.random_effect_driver. execution_context, schema_params=self.schema_params)) # Run training self.random_effect_driver.run_training( schema_params=self.schema_params, export_model=False, output_model_dir=None) # Assert model was called with the right calls self.mock_model.train.assert_has_calls(train_calls) def test_fixed_effect_inference(self): """ Test the fixed effect driver during inference :param test_create_dataset: mock create_dataset function :return: None """ self.base_training_params.action = constants.ACTION_INFERENCE # Run inference self.fixed_effect_driver.run_inference( schema_params=self.schema_params) inference_calls = [] inference_calls.append( mock.call( output_dir=os.path.join( self.base_training_params.training_score_dir), input_data_path=self.mock_model.training_data_dir, metadata_file=self.mock_model.metadata_file, checkpoint_path=self.mock_model.checkpoint_path, execution_context=self.fixed_effect_driver.execution_context, schema_params=self.schema_params)) inference_calls.append( mock.call( output_dir=os.path.join( self.base_training_params.validation_score_dir), input_data_path=self.mock_model.validation_data_dir, metadata_file=self.mock_model.metadata_file, checkpoint_path=self.mock_model.checkpoint_path, execution_context=self.fixed_effect_driver.execution_context, schema_params=self.schema_params)) # Assert model was called with the right calls self.mock_model.predict.assert_has_calls(inference_calls) def test_random_effect_inference(self): """ Test the random effect driver during inference :param test_create_dataset: mock create_dataset function :return: None """ self.base_training_params.action = constants.ACTION_INFERENCE # Run inference self.random_effect_driver.run_inference( schema_params=self.schema_params) # Read dummy partition index list. Parse the partitions random effect worker should work on with tf.io.gfile.GFile( self.base_training_params.partition_list_file) as f: line = f.readline() all_partitions = [int(l) for l in line.split(',')] partition_index_list = [ all_partitions[i] for i in (list( range( self.random_effect_driver.execution_context[ constants.TASK_INDEX], len(all_partitions), self.random_effect_driver.execution_context[ constants.NUM_WORKERS]))) ] # Gather all the calls to create_dataset(), compile() and train() method of the mock dataset_loader and model infer_calls = [] for partition_index in partition_index_list: checkpoint_path = os.path.join(self.mock_model.checkpoint_path) training_data_dir = self.random_effect_driver._anchor_directory( self.mock_model.training_data_dir, partition_index) validation_data_dir = self.random_effect_driver._anchor_directory( self.mock_model.validation_data_dir, partition_index) infer_calls.append( mock.call(output_dir=os.path.join( self.base_training_params.training_score_dir, RandomEffectDriver._RANDOM_EFFECT_PARTITION_DIR_PREFIX + str(partition_index)), input_data_path=training_data_dir, metadata_file=self.mock_model.metadata_file, checkpoint_path=checkpoint_path, execution_context=self.random_effect_driver. execution_context, schema_params=self.schema_params)) infer_calls.append( mock.call(output_dir=os.path.join( self.base_training_params.validation_score_dir, RandomEffectDriver._RANDOM_EFFECT_PARTITION_DIR_PREFIX + str(partition_index)), input_data_path=validation_data_dir, metadata_file=self.mock_model.metadata_file, checkpoint_path=checkpoint_path, execution_context=self.random_effect_driver. execution_context, schema_params=self.schema_params)) # Assert create_dataset() and model were called with the right calls # test_create_dataset.assert_has_calls(create_dataset_calls) self.mock_model.predict.assert_has_calls(infer_calls)