コード例 #1
0
                 format(ARGS.checkpoint_path))
        checkpoint = torch.load(ARGS.checkpoint_path)
        model = UNet(1, 1, bilinear=False)
        model.load_state_dict(checkpoint["modelState"])
        log.warning(
            "Using default preprocessing options. Provide Model file if they are changed"
        )
        dataOpts = DefaultSpecDatasetOps
    else:
        if ARGS.jit_load:
            extra_files = {}
            extra_files['dataOpts'] = ''
            model = torch.jit.load(ARGS.model_path, _extra_files=extra_files)
            unetState = model.state_dict()
            dataOpts = eval(extra_files['dataOpts'])
            log.debug("Model successfully load via torch jit: " +
                      str(ARGS.model_path))
        else:
            model_dict = torch.load(ARGS.model_path)
            model = UNet(1, 1, bilinear=False)
            model.load_state_dict(model_dict["unetState"])
            model = nn.Sequential(OrderedDict([("denoiser", model)]))
            dataOpts = model_dict["dataOpts"]
            log.debug("Model successfully load via torch load: " +
                      str(ARGS.model_path))

    log.info(model)

    if ARGS.visualize:
        sp = signal.signal_proc()
    else:
        sp = None
コード例 #2
0
ファイル: main.py プロジェクト: ChristianBergler/ORCA-SPOT
    for arg, value in vars(ARGS).items():
        if arg in encoderOpts and value is not None:
            encoderOpts[arg] = value
        if arg in classifierOpts and value is not None:
            classifierOpts[arg] = value
        if arg in dataOpts and value is not None:
            dataOpts[arg] = value

    ARGS.lr *= ARGS.batch_size

    patience_lr = math.ceil(ARGS.lr_patience_epochs / ARGS.epochs_per_eval)

    patience_lr = int(max(1, patience_lr))

    log.debug("dataOpts: " + json.dumps(dataOpts, indent=4))

    sequence_len = int(
        float(ARGS.sequence_len) / 1000 * dataOpts["sr"] /
        dataOpts["hop_length"])
    log.debug("Training with sequence length: {}".format(sequence_len))
    input_shape = (ARGS.batch_size, 1, dataOpts["n_freq_bins"], sequence_len)

    log.info("Setting up model")

    encoder = Encoder(encoderOpts)
    log.debug("Encoder: " + str(encoder))
    encoder_out_ch = 512 * encoder.block_type.expansion

    classifierOpts["num_classes"] = 2
    classifierOpts["input_channels"] = encoder_out_ch
