예제 #1
0
def main(args=None):
    if args is None:
        args = argument_paser()

    # Set experiment id
    exp_id = str(uuid.uuid4())[:8] if args.exp_id is None else args.exp_id
    print(f'Experiment Id: {exp_id}', flush=True)

    # Fix seed
    torch.manual_seed(args.seed)

    # Config gpu
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")

    # Prepare data
    dataset = MovingMnistDataset()
    train_index, valid_index = train_test_split(range(len(dataset)),
                                                test_size=0.3)
    train_loader = DataLoader(Subset(dataset, train_index),
                              batch_size=args.batch_size,
                              shuffle=True)
    valid_loader = DataLoader(Subset(dataset, valid_index),
                              batch_size=args.test_batch_size,
                              shuffle=False)
    loaders = {"train": train_loader, "valid": valid_loader}

    model = ConvLSTMEncoderPredictor(image_size=(64, 64)).to(device)

    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.lr,
                                 betas=(0.9, 0.999))
    criterion = nn.MSELoss()

    runner = SupervisedRunner(device=catalyst.utils.get_device())
    runner.train(
        model=model,
        criterion=criterion,
        optimizer=optimizer,
        scheduler=None,
        loaders=loaders,
        # model will be saved to {logdir}/checkpoints
        logdir=os.path.join(args.log_dir, exp_id),
        callbacks=[
            CheckpointCallback(save_n_best=args.n_saved),
            EarlyStoppingCallback(
                patience=args.es_patience,
                metric="loss",
                minimize=True,
            )
        ],
        num_epochs=args.epochs,
        main_metric="loss",
        minimize_metric=True,
        fp16=None,
        verbose=True)

    return exp_id, model
예제 #2
0
    def get_callbacks(self) -> "OrderedDict[str, Callback]":
        """Returns the callbacks for the experiment."""
        callbacks = sort_callbacks_by_order(self._callbacks)
        callback_exists = lambda callback_fn: any(
            callback_isinstance(x, callback_fn) for x in callbacks.values()
        )
        if self._verbose and not callback_exists(TqdmCallback):
            callbacks["_verbose"] = TqdmCallback()
        if self._timeit and not callback_exists(TimerCallback):
            callbacks["_timer"] = TimerCallback()
        if self._check and not callback_exists(CheckRunCallback):
            callbacks["_check"] = CheckRunCallback()
        if self._overfit and not callback_exists(BatchOverfitCallback):
            callbacks["_overfit"] = BatchOverfitCallback()
        if self._profile and not callback_exists(ProfilerCallback):
            callbacks["_profile"] = ProfilerCallback(
                tensorboard_path=os.path.join(self._logdir, "tb_profile"),
                profiler_kwargs={
                    "activities": [
                        torch.profiler.ProfilerActivity.CPU,
                        torch.profiler.ProfilerActivity.CUDA,
                    ],
                    "on_trace_ready": torch.profiler.tensorboard_trace_handler(
                        os.path.join(self._logdir, "tb_profile")
                    ),
                    "with_stack": True,
                    "with_flops": True,
                },
            )

        if self._logdir is not None and not callback_exists(ICheckpointCallback):
            callbacks["_checkpoint"] = CheckpointCallback(
                logdir=os.path.join(self._logdir, "checkpoints"),
                loader_key=self._valid_loader,
                metric_key=self._valid_metric,
                minimize=self._minimize_valid_metric,
                load_best_on_end=self._load_best_on_end,
            )
        return callbacks
예제 #3
0
    def get_callbacks(self, stage: str) -> "OrderedDict[str, Callback]":
        """Returns the callbacks for a given stage."""
        callbacks = sort_callbacks_by_order(self._callbacks)
        is_callback_exists = lambda callback_fn: any(
            callback_isinstance(x, callback_fn) for x in callbacks.values()
        )
        if self._verbose and not is_callback_exists(TqdmCallback):
            callbacks["_verbose"] = TqdmCallback()
        if self._timeit and not is_callback_exists(TimerCallback):
            callbacks["_timer"] = TimerCallback()
        if self._check and not is_callback_exists(CheckRunCallback):
            callbacks["_check"] = CheckRunCallback()
        if self._overfit and not is_callback_exists(BatchOverfitCallback):
            callbacks["_overfit"] = BatchOverfitCallback()

        if self._logdir is not None and not is_callback_exists(ICheckpointCallback):
            callbacks["_checkpoint"] = CheckpointCallback(
                logdir=os.path.join(self._logdir, "checkpoints"),
                loader_key=self._valid_loader,
                metric_key=self._valid_metric,
                minimize=self._minimize_valid_metric,
            )
        return callbacks
