예제 #1
0
 def _export_model(self, export_formats, export_dir):
     ExportFormat.validate(export_formats)
     exported = {}
     if ExportFormat.CHECKPOINT in export_formats:
         path = os.path.join(export_dir, ExportFormat.CHECKPOINT)
         self.export_policy_checkpoint(path)
         exported[ExportFormat.CHECKPOINT] = path
     if ExportFormat.MODEL in export_formats:
         path = os.path.join(export_dir, ExportFormat.MODEL)
         self.export_policy_model(path)
         exported[ExportFormat.MODEL] = path
     return exported
예제 #2
0
 def _export_model(self, export_formats, export_dir):
     ExportFormat.validate(export_formats)
     exported = {}
     if ExportFormat.CHECKPOINT in export_formats:
         path = os.path.join(export_dir, ExportFormat.CHECKPOINT)
         self.export_policy_checkpoint(path)
         exported[ExportFormat.CHECKPOINT] = path
     if ExportFormat.MODEL in export_formats:
         path = os.path.join(export_dir, ExportFormat.MODEL)
         self.export_policy_model(path)
         exported[ExportFormat.MODEL] = path
     return exported
예제 #3
0
    def _export_model(self, export_formats, export_dir):

        MODEL_WEIGHTS = "model_weights"

        try:
            ExportFormat.validate(export_formats)
        except TuneError as err:
            if MODEL_WEIGHTS in export_formats:
                idx = export_formats.index(MODEL_WEIGHTS)
                ExportFormat.validate(export_formats[:idx] +
                                      export_formats[idx + 1:])
            else:
                raise err

        exported = {}
        if ExportFormat.CHECKPOINT in export_formats:
            path = os.path.join(export_dir, ExportFormat.CHECKPOINT)
            self.export_policy_checkpoint(path)
            exported[ExportFormat.CHECKPOINT] = path
        if ExportFormat.MODEL in export_formats:
            path = os.path.join(export_dir, ExportFormat.MODEL)
            self.export_policy_model(path)
            exported[ExportFormat.MODEL] = path

        if MODEL_WEIGHTS in export_formats:
            path = os.path.join(export_dir, MODEL_WEIGHTS)

            for policy_id in self.config["export_policy_weights_ids"]:
                policy_to_save = self.workers.local_worker(
                ).policy_map[policy_id]
                local_file_path = os.path.join(export_dir, MODEL_WEIGHTS,
                                               policy_id + ".dill")
                policy_to_save.save_model_weights(
                    local_file_path, remove_scope_prefix=policy_id)
                exported[(MODEL_WEIGHTS, policy_id)] = path

        return exported
예제 #4
0
파일: trainer.py 프로젝트: zommiommy/ray
    def import_model(self, import_file):
        """Imports a model from import_file.

        Note: Currently, only h5 files are supported.

        Args:
            import_file (str): The file to import the model from.

        Returns:
            A dict that maps ExportFormats to successfully exported models.
        """
        # Check for existence.
        if not os.path.exists(import_file):
            raise FileNotFoundError(
                "`import_file` '{}' does not exist! Can't import Model.".
                format(import_file))
        # Get the format of the given file.
        import_format = "h5"  # TODO(sven): Support checkpoint loading.

        ExportFormat.validate([import_format])
        if import_format != ExportFormat.H5:
            raise NotImplementedError
        else:
            return self.import_policy_model_from_h5(import_file)
예제 #5
0
 def _export_model(self, export_formats, export_dir):
     ExportFormat.validate(export_formats)
     return {}