コード例 #3
0
class CsvSplit(object):
    def __init__(
        self,
        split_fracs: Dict[str, float],
        working_dir: (str) = None,
        seed: (int) = None,
        split_per_dir=False,
    ):
        if not np.isclose(np.sum([p for _, p in split_fracs.items()]), 1.):
            raise ValueError("Split probabilities have to sum up to 1.")
        self.split_fracs = split_fracs
        self.working_dir = working_dir
        self.seed = seed
        self.split_per_dir = split_per_dir
        self.splits = defaultdict(list)
        self._logger = Logger("CSVSPLIT")

    """
    Return split for given partition. If there is already an existing CSV split return this split if it is valid or
    in case there exist not a split yet generate a new CSV split
    """

    def load(self, split: str, files: List[Any] = None):

        if split not in self.split_fracs:
            raise ValueError(
                "Provided split '{}' is not in `self.split_fracs`.".format(
                    split))

        if self.splits[split]:
            return self.splits[split]
        if self.working_dir is None:
            self.splits = self._split_with_seed(files)
            return self.splits[split]
        if self.can_load_from_csv():
            if not self.split_per_dir:
                csv_split_files = {
                    split_: (os.path.join(self.working_dir, split_ + ".csv"), )
                    for split_ in self.split_fracs.keys()
                }
            else:
                csv_split_files = {}
                for split_ in self.split_fracs.keys():
                    split_file = os.path.join(self.working_dir, split_)
                    csv_split_files[split_] = []
                    with open(split_file, "r") as f:
                        for line in f.readlines():
                            csv_split_files[split_].append(line.strip())

            for split_ in self.split_fracs.keys():
                for csv_file in csv_split_files[split_]:
                    if not csv_file or csv_file.startswith(r"#"):
                        continue
                    csv_file_path = os.path.join(self.working_dir, csv_file)
                    with open(csv_file_path, "r") as f:
                        reader = csv.reader(f)
                        for item in reader:
                            file_ = os.path.basename(item[0])
                            file_ = os.path.join(os.path.dirname(csv_file),
                                                 file_)
                            self.splits[split_].append(file_)
            return self.splits[split]

        if not self.split_per_dir:
            working_dirs = (self.working_dir, )
        else:
            f_d_map = self._get_f_d_map(files)
            working_dirs = [
                os.path.join(self.working_dir, p) for p in f_d_map.keys()
            ]
        for working_dir in working_dirs:
            splits = self._split_with_seed(
                files if not self.split_per_dir else f_d_map[working_dir])
            for split_ in splits.keys():
                csv_file = os.path.join(working_dir, split_ + ".csv")
                self._logger.debug("Generating {}".format(csv_file))
                if self.split_per_dir:
                    with open(os.path.join(self.working_dir, split_),
                              "a") as f:
                        p = pathlib.Path(csv_file).relative_to(
                            self.working_dir)
                        f.write(str(p) + "\n")
                if len(splits[split_]) == 0:
                    raise ValueError(
                        "Error splitting dataset. Split '{}' has 0 entries".
                        format(split_))
                with open(csv_file, "w", newline="") as fh:
                    writer = csv.writer(fh)
                    for item in splits[split_]:
                        writer.writerow([item])
                self.splits[split_].extend(splits[split_])
        return self.splits[split]

    """
    Check whether it is possible to correctly load information from existing csv files
    """

    def can_load_from_csv(self):
        if not self.working_dir:
            return False
        if self.split_per_dir:
            for split in self.split_fracs.keys():
                split_file = os.path.join(self.working_dir, split)
                if not os.path.isfile(split_file):
                    return False
                self._logger.debug(
                    "Found dataset split file {}".format(split_file))
                with open(split_file, "r") as f:
                    for line in f.readlines():
                        csv_file = line.strip()
                        if not csv_file or csv_file.startswith(r"#"):
                            continue
                        if not os.path.isfile(
                                os.path.join(self.working_dir, csv_file)):
                            self._logger.error(
                                "File not found: {}".format(csv_file))
                            raise ValueError(
                                "Split file found, but csv files are missing. "
                                "Aborting...")
        else:
            for split in self.split_fracs.keys():
                csv_file = os.path.join(self.working_dir, split + ".csv")
                if not os.path.isfile(csv_file):
                    return False
                self._logger.debug("Found csv file {}".format(csv_file))
        return True

    """
    Create a mapping from directory to containing files.
    """

    def _get_f_d_map(self, files: List[Any]):

        f_d_map = defaultdict(list)
        if self.working_dir is not None:
            for f in files:
                f_d_map[str(pathlib.Path(
                    self.working_dir).joinpath(f).parent)].append(f)
        else:
            for f in files:
                f_d_map[str(
                    pathlib.Path(".").resolve().joinpath(f).parent)].append(f)
        return f_d_map

    """
    Randomly splits the dataset using given seed
    """

    def _split_with_seed(self, files: List[Any]):
        if not files:
            raise ValueError("Provided list `files` is `None`.")
        if self.seed:
            random.seed(self.seed)
        return self.split_fn(files)

    """
    A generator function that returns all values for the given `split`.
    """

    def split_fn(self, files: List[Any]):
        _splits = np.split(
            ary=random.sample(files, len(files)),
            indices_or_sections=[
                int(p * len(files)) for _, p in self.split_fracs.items()
            ],
        )
        splits = dict()
        for i, key in enumerate(self.splits.keys()):
            splits[key] = _splits[i]
        return splits
