def from_checkpoint(cls, config, which_data_set="test"):
        r"""Instantiate an AudioSeparator from a model checkpoint.

            Loads the model from its checkpoint.
            The checkpoint also contains the configuration dictionary required to create the validation set related
            to the set used to train the model.

        Args:
            config (dict): Configuration dictionary with the parameters in defined in 'default_config()'
            which_data_set (str): Identifier of the set type for the 'split' method of the AudiodataSet. 'train',
                                  'test' or 'val'

        Returns:
            AudioSeparator using the model loaded from the checkpoint path in 'config'
        """

        # Load the checkpoint
        filename = config["checkpoint_path"]
        if not os.path.isfile(filename):
            raise ValueError("File " + filename + " is not a valid file.")
        print("Loading model ...'{}'".format(filename))
        state = torch.load(filename, 'cpu')

        # Get the configuration paramters used during the training of the model.
        train_config = state["config"]
        # Update those parameters with the AudioSeparator parameters.
        train_config.update(config)

        # Build the data set containing the audio to separate.
        val_set = dts.find_data_set_class(train_config["data_set_type"]).split(
            train_config, which_data_set)

        # Build the SeparationModel and load its parameters
        model = md.SeparationModel(train_config, val_set.features_shape(),
                                   val_set.n_classes())
        model.load_state_dict(state["model_state_dict"])

        # Build the AudioSeparator
        return cls(val_set, model, train_config)
Exemplo n.º 2
0
def parse_arguments():
    r"""Parse user arguments to determine the execution mode to use. Then parse the arguments required for this mode.

        The framework can perform 3 tasks (see main()). First the methd parses the command line arguments to detect
        which task should be done.
        Then the method parses the arguments which are specific to this task.
        This is done using the following trick:
            Most objects in the code (models, training_manager, AudioSeparator, AudioDataSets, etc...) have a method
            'default_config()'. This method returns a dictionary that contains the tunable parameters for this object
            along with their default value.
            The argument parser first aggregates all the parameters dictionaries for all the objects.
            Then it is going to check if the user entered an argument corresponding to any parameters in this
            aggregated dictionary. If an argument is passed, we update the default value of this parameter with the
            received value.
        This aggregated and updated dictionary will be passed to all objects in the code so that they can receive the
        user arguments.

    Returns:
        exec_mode, config: Execution mode identifier expected by 'main()' and the configuration dictionary required
        for this mode.
    """

    # First parse the execution mode
    parser = argparse.ArgumentParser(
        allow_abbrev=False,
        description="Audio separation framework in pytorch. \n "
        "See README.md for information about the command line arguments.")
    parser.add_argument(
        "--mode",
        type=str,
        required=True,
        help=
        "Which mode of the script to execute: train for training a model, evaluate for evaluating "
        "a model, and separate to generate source separated audio files")
    initial_args = vars(parser.parse_known_args()[0])
    exec_mode = initial_args.pop("mode")

    # Now gather the possible arguments for the input execution mode
    if exec_mode == "train":
        # In order to get the arguments for the models and dataset objects, we need to know the type of these objects
        parser.add_argument(
            "-m",
            "--mask_model_type",
            type=str,
            required=True,
            help=
            "Identifier for the class of the model bulding the segmentation masks."
            "See 'find_model_class' in separation_model.py")
        parser.add_argument(
            "-c",
            "--classifier_model_type",
            type=str,
            required=True,
            help="Identifier for the class of the classifier model. "
            "See 'find_model_class' in separation_model.py")
        parser.add_argument(
            "-d",
            "--data_set_type",
            type=str,
            required=True,
            help=
            "Identifier of the class of the data set. See 'find_data_set_class' in data_set.py"
        )
        args = vars(parser.parse_known_args()[0])
        initial_args.update(args)

        # Get all the arguments for the data_set, the model and the training_manager.
        data_set_default_config = ds.find_data_set_class(
            args["data_set_type"]).default_config()
        model_default_config = mod.SeparationModel.default_config(
            args["mask_model_type"], args["classifier_model_type"])
        training_default_config = tr.TrainingManager.default_config()
        # merge all dictionaries together
        default_config = {
            **data_set_default_config,
            **model_default_config,
            **training_default_config
        }

    # If execution mode is evaluate, all we need to know is the checkpoint to the model to evaluate.
    # The configuration for the model and data set are saved in the checkpoint.
    elif exec_mode == "evaluate":
        # get the checkpoint path
        parser.add_argument("--checkpoint_path",
                            type=str,
                            required=True,
                            help="Path to the saved model checkpoint.")
        args = vars(parser.parse_known_args()[0])
        # Read the configuration from checkpoint
        state = torch.load(args["checkpoint_path"], 'cpu')
        default_config = state["config"]

    # If execution mode is separate, we just need the AudioSeparator configuration.
    # Model and dataset configuration will be read from checkpoint.
    elif exec_mode == "separate":
        default_config = sep.AudioSeparator.default_config()

    else:
        raise NotImplementedError("Mode " + exec_mode + " is not Implemented")

    # Now parse the arguments for parameters in the default_config()
    full_parser = argparse.ArgumentParser(allow_abbrev=False)
    for key, value in default_config.items():
        key = "--{}".format(key)
        if isinstance(value, list) or isinstance(value, tuple):
            full_parser.add_argument(key,
                                     default=value,
                                     nargs='*',
                                     type=type(value[0]))
        elif isinstance(value, bool):
            full_parser.add_argument(key,
                                     default=value,
                                     nargs='?',
                                     type=str2bool)
        else:
            full_parser.add_argument(key, default=value, type=type(value))

    parsed_args = vars(full_parser.parse_known_args()[0])
    parsed_args.update(initial_args)
    default_config.update(initial_args)
    # If a model will be loaded from checkpoint, we do not want to over-write the values in its config dict with the
    # default values from default_config(). Therefore we only update the values if they were explicitly passed by user.
    if parsed_args["checkpoint_path"]:
        # Return only the values explicitly passed by user. Other values should be loaded from checkpoint
        new_args = {
            key: value
            for key, value in parsed_args.items()
            if value != default_config[key]
        }
        return exec_mode, new_args
    else:
        # update the values in config with the passed arguments or the default arguments
        default_config.update(parsed_args)
        return exec_mode, default_config
