Exemple #1
0
    def test_master_prepare(self):
        self.arguments[
            "distribution_strategy"] = DistributionStrategy.PARAMETER_SERVER
        with tempfile.TemporaryDirectory() as temp_dir_name:
            create_recordio_file(
                self._num_records,
                DatasetName.TEST_MODULE,
                1,
                temp_dir=temp_dir_name,
            )
            self.arguments["training_data"] = temp_dir_name
            args = self._get_args()
            args = parse_master_args(args)
            master = Master(args)
            master._set_command_in_pod_manager()
            self.assertListEqual(master.pod_manager._worker_command,
                                 ["/bin/bash"])

            self.arguments["need_elasticdl_job_service"] = "False"
            self.arguments["job_command"] = "python --version"
            args = self._get_args()
            args = parse_master_args(args)
            master = Master(args)
            master._set_command_in_pod_manager()
            self.assertListEqual(master.pod_manager._worker_args,
                                 ["-c", "python --version"])
Exemple #2
0
    def test_master_validate(self):
        with tempfile.TemporaryDirectory() as temp_dir_name:
            create_recordio_file(
                self._num_records,
                DatasetName.TEST_MODULE,
                1,
                temp_dir=temp_dir_name,
            )
            self.arguments["training_data"] = temp_dir_name
            self.arguments["task_fault_tolerance"] = "False"
            args = self._get_args()
            args = parse_master_args(args)
            master = Master(args)
            with self.assertRaises(Exception):
                master.validate()

            self.arguments["need_elasticdl_job_service"] = "False"
            args = self._get_args()
            args = parse_master_args(args)
            master = Master(args)
            master.validate()

            self.arguments["task_fault_tolerance"] = "True"
            args = self._get_args()
            args = parse_master_args(args)
            master = Master(args)
            master.pod_manager = None
            with self.assertRaises(Exception):
                master.validate()
    def test_recordio_data_reader(self):
        num_records = 128
        with tempfile.TemporaryDirectory() as temp_dir_name:
            shard_name = create_recordio_file(num_records,
                                              DatasetName.TEST_MODULE,
                                              1,
                                              temp_dir=temp_dir_name)

            # Test shards creation
            expected_shards = {shard_name: (0, num_records)}
            reader = RecordIODataReader(data_dir=temp_dir_name)
            self.assertEqual(expected_shards, reader.create_shards())

            # Test records reading
            records = list(
                reader.read_records(
                    _MockedTask(0, num_records, shard_name,
                                elasticdl_pb2.TRAINING)))
            self.assertEqual(len(records), num_records)
            for record in records:
                parsed_record = tf.io.parse_single_example(
                    record,
                    {
                        "x": tf.io.FixedLenFeature([1], tf.float32),
                        "y": tf.io.FixedLenFeature([1], tf.float32),
                    },
                )
                for k, v in parsed_record.items():
                    self.assertEqual(len(v.numpy()), 1)
 def test_create_master_for_allreduce(self):
     self.arguments[
         "distribution_strategy"] = DistributionStrategy.ALLREDUCE
     with tempfile.TemporaryDirectory() as temp_dir_name:
         create_recordio_file(
             self._num_records,
             DatasetName.TEST_MODULE,
             1,
             temp_dir=temp_dir_name,
         )
         self.arguments["training_data"] = temp_dir_name
         self.arguments["custom_training_loop"] = "true"
         args = self._get_args()
         args = parse_master_args(args)
         master = ElasticdlJobService(args, TaskManager(args))
         self.assertIsNotNone(master)
 def test_create_master_without_eval(self):
     self.arguments[
         "distribution_strategy"] = DistributionStrategy.ALLREDUCE
     self.arguments["custom_training_loop"] = "true"
     self.arguments["model_def"] = "mnist.mnist_train_tfv2.train"
     with tempfile.TemporaryDirectory() as temp_dir_name:
         create_recordio_file(
             self._num_records,
             DatasetName.TEST_MODULE,
             1,
             temp_dir=temp_dir_name,
         )
         self.arguments["training_data"] = temp_dir_name
         args = self._get_args()
         args = parse_master_args(args)
         master = ElasticdlJobService(args, TaskManager(args))
         self.assertIsNone(master.evaluation_service)
Exemple #6
0
 def test_master_run_and_stop(self):
     self.arguments[
         "distribution_strategy"] = DistributionStrategy.PARAMETER_SERVER
     with tempfile.TemporaryDirectory() as temp_dir_name:
         create_recordio_file(
             self._num_records,
             DatasetName.TEST_MODULE,
             1,
             temp_dir=temp_dir_name,
         )
         self.arguments["training_data"] = temp_dir_name
         args = self._get_args()
         args = parse_master_args(args)
         master = Master(args)
         master.task_manager._todo.clear()
         master.pod_manager = Mock()
         master.pod_manager.all_workers_exited = True
         master.pod_manager.all_workers_failed = False
         exit_code = master.run()
         master.stop()
         self.assertEqual(exit_code, 0)
         master.pod_manager.all_workers_failed = True
         exit_code = master.run()
         self.assertEqual(exit_code, 1)
Exemple #7
0
    def testMaxCheckpointVersions(self):
        with tempfile.TemporaryDirectory() as tempdir:
            chkp_dir = os.path.join(tempdir, "testMaxCheckpointVersions")
            os.makedirs(chkp_dir)
            # Save checkpoints every 2 steps, and keep 5 checkpoints at most
            checkpointer = CheckpointService(chkp_dir, 2, 5, False)
            self.assertTrue(checkpointer.is_enabled())

            batch_size = 2
            # Launch the training
            arguments = [
                "--worker_id",
                1,
                "--job_type",
                JobType.TRAINING_ONLY,
                "--minibatch_size",
                batch_size,
                "--model_zoo",
                _model_zoo_path,
                "--model_def",
                "test_module.custom_model",
            ]
            args = parse_worker_args(arguments)
            worker = Worker(args)

            filename = create_recordio_file(128, DatasetName.TEST_MODULE, 1)
            task_d = _TaskDispatcher({filename: (0, 128)}, {}, {},
                                     records_per_task=64,
                                     num_epochs=1)
            master = MasterServicer(
                2,
                batch_size,
                worker._opt_fn(),
                task_d,
                init_var=worker._model.trainable_variables,
                checkpoint_filename_for_init="",
                checkpoint_service=checkpointer,
                evaluation_service=None,
            )

            worker._stub = InProcessMaster(master)
            worker.run()

            # We should have 5 checkpoints when the training finishes
            checkpoint_files = sorted(os.listdir(checkpointer._directory))
            self.assertEqual(
                checkpoint_files,
                [
                    "model_v24.chkpt",
                    "model_v26.chkpt",
                    "model_v28.chkpt",
                    "model_v30.chkpt",
                    "model_v32.chkpt",
                ],
            )
            # Latest version should be 32
            self.assertEqual(32, checkpointer.get_latest_checkpoint_version())
            # Check all checkpoints
            for version in [24, 26, 28, 30, 32]:
                model = checkpointer.get_checkpoint_model(version)
                self.assertEqual(version, model.version)
            # Checkpoint not found
            self.assertRaisesRegex(
                RuntimeError,
                "Failed to read model checkpoint from file",
                checkpointer.get_checkpoint_model,
                100,
            )