예제 #4
0
def create_callbacks(args, criterion_names):
    callbacks = [
        IoUMetricsCallback(mode=args.dice_mode,
                           input_key=args.input_target_key,
                           class_names=args.class_names.split(',')
                           if args.class_names else None),
        CheckpointCallback(save_n_best=args.save_n_best),
        EarlyStoppingCallback(
            patience=args.patience,
            metric=args.eval_metric,
            minimize=True if args.eval_metric == 'loss' else False)
    ]
    metrics_weights = {}
    for cn in criterion_names:
        callbacks.append(
            CriterionCallback(input_key=args.input_target_key,
                              prefix=f"loss_{cn}",
                              criterion_key=cn))
        metrics_weights[f'loss_{cn}'] = 1.0
    callbacks.append(
        MetricAggregationCallback(prefix="loss",
                                  mode="weighted_sum",
                                  metrics=metrics_weights))
    return callbacks
예제 #5
0
    def train(
        self,
        *,
        model: Model,
        criterion: Criterion = None,
        optimizer: Optimizer = None,
        scheduler: Scheduler = None,
        datasets: "OrderedDict[str, Union[Dataset, Dict, Any]]" = None,
        loaders: "OrderedDict[str, DataLoader]" = None,
        callbacks: "Union[List[Callback], OrderedDict[str, Callback]]" = None,
        logdir: str = None,
        resume: str = None,
        num_epochs: int = 1,
        valid_loader: str = "valid",
        main_metric: str = "loss",
        minimize_metric: bool = True,
        verbose: bool = False,
        stage_kwargs: Dict = None,
        checkpoint_data: Dict = None,
        fp16: Union[Dict, bool] = None,
        distributed: bool = False,
        check: bool = False,
        overfit: bool = False,
        timeit: bool = False,
        load_best_on_end: bool = False,
        initial_seed: int = 42,
        state_kwargs: Dict = None,
    ) -> None:
        """
        Starts the train stage of the model.

        Args:
            model: model to train
            criterion: criterion function for training
            optimizer: optimizer for training
            scheduler: scheduler for training
            datasets (OrderedDict[str, Union[Dataset, Dict, Any]]): dictionary
                with one or several  ``torch.utils.data.Dataset``
                for training, validation or inference
                used for Loaders automatic creation
                preferred way for distributed training setup
            loaders (OrderedDict[str, DataLoader]): dictionary
                with one or several ``torch.utils.data.DataLoader``
                for training, validation or inference
            callbacks (Union[List[Callback], OrderedDict[str, Callback]]):
                list or dictionary with Catalyst callbacks
            logdir: path to output directory
            resume: path to checkpoint for model
            num_epochs: number of training epochs
            valid_loader: loader name used to calculate
                the metrics and save the checkpoints. For example,
                you can pass `train` and then
                the metrics will be taken from `train` loader.
            main_metric: the key to the name of the metric
                by which the checkpoints will be selected.
            minimize_metric: flag to indicate whether
                the ``main_metric`` should be minimized.
            verbose: if `True`, it displays the status of the training
                to the console.
            stage_kwargs: additional params for stage
            checkpoint_data: additional data to save in checkpoint,
                for example: ``class_names``, ``date_of_training``, etc
            fp16 (Union[Dict, bool]): If not None, then sets training to FP16.
                To use pytorch native amp: ``{"amp": True}``
                To use apex: ``{"apex": True, "opt_level": "O1", ...}``
                    See https://nvidia.github.io/apex/amp.html#properties
                    for more params

                If fp16=True, params by default will be:
                    * ``{"amp": True}`` if torch>=1.6.0
                    * ``{"apex": True, "opt_level": "O1", ...}`` if torch<1.6.0
            distributed: if `True` will start training
                in distributed mode.
                Note: Works only with python scripts. No jupyter support.
            check: if True, then only checks that pipeline is working
                (3 epochs only with 3 batches per loader)
            overfit: if True, then takes only one batch per loader
                for model overfitting, for advance usage please check
                ``BatchOverfitCallback``
            timeit: if True, computes the execution time
                of training process and displays it to the console.
            load_best_on_end: if True, Runner will load
                best checkpoint state (model, optimizer, etc)
                according to validation metrics. Requires specified ``logdir``.
            initial_seed: experiment's initial seed value
            state_kwargs: deprecated, use `stage_kwargs` instead

        Raises:
            NotImplementedError: if both `resume` and `CheckpointCallback`
                already exist
        """
        assert state_kwargs is None or stage_kwargs is None

        fp16 = _resolve_bool_fp16(fp16)

        if resume is not None or load_best_on_end:
            load_on_stage_end = None
            if load_best_on_end:
                load_on_stage_end = "best_full"
                assert logdir is not None, ("For ``load_best_on_end`` feature "
                                            "you need to specify ``logdir``")
            callbacks = sort_callbacks_by_order(callbacks)
            checkpoint_callback_flag = any(
                isinstance(x, CheckpointCallback) for x in callbacks.values())
            if not checkpoint_callback_flag:
                callbacks["_loader"] = CheckpointCallback(
                    resume=resume,
                    load_on_stage_end=load_on_stage_end,
                )
            else:
                raise NotImplementedError("CheckpointCallback already exist")

        experiment = self._experiment_fn(
            stage="train",
            model=model,
            datasets=datasets,
            loaders=loaders,
            callbacks=callbacks,
            logdir=logdir,
            criterion=criterion,
            optimizer=optimizer,
            scheduler=scheduler,
            num_epochs=num_epochs,
            valid_loader=valid_loader,
            main_metric=main_metric,
            minimize_metric=minimize_metric,
            verbose=verbose,
            check_time=timeit,
            check_run=check,
            overfit=overfit,
            stage_kwargs=stage_kwargs or state_kwargs,
            checkpoint_data=checkpoint_data,
            distributed_params=fp16,
            initial_seed=initial_seed,
        )
        self.experiment = experiment
        distributed_cmd_run(self.run_experiment, distributed)
