Example #1
0
    def apply(
        self, method: Method, config: Config = None
    ) -> "ContinualRLSetting.Results":
        """Apply the given method on this setting to producing some results. """
        # Use the supplied config, or parse one from the arguments that were
        # used to create `self`.
        self.config: Config
        if config is not None:
            self.config = config
            logger.debug(f"Using Config {self.config}")
        elif isinstance(getattr(method, "config", None), Config):
            self.config = method.config
            logger.debug(f"Using Config from the Method: {self.config}")
        else:
            logger.debug(f"Parsing the Config from the command-line.")
            self.config = Config.from_args(self._argv, strict=False)
            logger.debug(f"Resulting Config: {self.config}")

        # TODO: Test to make sure that this doesn't cause any other bugs with respect to
        # the display of stuff:
        # Call this method, which creates a virtual display if necessary.
        self.config.get_display()

        # TODO: Should we really overwrite the method's 'config' attribute here?
        if not getattr(method, "config", None):
            method.config = self.config

        # TODO: Remove `Setting.configure(method)` entirely, from everywhere,
        # and use the `prepare_data` or `setup` methods instead (since these
        # `configure` methods aren't using the `method` anyway.)
        method.configure(setting=self)

        # BUG This won't work if the task schedule uses callables as the values (as
        # they aren't json-serializable.)
        if self._new_random_task_on_reset:
            logger.info(
                f"Train tasks: "
                + json.dumps(list(self.train_task_schedule.values()), indent="\t")
            )
        else:
            logger.info(
                f"Train task schedule:"
                + json.dumps(self.train_task_schedule, indent="\t")
            )
        if self.config.debug:
            logger.debug(
                f"Test task schedule:"
                + json.dumps(self.test_task_schedule, indent="\t")
            )

        # Run the Training loop (which is defined in IncrementalSetting).
        results = self.main_loop(method)

        logger.info("Results summary:")
        logger.info(results.to_log_dict())
        logger.info(results.summary())
        method.receive_results(self, results=results)
        return results
Example #2
0
    def __init__(
        self,
        hparams: BaselineModel.HParams = None,
        config: Config = None,
        trainer_options: TrainerConfig = None,
        **kwargs,
    ):
        """ Creates a new BaselineMethod, using the provided configuration options.

        Parameters
        ----------
        hparams : BaselineModel.HParams, optional
            Hyper-parameters of the BaselineModel used by this Method. Defaults to None.

        config : Config, optional
            Configuration dataclass with options like log_dir, device, etc. Defaults to
            None.

        trainer_options : TrainerConfig, optional
            Dataclass which holds all the options for creating the `pl.Trainer` which
            will be used for training. Defaults to None.

        **kwargs :
            If any of the above arguments are left as `None`, then they will be created
            using any appropriate value from `kwargs`, if present.

        ## Examples:
        ```
        method = BaselineMethod(hparams=BaselineModel.HParams(learning_rate=0.01))
        method = BaselineMethod(learning_rate=0.01) # Same as above

        method = BaselineMethod(config=Config(debug=True))
        method = BaselineMethod(debug=True) # Same as above

        method = BaselineMethod(hparams=BaselineModel.HParams(learning_rate=0.01),
                                config=Config(debug=True))
        method = BaselineMethod(learning_rate=0.01, debug=True) # Same as above
        ```
        """
        # TODO: When creating a Method from a script, like `BaselineMethod()`,
        # should we expect the hparams to be passed? Should we create them from
        # the **kwargs? Should we parse them from the command-line?

        # Option 2: Try to use the keyword arguments to create the hparams,
        # config and trainer options.
        if kwargs:
            logger.info(
                f"using keyword arguments {kwargs} to populate the corresponding "
                f"values in the hparams, config and trainer_options.")
            self.hparams = hparams or BaselineModel.HParams.from_dict(
                kwargs, drop_extra_fields=True)
            self.config = config or Config.from_dict(kwargs,
                                                     drop_extra_fields=True)
            self.trainer_options = trainer_options or TrainerConfig.from_dict(
                kwargs, drop_extra_fields=True)

        elif self._argv:
            # Since the method was parsed from the command-line, parse those as
            # well from the argv that were used to create the Method.
            # Option 3: Parse them from the command-line.
            # assert not kwargs, "Don't pass any extra kwargs to the constructor!"
            self.hparams = hparams or BaselineModel.HParams.from_args(
                self._argv, strict=False)
            self.config = config or Config.from_args(self._argv, strict=False)
            self.trainer_options = trainer_options or TrainerConfig.from_args(
                self._argv, strict=False)

        else:
            # Option 1: Use the default values:
            self.hparams = hparams or BaselineModel.HParams()
            self.config = config or Config()
            self.trainer_options = trainer_options or TrainerConfig()
        assert self.hparams
        assert self.config
        assert self.trainer_options

        if self.config.debug:
            # Disable wandb logging if debug is True.
            self.trainer_options.no_wandb = True

        # The model and Trainer objects will be created in `self.configure`.
        # NOTE: This right here doesn't create the fields, it just gives some
        # type information for static type checking.
        self.trainer: Trainer
        self.model: BaselineModel

        self.additional_train_wrappers: List[Callable] = []
        self.additional_valid_wrappers: List[Callable] = []

        self.setting: Setting
