Ejemplo n.º 1
0
    def __init__(self, ckpt: Path, prominence: float):
        logging.info(f"running inferece using ckpt: {ckpt}")

        ckpt = torch.load(str(ckpt), map_location=lambda storage, loc: storage)
        hp = Namespace(**dict(ckpt["hparams"]))

        # load weights and peak detection params
        self.model = NextFrameClassifier(hp)
        weights = ckpt["state_dict"]
        weights = {k.replace("NFC.", ""): v for k, v in weights.items()}
        self.model.load_state_dict(weights)
        self.peak_detection_params = dill.loads(ckpt['peak_detection_params'])['cpc_1']
        if prominence is not None:
            logging.info(f"overriding prominence with {prominence}")
            self.peak_detection_params["prominence"] = prominence
Ejemplo n.º 2
0
def main(wav, ckpt, prominence, outpath):
    print(f"running inference on: {wav}")
    print(f"running inferece using ckpt: {ckpt}")
    print("\n\n", 90 * "-")

    ckpt = torch.load(ckpt, map_location=lambda storage, loc: storage)
    hp = Namespace(**dict(ckpt["hparams"]))

    # load weights and peak detection params
    model = NextFrameClassifier(hp)
    weights = ckpt["state_dict"]
    weights = {k.replace("NFC.", ""): v for k,v in weights.items()}
    model.load_state_dict(weights)
    peak_detection_params = dill.loads(ckpt['peak_detection_params'])['cpc_1']
    if prominence is not None:
        print(f"overriding prominence with {prominence}")
        peak_detection_params["prominence"] = prominence

    # load data
    wav_name = wav.split("/")[-1]
    audio, sr = torchaudio.load(wav)
    assert sr == 16000, "model was trained with audio sampled at 16khz, please downsample."
    audio = audio[0]
    audio = audio.unsqueeze(0)

    # run inference
    preds = model(audio)  # get scores
    preds = preds[1][0]  # get scores of positive pairs
    num_features = preds.size(1)
    preds = replicate_first_k_frames(preds, k=1, dim=1)  # padding
    preds = 1 - max_min_norm(preds)  # normalize scores (good for visualizations)
    
    preds = detect_peaks(x=preds,
                         lengths=[preds.shape[1]],
                         prominence=peak_detection_params["prominence"],
                         width=peak_detection_params["width"],
                         distance=peak_detection_params["distance"])  # run peak detection on scores

    mult = audio.size(1)/num_features
    preds = preds[0] * mult / sr  # transform frame indexes to seconds

    if not os.path.exists(outpath):
        os.makedirs(outpath)
    create_textgrid(preds, audio.size(1)/sr, os.path.join(outpath,wav_name.replace(".wav", ".TextGrid") ))
    print("predicted boundaries (in seconds):")
    print(preds)
Ejemplo n.º 3
0
class UnsupSegPredictor:

    def __init__(self, ckpt: Path, prominence: float):
        logging.info(f"running inferece using ckpt: {ckpt}")

        ckpt = torch.load(str(ckpt), map_location=lambda storage, loc: storage)
        hp = Namespace(**dict(ckpt["hparams"]))

        # load weights and peak detection params
        self.model = NextFrameClassifier(hp)
        weights = ckpt["state_dict"]
        weights = {k.replace("NFC.", ""): v for k, v in weights.items()}
        self.model.load_state_dict(weights)
        self.peak_detection_params = dill.loads(ckpt['peak_detection_params'])['cpc_1']
        if prominence is not None:
            logging.info(f"overriding prominence with {prominence}")
            self.peak_detection_params["prominence"] = prominence

    def predict(self, wav_path: Path) -> np.ndarray:
        logging.debug(f"running inference on: {wav_filepath}")
        audio, sr = torchaudio.load(str(wav_path))
        assert sr == 16000, "model was trained with audio sampled at 16khz, please downsample."
        audio = audio[0]
        audio = audio.unsqueeze(0)

        # run inference
        preds = self.model(audio)  # get scores
        preds = preds[1][0]  # get scores of positive pairs
        preds = replicate_first_k_frames(preds, k=1, dim=1)  # padding
        preds = 1 - max_min_norm(preds)  # normalize scores (good for visualizations)
        preds = detect_peaks(x=preds,
                             lengths=[preds.shape[1]],
                             prominence=self.peak_detection_params["prominence"],
                             width=self.peak_detection_params["width"],
                             distance=self.peak_detection_params["distance"])  # run peak detection on scores
        preds = preds[0] * 160 / sr  # transform frame indexes to seconds

        logging.debug("predicted boundaries (in seconds):")
        return preds
