예제 #1
0
    def _get_experiment_components(
        self,
        stage: str = None
    ) -> Tuple[Model, Criterion, Optimizer, Scheduler, Device]:
        """
        Inner method for `Experiment` components preparation.

        Check available torch device, takes model from the experiment
        and creates stage-specified criterion, optimizer, scheduler for it.

        Args:
            stage (str): experiment stage name of interest
                like "pretraining" / "training" / "finetuning" / etc

        Returns:
            tuple: model, criterion, optimizer,
                scheduler and device for a given stage and model
        """
        utils.set_global_seed(self.experiment.initial_seed)
        model = self.experiment.get_model(stage)
        criterion, optimizer, scheduler = \
            self.experiment.get_experiment_components(model, stage)

        model, criterion, optimizer, scheduler, device = \
            utils.process_components(
                model=model,
                criterion=criterion,
                optimizer=optimizer,
                scheduler=scheduler,
                distributed_params=self.experiment.distributed_params,
                device=self.device
            )

        return model, criterion, optimizer, scheduler, device
예제 #2
0
    def _get_experiment_components(
        experiment: IExperiment,
        stage: str = None,
        device: Device = None,
    ) -> Tuple[Model, Criterion, Optimizer, Scheduler, Device]:
        """
        Inner method for `Experiment` components preparation.

        Check available torch device, takes model from the experiment
        and creates stage-specified criterion, optimizer, scheduler for it.

        Args:
            stage (str): experiment stage name of interest
                like "pretrain" / "train" / "finetune" / etc

        Returns:
            tuple: model, criterion, optimizer,
                scheduler and device for a given stage and model
        """
        (
            model,
            criterion,
            optimizer,
            scheduler,
        ) = experiment.get_experiment_components(stage)
        (
            model,
            criterion,
            optimizer,
            scheduler,
            device,
        ) = utils.process_components(
            model=model,
            criterion=criterion,
            optimizer=optimizer,
            scheduler=scheduler,
            distributed_params=experiment.distributed_params,
            device=device,
        )
        return model, criterion, optimizer, scheduler, device