예제 #6
0
    def infer(
        self,
        *,
        model: Model,
        datasets: "OrderedDict[str, Union[Dataset, Dict, Any]]" = None,
        loaders: "OrderedDict[str, DataLoader]" = None,
        callbacks: "Union[List[Callback], OrderedDict[str, Callback]]" = None,
        logdir: str = None,
        resume: str = None,
        verbose: bool = False,
        stage_kwargs: Dict = None,
        fp16: Union[Dict, bool] = None,
        check: bool = False,
        timeit: bool = False,
        initial_seed: int = 42,
        state_kwargs: Dict = None,
    ) -> None:
        """
        Starts the inference stage of the model.

        Args:
            model: model for inference
            datasets (OrderedDict[str, Union[Dataset, Dict, Any]]): dictionary
                with one or several  ``torch.utils.data.Dataset``
                for training, validation or inference
                used for Loaders automatic creation
                preferred way for distributed training setup
            loaders (OrderedDict[str, DataLoader]): dictionary
                with one or several ``torch.utils.data.DataLoader``
                for training, validation or inference
            callbacks (Union[List[Callback], OrderedDict[str, Callback]]):
                list or dictionary with Catalyst callbacks
            logdir: path to output directory
            resume: path to checkpoint to use for resume
            verbose: if `True`, it displays the status of the training
                to the console.
            stage_kwargs: additional stage params
            fp16 (Union[Dict, bool]): fp16 settings (same as in `train`)
            check: if True, then only checks that pipeline is working
                (3 epochs only)
            timeit: if True, computes the execution time
                of training process and displays it to the console.
            initial_seed: experiment's initial seed value
            state_kwargs: deprecated, use `stage_kwargs` instead

        Raises:
            NotImplementedError: if both `resume` and `CheckpointCallback`
                already exist
        """
        assert state_kwargs is None or stage_kwargs is None

        fp16 = _resolve_bool_fp16(fp16)

        if resume is not None:
            callbacks = sort_callbacks_by_order(callbacks)
            checkpoint_callback_flag = any(
                isinstance(x, CheckpointCallback) for x in callbacks.values())
            if not checkpoint_callback_flag:
                callbacks["loader"] = CheckpointCallback(resume=resume)
            else:
                raise NotImplementedError("CheckpointCallback already exist")

        experiment = self._experiment_fn(
            stage="infer",
            model=model,
            datasets=datasets,
            loaders=loaders,
            callbacks=callbacks,
            logdir=logdir,
            verbose=verbose,
            check_time=timeit,
            check_run=check,
            stage_kwargs=stage_kwargs or state_kwargs,
            distributed_params=fp16,
            initial_seed=initial_seed,
        )
        self.run_experiment(experiment)