def test_master_prepare(self): self.arguments[ "distribution_strategy"] = DistributionStrategy.PARAMETER_SERVER with tempfile.TemporaryDirectory() as temp_dir_name: create_recordio_file( self._num_records, DatasetName.TEST_MODULE, 1, temp_dir=temp_dir_name, ) self.arguments["training_data"] = temp_dir_name args = self._get_args() args = parse_master_args(args) master = Master(args) master._set_command_in_pod_manager() self.assertListEqual(master.pod_manager._worker_command, ["/bin/bash"]) self.arguments["need_elasticdl_job_service"] = "False" self.arguments["job_command"] = "python --version" args = self._get_args() args = parse_master_args(args) master = Master(args) master._set_command_in_pod_manager() self.assertListEqual(master.pod_manager._worker_args, ["-c", "python --version"])
def test_master_validate(self): with tempfile.TemporaryDirectory() as temp_dir_name: create_recordio_file( self._num_records, DatasetName.TEST_MODULE, 1, temp_dir=temp_dir_name, ) self.arguments["training_data"] = temp_dir_name self.arguments["task_fault_tolerance"] = "False" args = self._get_args() args = parse_master_args(args) master = Master(args) with self.assertRaises(Exception): master.validate() self.arguments["need_elasticdl_job_service"] = "False" args = self._get_args() args = parse_master_args(args) master = Master(args) master.validate() self.arguments["task_fault_tolerance"] = "True" args = self._get_args() args = parse_master_args(args) master = Master(args) master.pod_manager = None with self.assertRaises(Exception): master.validate()
def test_recordio_data_reader(self): num_records = 128 with tempfile.TemporaryDirectory() as temp_dir_name: shard_name = create_recordio_file(num_records, DatasetName.TEST_MODULE, 1, temp_dir=temp_dir_name) # Test shards creation expected_shards = {shard_name: (0, num_records)} reader = RecordIODataReader(data_dir=temp_dir_name) self.assertEqual(expected_shards, reader.create_shards()) # Test records reading records = list( reader.read_records( _MockedTask(0, num_records, shard_name, elasticdl_pb2.TRAINING))) self.assertEqual(len(records), num_records) for record in records: parsed_record = tf.io.parse_single_example( record, { "x": tf.io.FixedLenFeature([1], tf.float32), "y": tf.io.FixedLenFeature([1], tf.float32), }, ) for k, v in parsed_record.items(): self.assertEqual(len(v.numpy()), 1)
def test_create_master_for_allreduce(self): self.arguments[ "distribution_strategy"] = DistributionStrategy.ALLREDUCE with tempfile.TemporaryDirectory() as temp_dir_name: create_recordio_file( self._num_records, DatasetName.TEST_MODULE, 1, temp_dir=temp_dir_name, ) self.arguments["training_data"] = temp_dir_name self.arguments["custom_training_loop"] = "true" args = self._get_args() args = parse_master_args(args) master = ElasticdlJobService(args, TaskManager(args)) self.assertIsNotNone(master)
def test_create_master_without_eval(self): self.arguments[ "distribution_strategy"] = DistributionStrategy.ALLREDUCE self.arguments["custom_training_loop"] = "true" self.arguments["model_def"] = "mnist.mnist_train_tfv2.train" with tempfile.TemporaryDirectory() as temp_dir_name: create_recordio_file( self._num_records, DatasetName.TEST_MODULE, 1, temp_dir=temp_dir_name, ) self.arguments["training_data"] = temp_dir_name args = self._get_args() args = parse_master_args(args) master = ElasticdlJobService(args, TaskManager(args)) self.assertIsNone(master.evaluation_service)
def test_master_run_and_stop(self): self.arguments[ "distribution_strategy"] = DistributionStrategy.PARAMETER_SERVER with tempfile.TemporaryDirectory() as temp_dir_name: create_recordio_file( self._num_records, DatasetName.TEST_MODULE, 1, temp_dir=temp_dir_name, ) self.arguments["training_data"] = temp_dir_name args = self._get_args() args = parse_master_args(args) master = Master(args) master.task_manager._todo.clear() master.pod_manager = Mock() master.pod_manager.all_workers_exited = True master.pod_manager.all_workers_failed = False exit_code = master.run() master.stop() self.assertEqual(exit_code, 0) master.pod_manager.all_workers_failed = True exit_code = master.run() self.assertEqual(exit_code, 1)
def testMaxCheckpointVersions(self): with tempfile.TemporaryDirectory() as tempdir: chkp_dir = os.path.join(tempdir, "testMaxCheckpointVersions") os.makedirs(chkp_dir) # Save checkpoints every 2 steps, and keep 5 checkpoints at most checkpointer = CheckpointService(chkp_dir, 2, 5, False) self.assertTrue(checkpointer.is_enabled()) batch_size = 2 # Launch the training arguments = [ "--worker_id", 1, "--job_type", JobType.TRAINING_ONLY, "--minibatch_size", batch_size, "--model_zoo", _model_zoo_path, "--model_def", "test_module.custom_model", ] args = parse_worker_args(arguments) worker = Worker(args) filename = create_recordio_file(128, DatasetName.TEST_MODULE, 1) task_d = _TaskDispatcher({filename: (0, 128)}, {}, {}, records_per_task=64, num_epochs=1) master = MasterServicer( 2, batch_size, worker._opt_fn(), task_d, init_var=worker._model.trainable_variables, checkpoint_filename_for_init="", checkpoint_service=checkpointer, evaluation_service=None, ) worker._stub = InProcessMaster(master) worker.run() # We should have 5 checkpoints when the training finishes checkpoint_files = sorted(os.listdir(checkpointer._directory)) self.assertEqual( checkpoint_files, [ "model_v24.chkpt", "model_v26.chkpt", "model_v28.chkpt", "model_v30.chkpt", "model_v32.chkpt", ], ) # Latest version should be 32 self.assertEqual(32, checkpointer.get_latest_checkpoint_version()) # Check all checkpoints for version in [24, 26, 28, 30, 32]: model = checkpointer.get_checkpoint_model(version) self.assertEqual(version, model.version) # Checkpoint not found self.assertRaisesRegex( RuntimeError, "Failed to read model checkpoint from file", checkpointer.get_checkpoint_model, 100, )