Example #3
0
    def apply(self,
              method: "Method",
              config: Config = None) -> "SettingABC.Results":
        """ Applies a Method on this experimental Setting to produce Results. 
 
        Defines the training/evaluation procedure specific to this Setting.
        
        The training/evaluation loop can be defined however you want, as long as
        it respects the following constraints:

        1.  This method should always return either a float or a Results object
            that indicates the "performance" of this method on this setting.

        2. More importantly: You **have** to make sure that you do not break
            compatibility with more general methods targetting a parent setting!
            It should always be the case that all methods designed for any of
            this Setting's parents should also be applicable via polymorphism,
            i.e., anything that is defined to work on the class `Animal` should
            also work on the class `Cat`!

        3. While not enforced, it is strongly encourged that you define your
            training/evaluation routines at a pretty high level, so that Methods
            that get applied to your Setting can make use of pytorch-lightning's
            `Trainer` & `LightningDataModule` API to be neat and fast.
        
        Parameters
        ----------
        method : Method
            A Method to apply on this Setting.
        
        config : Optional[Config]
            Optional configuration object with things like the log dir, the data
            dir, cuda, wandb config, etc. When None, will be parsed from the
            current command-line arguments. 

        Returns
        -------
        Results
            An object that is used to measure or quantify the performance of the
            Method on this experimental Setting.
        """
        # For illustration purposes only:
        self.config = config or Config.from_args()
        method.configure(self)
        # IDEA: Maybe instead of passing the train_dataloader or test_dataloader,
        # objects, we could instead pass the methods of the LightningDataModule
        # themselves, so we wouldn't have to configure the 'batch_size' etc
        # arguments, and this way we could also directly control how many times
        # the dataloader method can be called, for instance to limit the number
        # of samples that a user can have access to (the number of epochs, etc).
        # Or the dataloader would only allow a given number of iterations!
        method.fit(
            train_env=self.train_dataloader(),
            valid_env=self.val_dataloader(),
        )

        test_metrics = []
        test_environment = self.test_dataloader()
        for observations in test_environment:
            # Get the predictions/actions:
            actions = method.get_actions(observations,
                                         test_environment.action_space)
            # Get the rewards for the given predictions.
            rewards = test_environment.send(actions)
            # Calculate the 'metrics' (TODO: This should be done be in the env!)
            metrics = self.get_metrics(actions=actions, rewards=rewards)
            test_metrics.append(metrics)

        results = self.Results(test_metrics)
        # TODO: allow the method to observe a 'copy' of the results, maybe?
        method.receive_results(self, results)
        return results