Exemplo n.º 3
0
    def __init__(self, config):
        r"""Constructor. Instantiates the data sets, the model, the optimizer, scheduler, loss function.

        Args:
            config (dict): Configuration dictionary with tunable training parameters
        """

        self.config = dict(config)

        self.device = torch.device("cpu") if not self.config["use_gpu"] \
            else torch.device("cuda:" + str(self.config["gpu_no"]))

        # Instantiate the data sets.
        self.train_set, self.test_set, self.val_set = \
            dts.find_data_set_class(self.config["data_set_type"]).split(self.config)

        # Scale the features
        self.shift_scale_data_sets()

        # Instantiate the model
        self.model = md.SeparationModel(config,
                                        self.train_set.features_shape(),
                                        self.train_set.n_classes())

        # Optimizer
        if self.config["optimizer"] == "Adam":
            self.optimizer = torch.optim.Adam(
                self.model.parameters(),
                lr=self.config["learning_rate"],
                weight_decay=self.config["weight_decay"])
        else:
            raise NotImplementedError('The optimizer ' +
                                      self.config["optimizer"] +
                                      ' is not available.')

        # Learning rate scheduler
        if self.config["scheduler_type"] == "stepLR":
            # Reduce lr after every step_size number of epoch
            self.scheduler = torch.optim.lr_scheduler.StepLR(
                optimizer=self.optimizer,
                step_size=self.config["scheduler_step_size"],
                gamma=self.config["scheduler_gamma"])
        elif self.config["scheduler_type"] == "multiStepLR":
            # Reduce the learning rate when the epochs in milestones are reached
            self.scheduler = torch.optim.lr_scheduler.MultiStepLR(
                optimizer=self.optimizer,
                milestones=self.config["scheduler_milestones"],
                gamma=self.config["scheduler_gamma"])
        elif self.config["scheduler_type"] == "reduceLROnPlateau":
            # Reduce learning rate if the loss value does not decrease during 'patience' number of epoch
            self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
                optimizer=self.optimizer,
                patience=self.config["scheduler_patience"],
                factor=self.config["scheduler_gamma"])
        elif not self.config["scheduler_type"]:
            # Do not use any scheduler
            self.scheduler = None
        else:
            raise NotImplementedError("Learning rate scheduler " +
                                      self.config["scheduler_type"] +
                                      " is not available.")

        # Loss function
        if self.config["loss_f"] == "BCE":
            self.loss_f = torch.nn.BCELoss()
        elif self.config["loss_f"] == "MultiLabelSoftMarginLoss":
            self.loss_f = torch.nn.MultiLabelSoftMarginLoss()
        else:
            raise NotImplementedError("Loss function " +
                                      self.config["loss_f"] +
                                      " is not available.")

        # l1 loss function, to penalize masks activations when they should be 0.
        self.l1_loss_f = torch.nn.L1Loss()
        self.l1_loss_lambda = self.config["l1_loss_lambda"]

        # list storing loss function and metric values for each epoch
        self.train_losses, self.test_losses, self.val_losses = [], [], []
        self.train_metrics, self.test_metrics, self.val_metrics = [], [], []

        # List to save the trainable pcen parameters at each epoch (if any)
        self.pcen_parameters = []