コード例 #1
0
ファイル: search_generator.py プロジェクト: smorad/ray
    def create_trial_if_possible(self, experiment_spec: Dict,
                                 output_path: str) -> Optional[Trial]:
        logger.debug("creating trial")
        trial_id = Trial.generate_id()
        suggested_config = self.searcher.suggest(trial_id)
        if suggested_config == Searcher.FINISHED:
            self._finished = True
            logger.debug("Searcher has finished.")
            return

        if suggested_config is None:
            return
        spec = copy.deepcopy(experiment_spec)
        spec["config"] = merge_dicts(spec["config"],
                                     copy.deepcopy(suggested_config))

        # Create a new trial_id if duplicate trial is created
        flattened_config = resolve_nested_dict(spec["config"])
        self._counter += 1
        tag = "{0}_{1}".format(str(self._counter),
                               format_vars(flattened_config))
        trial = create_trial_from_spec(
            spec,
            output_path,
            self._parser,
            evaluated_params=flatten_dict(suggested_config),
            experiment_tag=tag,
            trial_id=trial_id,
        )
        return trial
コード例 #2
0
    def _read_from_disk(self, policy_id):
        """Reads a policy ID from disk and re-adds it to the cache.
        """
        # Make sure this policy ID is not in the cache right now.
        assert policy_id not in self.cache
        # Read policy state from disk.
        with open(self.path + "/" + policy_id + self.extension, "rb") as f:
            policy_state = pickle.load(f)

        # Get class and config override.
        merged_conf = merge_dicts(self.policy_config,
                                  self.policy_specs[policy_id].config)

        # Create policy object (from its spec: cls, obs-space, act-space,
        # config).
        self.create_policy(
            policy_id,
            self.policy_specs[policy_id].policy_class,
            self.policy_specs[policy_id].observation_space,
            self.policy_specs[policy_id].action_space,
            self.policy_specs[policy_id].config,
            merged_conf,
        )
        # Restore policy's state.
        policy = self[policy_id]
        policy.set_state(policy_state)
コード例 #3
0
    def _train(self):
        """Calls `self.trainer.train()` and `self.trainer.validate()` once.

        You may want to override this if using a custom LR scheduler.
        """
        train_stats = self.trainer.train(max_retries=10, profile=True)
        validation_stats = self.trainer.validate(profile=True)
        stats = merge_dicts(train_stats, validation_stats)
        return stats
コード例 #4
0
ファイル: torch_trainer.py プロジェクト: hngenc/ray
    def step(self):
        """Calls `self.trainer.train()` and `self.trainer.validate()` once."""
        if self._implements_method("_train"):
            raise DeprecationWarning(
                "Trainable._train is deprecated and is now removed."
                "Override Trainable.step instead.")

        train_stats = self.trainer.train(max_retries=0, profile=True)
        validation_stats = self.trainer.validate(profile=True)
        stats = merge_dicts(train_stats, validation_stats)
        return stats
コード例 #5
0
ファイル: torch_trainer.py プロジェクト: yynst2/ray
    def step(self):
        """Calls `self.trainer.train()` and `self.trainer.validate()` once."""
        if self._is_overridden("_train"):
            raise DeprecationWarning(
                "Trainable._train is deprecated and will be "
                "removed in "
                "a future version of Ray. Override Trainable.step instead.")

        train_stats = self.trainer.train(max_retries=10, profile=True)
        validation_stats = self.trainer.validate(profile=True)
        stats = merge_dicts(train_stats, validation_stats)
        return stats
コード例 #6
0
ファイル: trainer.py プロジェクト: SymbioticLab/Fluid
    def _train(self):
        """Calls `self.trainer.train()` and `self.trainer.validate()` once.

        You may want to override this if using a custom LR scheduler.
        """
        train_stats = self.trainer.train(max_retries=10, profile=True)
        validation_stats = self.trainer.validate(profile=True)
        stats = merge_dicts(train_stats, validation_stats)

        stats.update(hostplan=self.trainer.config.get(
            "all_fluid_trial_resources", {}), )
        return stats
コード例 #7
0
 def _train(self):
     train_stats = self.trainer.train(max_retries=0, profile=True)
     validation_stats = self.trainer.validate(profile=True)
     stats = merge_dicts(train_stats, validation_stats)
     return stats