Ejemplo n.º 4
0
class Solver(LightningModule):
    def __init__(self, hparams):
        super(Solver, self).__init__()
        hp = hparams
        self.hp = hp
        self.hparams = hp
        self.peak_detection_params = defaultdict(lambda: {
            "prominence": 0.05,
            "width": None,
            "distance": None
        })
        self.pr = defaultdict(
            lambda: {
                "train": PrecisionRecallMetric(),
                "val": PrecisionRecallMetric(),
                "test": PrecisionRecallMetric()
            })
        self.best_rval = defaultdict(lambda: {
            "train": (0, 0),
            "val": (0, 0),
            "test": (0, 0)
        })
        self.overall_best_rval = 0
        self.stats = defaultdict(lambda: {
            "train": StatsMeter(),
            "val": StatsMeter(),
            "test": StatsMeter()
        })

        wandb.init(project=self.hp.project,
                   name=hp.exp_name,
                   config=vars(hp),
                   tags=[hp.tag])
        self.build_model()

    def prepare_data(self):
        # setup training set
        if "timit" in self.hp.data:
            train, val, test = TrainTestDataset.get_datasets(
                path=self.hp.timit_path)
        elif "buckeye" in self.hp.data:
            train, val, test = TrainValTestDataset.get_datasets(
                path=self.hp.buckeye_path, percent=self.hp.buckeye_percent)
        else:
            raise Exception("no such training data!")

        if "libri" in self.hp.data:
            libri_train = LibriSpeechDataset(path=self.hp.libri_path,
                                             subset=self.hp.libri_subset,
                                             percent=self.hp.libri_percent)
            train = ConcatDataset([train, libri_train])
            train.path = "\n\t+".join(
                [dataset.path for dataset in train.datasets])
            print(f"added libri ({len(libri_train)} examples)")

        self.train_dataset = train
        self.valid_dataset = val
        self.test_dataset = test

        line()
        print("DATA:")
        print(f"train: {self.train_dataset.path} ({len(self.train_dataset)})")
        print(f"valid: {self.valid_dataset.path} ({len(self.valid_dataset)})")
        print(f"test: {self.test_dataset.path} ({len(self.test_dataset)})")
        line()

    @pl.data_loader
    def train_dataloader(self):
        self.train_loader = DataLoader(
            self.train_dataset,
            batch_size=self.hp.batch_size,
            shuffle=True,
            collate_fn=collate_fn_padd,
            num_workers=self.hp.dataloader_n_workers)
        return self.train_loader

    @pl.data_loader
    def val_dataloader(self):
        self.valid_loader = DataLoader(
            self.valid_dataset,
            batch_size=self.hp.batch_size,
            shuffle=False,
            collate_fn=collate_fn_padd,
            num_workers=self.hp.dataloader_n_workers)
        return self.valid_loader

    @pl.data_loader
    def test_dataloader(self):
        self.test_loader = DataLoader(self.test_dataset,
                                      batch_size=self.hp.batch_size,
                                      shuffle=False,
                                      collate_fn=collate_fn_padd,
                                      num_workers=self.hp.dataloader_n_workers)
        return self.test_loader

    def build_model(self):
        print("MODEL:")
        self.NFC = NextFrameClassifier(self.hp)
        line()

    def forward(self, data_batch, batch_i, mode):
        loss = 0

        # TRAIN
        audio, seg, phonemes, length, fname = data_batch
        preds = self.NFC(audio)
        NFC_loss = self.NFC.loss(preds, length)
        self.stats['nfc_loss'][mode].update(NFC_loss.item())
        loss += NFC_loss

        # INFERENCE
        if mode == "test" or (mode == "val"
                              and self.hp.early_stop_metric == "val_max_rval"):
            positives = 0
            for t in self.NFC.pred_steps:
                p = preds[t][0]
                p = replicate_first_k_frames(p, k=t, dim=1)
                positives += p
            positives = 1 - max_min_norm(positives)
            self.pr[f'cpc_{t}'][mode].update(seg, positives, length)

        loss_key = "loss" if mode == "train" else f"{mode}_loss"
        return OrderedDict({loss_key: loss})

    def generic_eval_end(self, outputs, mode):
        metrics = {}
        data = self.hp.data

        for k, v in self.stats.items():
            metrics[f"train_{k}"] = self.stats[k]["train"].get_stats()
            metrics[f"{mode}_{k}"] = self.stats[k][mode].get_stats()

        epoch = self.current_epoch + 1
        metrics['epoch'] = epoch
        metrics['current_lr'] = self.opt.param_groups[0]['lr']

        line()
        for pred_type in self.pr.keys():
            if mode == "val":
                (precision, recall, f1,
                 rval), (width, prominence,
                         distance) = self.pr[pred_type][mode].get_stats()
                if rval > self.best_rval[pred_type][mode][0]:
                    self.best_rval[pred_type][mode] = rval, self.current_epoch
                    self.peak_detection_params[pred_type]["width"] = width
                    self.peak_detection_params[pred_type][
                        "prominence"] = prominence
                    self.peak_detection_params[pred_type][
                        "distance"] = distance
                    self.peak_detection_params[pred_type][
                        "epoch"] = self.current_epoch
                    print(
                        f"saving for test - {pred_type} - {self.peak_detection_params[pred_type]}"
                    )
            else:
                print(
                    f"using pre-defined peak detection values - {pred_type} - {self.peak_detection_params[pred_type]}"
                )
                (precision, recall, f1,
                 rval), _ = self.pr[pred_type][mode].get_stats(
                     width=self.peak_detection_params[pred_type]["width"],
                     prominence=self.peak_detection_params[pred_type]
                     ["prominence"],
                     distance=self.peak_detection_params[pred_type]
                     ["distance"],
                 )
                # test has only one epoch so set it as best
                # this is to get the overall best pred_type later
                self.best_rval[pred_type][mode] = rval, self.current_epoch
            metrics[f'{data}_{mode}_{pred_type}_f1'] = f1
            metrics[f'{data}_{mode}_{pred_type}_precision'] = precision
            metrics[f'{data}_{mode}_{pred_type}_recall'] = recall
            metrics[f'{data}_{mode}_{pred_type}_rval'] = rval
            metrics[f"{data}_{mode}_{pred_type}_max_rval"] = self.best_rval[
                pred_type][mode][0]
            metrics[
                f"{data}_{mode}_{pred_type}_max_rval_epoch"] = self.best_rval[
                    pred_type][mode][1]

        # get best rval from all rval types and all epochs
        best_overall_rval = -float("inf")
        for pred_type, rval in self.best_rval.items():
            if rval[mode][0] > best_overall_rval:
                best_overall_rval = rval[mode][0]
        metrics[f'{mode}_max_rval'] = best_overall_rval

        for k, v in metrics.items():
            print(f"\t{k:<30} -- {v}")
        line()
        wandb.log(metrics)

        output = OrderedDict({'log': metrics})

        return output

    def training_step(self, data_batch, batch_i):
        return self.forward(data_batch, batch_i, 'train')

    def validation_step(self, data_batch, batch_i):
        return self.forward(data_batch, batch_i, 'val')

    def test_step(self, data_batch, batch_i):
        return self.forward(data_batch, batch_i, 'test')

    def validation_end(self, outputs):
        return self.generic_eval_end(outputs, 'val')

    def test_end(self, outputs):
        return self.generic_eval_end(outputs, 'test')

    def configure_optimizers(self):
        parameters = filter(lambda p: p.requires_grad, self.parameters())
        if self.hp.optimizer == "sgd":
            self.opt = optim.SGD(parameters,
                                 lr=self.hparams.lr,
                                 momentum=0.9,
                                 weight_decay=5e-4)
        elif self.hp.optimizer == "adam":
            self.opt = optim.Adam(parameters,
                                  lr=self.hparams.lr,
                                  weight_decay=5e-4)
        elif self.hp.optimizer == "ranger":
            self.opt = optim_extra.Ranger(parameters,
                                          lr=self.hparams.lr,
                                          alpha=0.5,
                                          k=6,
                                          N_sma_threshhold=5,
                                          betas=(.95, 0.999),
                                          eps=1e-5,
                                          weight_decay=0)
        else:
            raise Exception("unknown optimizer")
        print(f"optimizer: {self.opt}")
        line()
        self.scheduler = optim.lr_scheduler.StepLR(
            self.opt,
            step_size=self.hp.lr_anneal_step,
            gamma=self.hp.lr_anneal_gamma)
        return [self.opt]

    def on_epoch_end(self):
        self.scheduler.step()

    def on_save_checkpoint(self, ckpt):
        ckpt['peak_detection_params'] = dill.dumps(self.peak_detection_params)

    def on_load_checkpoint(self, ckpt):
        self.peak_detection_params = dill.loads(ckpt['peak_detection_params'])

    def get_ckpt_path(self):
        return glob.glob(self.hp.wd + "/*.ckpt")[0]
Ejemplo n.º 5
0
 def build_model(self):
     print("MODEL:")
     self.NFC = NextFrameClassifier(self.hp)
     line()