コード例 #4
0
"""
if __name__ == "__main__":

    dataOpts = DefaultSpecDatasetOps

    for arg, value in vars(ARGS).items():
        if arg in dataOpts and value is not None:
            dataOpts[arg] = value

    ARGS.lr *= ARGS.batch_size

    patience_lr = math.ceil(ARGS.lr_patience_epochs / ARGS.epochs_per_eval)

    patience_lr = int(max(1, patience_lr))

    log.debug("dataOpts: " + json.dumps(dataOpts, indent=4))

    sequence_len = int(
        float(ARGS.sequence_len) / 1000 * dataOpts["sr"] /
        dataOpts["hop_length"])
    log.debug("Training with sequence length: {}".format(sequence_len))
    input_shape = (ARGS.batch_size, 1, dataOpts["n_freq_bins"], sequence_len)

    log.info("Setting up model")

    unet = UNet(n_channels=1, n_classes=1, bilinear=False)

    log.debug("Model: " + str(unet))
    model = nn.Sequential(OrderedDict([("unet", unet)]))

    split_fracs = {"train": .7, "val": .15, "test": .15}
コード例 #5
0
        "decoderOpts": decoderOpts,
        "dataOpts": dataOpts,
        "encoderState": encoder_state_dict,
        "decoderState": decoder_state_dict,
    }
    if not os.path.isdir(ARGS.model_dir):
        os.makedirs(ARGS.model_dir)
    torch.save(save_dict, path)


if __name__ == "__main__":
    """------------- model-related preparation -------------"""
    # load and update the options for setting up the model
    if ARGS.model is not None:
        if ARGS.model == "plain_ae":
            log.debug("Plain autoencoder is chosen. n_bottleneck: {}".format(
                ARGS.n_bottleneck))
        elif ARGS.model == "conv_ae":
            encoderOpts = DefaultEncoderOpts
            decoderOpts = DefaultDecoderOpts

            # update the respective parameters if given in terminal
            for arg, value in vars(ARGS).items():
                if arg in encoderOpts and value is not None:
                    encoderOpts[arg] = value
                if arg in decoderOpts and value is not None:
                    decoderOpts[arg] = value
        else:
            raise ValueError(
                "Expected plain_ae or conv_ae as model but received: {}".
                format(ARGS.model))
    else:
コード例 #6
0
    sr = dataOpts['sr']
    hop_length = dataOpts["hop_length"]
    n_fft = dataOpts["n_fft"]

    try:
        n_freq_bins = dataOpts["num_mels"]
    except KeyError:
        n_freq_bins = dataOpts["n_freq_bins"]

    fmin = dataOpts["fmin"]
    fmax = dataOpts["fmax"]
    freq_cmpr = dataOpts["freq_compression"]
    DefaultSpecDatasetOps["min_level_db"] = dataOpts["min_level_db"]
    DefaultSpecDatasetOps["ref_level_db"] = dataOpts["ref_level_db"]

    log.debug("dataOpts: " + str(dataOpts))

    if ARGS.min_max_norm:
        log.debug("Init min-max-normalization activated")
    else:
        log.debug("Init 0/1-dB-normalization activated")

    sequence_len = int(ceil(ARGS.sequence_len * sr))
    hop = int(ceil(ARGS.hop * sr))

    log.info("Predicting {} files".format(len(ARGS.audio_files)))

    for file_name in ARGS.audio_files:
        log.info(file_name)
        dataset = StridedAudioDataset(file_name.strip(),
                                      sequence_len=sequence_len,
    model.eval()

    sr = dataOpts["sr"]  # modified, s.t. not hard-coded
    hop_length = dataOpts["hop_length"]
    n_fft = dataOpts["n_fft"]

    try:
        n_freq_bins = dataOpts["num_mels"]
    except KeyError:
        n_freq_bins = dataOpts["n_freq_bins"]

    freq_compression = dataOpts[
        "freq_compression"]  # added, missing in orig master; is freq compression not needed during inference?
    fmin = dataOpts["fmin"]
    fmax = dataOpts["fmax"]
    log.debug("dataOpts: " + str(dataOpts))
    #sequence_len = int(ceil(ARGS.sequence_len * sr))
    sequence_len = int(
        float(ARGS.sequence_len) / 1000 * dataOpts["sr"] /
        dataOpts["hop_length"])
    hop = int(ceil(ARGS.hop * sr))

    log.info("Predicting {} files".format(len(audio_files)))

    #for file_name in audio_files:
    #    log.info(file_name)
    #    dataset = StridedAudioDataset(
    #        os.path.join(ARGS.data_dir, file_name.strip()),
    #        sequence_len=sequence_len,
    #        hop=hop,
    #        sr=sr,
コード例 #8
0
ファイル: predict.py プロジェクト: sness23/ORCA-SPOT
    if torch.cuda.is_available() and ARGS.cuda:
        model = model.cuda()
    model.eval()

    sr = dataOpts['sr']
    hop_length = dataOpts["hop_length"]
    n_fft = dataOpts["n_fft"]

    try:
        n_freq_bins = dataOpts["num_mels"]
    except KeyError:
        n_freq_bins = dataOpts["n_freq_bins"]

    fmin = dataOpts["fmin"]
    fmax = dataOpts["fmax"]
    log.debug("dataOpts: " + str(dataOpts))
    sequence_len = int(ceil(ARGS.sequence_len * sr))
    hop = int(ceil(ARGS.hop * sr))

    log.info("Predicting {} files".format(len(ARGS.audio_files)))

    for file_name in ARGS.audio_files:
        log.info(file_name)
        dataset = StridedAudioDataset(
            file_name.strip(),
            sequence_len=sequence_len,
            hop=hop,
            sr=sr,
            fft_size=n_fft,
            fft_hop=hop_length,
            n_freq_bins=n_freq_bins,