예제 #1
0
파일: eval.py 프로젝트: ceshine/oggdo
def orig(encoder, df, batch_size: int):
    embedder = encoder[0]
    model = SentencePairCosineSimilarity(encoder)

    dataset = SimilarityDataset(embedder.tokenizer, df, "sentence1",
                                "sentence2", "score")
    loader = DataLoader(
        dataset,
        sampler=SortSampler(
            dataset,
            key=lambda x: max(len(dataset.text_1[x]), len(dataset.text_2[x]))),
        collate_fn=partial(collate_pairs,
                           pad=0,
                           opening_id=embedder.cls_token_id,
                           closing_id=embedder.sep_token_id,
                           truncate_length=embedder.max_seq_length),
        batch_size=batch_size)
    preds, references = [], []
    with torch.no_grad():
        for features, labels in tqdm(loader, ncols=100):
            features = features_to_device(features, encoder.device)
            preds.append(model(**features).cpu().numpy())
            references.append(labels.cpu().numpy())

    preds = np.concatenate(preds)
    references = np.concatenate(references)

    return preds, references
예제 #2
0
def get_loaders(embedder, args) -> Tuple[DataLoader, DataLoader]:
    ds_train = LcqmcDataset(embedder.tokenizer,
                            filename="train.txt",
                            cache_dir=CACHE_DIR)
    train_loader = DataLoader(
        ds_train,
        sampler=SortishSampler(ds_train,
                               key=pair_max_len(ds_train),
                               bs=args.batch_size),
        collate_fn=partial(collate_pairs,
                           pad=0,
                           opening_id=embedder.cls_token_id,
                           closing_id=embedder.sep_token_id,
                           truncate_length=embedder.max_seq_length),
        batch_size=args.batch_size,
        num_workers=args.workers)
    ds_valid = LcqmcDataset(embedder.tokenizer,
                            filename="dev.txt",
                            cache_dir=CACHE_DIR)
    valid_loader = DataLoader(ds_valid,
                              sampler=SortSampler(ds_valid,
                                                  key=pair_max_len(ds_valid)),
                              collate_fn=partial(
                                  collate_pairs,
                                  pad=0,
                                  opening_id=embedder.cls_token_id,
                                  closing_id=embedder.sep_token_id,
                                  truncate_length=embedder.max_seq_length),
                              batch_size=args.batch_size * 2,
                              num_workers=args.workers)
    return train_loader, valid_loader
예제 #3
0
def orig(args, model):
    encoder = model.encoder
    embedder = model.encoder[0]

    dataset = LcqmcDataset(embedder.tokenizer, filename=args.filename)
    loader = DataLoader(
        dataset,
        sampler=SortSampler(
            dataset,
            key=lambda x: max(len(dataset.text_1[x]), len(dataset.text_2[x]))),
        collate_fn=partial(collate_pairs,
                           pad=0,
                           opening_id=embedder.cls_token_id,
                           closing_id=embedder.sep_token_id,
                           truncate_length=embedder.max_seq_length),
        batch_size=16)
    preds, references = [], []
    with torch.no_grad():
        for features, labels in tqdm(loader):
            for name in features:
                features[name] = features[name].to(encoder.device)
            preds.append(model(features).cpu().numpy())
            references.append(labels.cpu().numpy())

    preds = np.concatenate(preds)
    references = np.concatenate(references)
    spearman_score = spearmanr(preds, references)
    print(f"Spearman: {spearman_score.correlation:.4f}")

    return preds, references
예제 #4
0
def main(
    dataset: SBertDataset, model_path: str, output_folder: str = "cache/teacher_embs/",
    batch_size: int = 32, workers: int = 2, attentions: bool = False, sample: float = -1,
    no_train: bool = False
):
    # this is designed for stsb-roberta-base; some changes might be needed for other models
    encoder = load_encoder(model_path, None, 256, do_lower_case=True, mean_pooling=True).cuda().eval()
    encoder[0].attentions = attentions
    train, valid, test = get_splits("data/", dataset)
    Path(output_folder).mkdir(exist_ok=True, parents=True)
    for name, sentences in (("train", train), ("valid", valid), ("test", test)):
        if name == "train" and sample > 0 and sample < 1:
            np.random.seed(42)
            sentences = np.random.choice(
                sentences, size=round(len(sentences) * sample),
                replace=False
            )
        if name == "train" and no_train:
            continue
        ds = SentenceDataset(encoder[0].tokenizer, sentences)
        sampler = SortSampler(
            ds,
            key=lambda x: len(ds.text[x])
        )
        loader = DataLoader(
            ds,
            sampler=sampler,
            collate_fn=partial(
                collate_singles,
                pad=0,
                opening_id=encoder[0].cls_token_id,
                closing_id=encoder[0].sep_token_id,
                truncate_length=encoder[0].max_seq_length
            ),
            batch_size=batch_size,
            num_workers=workers
        )
        print(name)
        with torch.no_grad():
            buffer = []
            for batch, _ in tqdm(loader, ncols=100):
                outputs = encoder(
                    features_to_device(batch, torch.device("cuda"))
                )
                buffer.append([
                    x.detach().cpu().numpy() if x is not None else None
                    for x in (outputs["sentence_embeddings"], outputs["attentions"])
                ])
        reorder_index = np.argsort(list(iter(sampler)))
        embs = np.concatenate([x[0] for x in buffer])[reorder_index]
        if buffer[0][1] is None:
            attns = None
        else:
            attns = list(chain.from_iterable((list(x[1]) for x in buffer)))
            attns = [attns[i] for i in reorder_index]
        joblib.dump([sentences, embs, attns], Path(output_folder) / (dataset.value + "_" + name + ".jbl"))
        del attns, embs
예제 #5
0
def main(args):
    embedder = BertWrapper(args.model_path, max_seq_length=256)
    pooler = PoolingLayer(embedder.get_word_embedding_dimension(),
                          pooling_mode_mean_tokens=True,
                          pooling_mode_cls_token=False,
                          pooling_mode_max_tokens=False,
                          layer_to_use=args.layer)
    encoder = SentenceEncoder(modules=[embedder, pooler])
    model = SentencePairCosineSimilarity(encoder, linear_transform=False)
    model.eval()

    # print("\n".join([name for name, _ in model.named_parameters()]))

    dataset = LcqmcDataset(embedder.tokenizer, filename=args.filename)
    loader = DataLoader(
        dataset,
        sampler=SortSampler(
            dataset,
            key=lambda x: max(len(dataset.text_1[x]), len(dataset.text_2[x]))),
        collate_fn=partial(collate_pairs,
                           pad=0,
                           opening_id=embedder.cls_token_id,
                           closing_id=embedder.sep_token_id,
                           truncate_length=embedder.max_seq_length),
        batch_size=16)
    preds, references = [], []
    with torch.no_grad():
        for features, labels in tqdm(loader):
            for name in features:
                features[name] = features[name].to(encoder.device)
            preds.append(model(features).cpu().numpy())
            references.append(labels.cpu().numpy())

    preds = np.concatenate(preds)
    references = np.concatenate(references)
    spearman_score = spearmanr(preds, references)
    print(f"Spearman: {spearman_score.correlation:.4f}")

    print(f"Pred Min: {np.min(preds)}, {np.max(preds)}")
    if args.threshold == -1:
        best_thres, best_acc = -1, -1
        for threshold in np.arange(0.05, 1, 0.05):
            binarized = (preds > threshold).astype("int")
            acc = (binarized == references).sum() / len(references)
            if acc > best_acc:
                best_acc = acc
                best_thres = threshold
        print(f"Best acc: {best_acc:.4f} @ {best_thres:.2f}")
    else:
        binarized = (preds > args.threshold).astype("int")
        acc = (binarized == references).sum() / len(references)
        print(f"Acc: {acc:.4f} @ {args.threshold:.2f}")
예제 #6
0
def get_loaders(embedder, args) -> Tuple[DataLoader, DataLoader]:
    df_train, df_valid, df_test = get_splitted_data(args)
    ds_train = NewsClassificationDataset(embedder.tokenizer, df_train)
    train_loader = DataLoader(
        ds_train,
        sampler=SortishSampler(ds_train,
                               key=lambda x: len(ds_train.text[x]),
                               bs=args.batch_size),
        collate_fn=partial(collate_singles,
                           pad=0,
                           opening_id=embedder.cls_token_id,
                           closing_id=embedder.sep_token_id,
                           truncate_length=embedder.max_seq_length),
        batch_size=args.batch_size,
        num_workers=args.workers)
    ds_valid = NewsClassificationDataset(embedder.tokenizer, df_valid)
    valid_loader = DataLoader(
        ds_valid,
        sampler=SortSampler(ds_valid, key=lambda x: len(ds_valid.text[x])),
        collate_fn=partial(collate_singles,
                           pad=0,
                           opening_id=embedder.cls_token_id,
                           closing_id=embedder.sep_token_id,
                           truncate_length=embedder.max_seq_length),
        batch_size=args.batch_size * 2,
        num_workers=args.workers)
    ds_test = NewsClassificationDataset(embedder.tokenizer, df_test)
    test_loader = DataLoader(
        ds_test,
        sampler=SortSampler(ds_test, key=lambda x: len(ds_test.text[x])),
        collate_fn=partial(collate_singles,
                           pad=0,
                           opening_id=embedder.cls_token_id,
                           closing_id=embedder.sep_token_id,
                           truncate_length=embedder.max_seq_length),
        batch_size=args.batch_size * 2,
        num_workers=args.workers)
    return train_loader, valid_loader, test_loader
예제 #7
0
def main(teacher_model_path: str,
         student_model_path: str = "nreimers/TinyBERT_L-4_H-312_v2",
         dataset: SBertDataset = "allnli",
         batch_size: int = 32,
         fp16: bool = False,
         workers: int = 2,
         grad_accu: int = 1,
         lr: float = 3e-5,
         epochs: int = 2,
         wd: float = 0,
         layerwise_decay: float = 0,
         attn_loss_weight: float = 1.):
    pl.seed_everything(int(os.environ.get("SEED", 42)))

    config = DistillConfig(
        model_path=student_model_path,
        teacher_model_path=teacher_model_path,
        dataset=dataset,
        data_path="",
        batch_size=batch_size,
        grad_accu=grad_accu,
        learning_rate=lr,
        fp16=fp16,
        epochs=epochs,
        loss_fn=nn.MSELoss(),
        # optimizer_cls=pls.optimizers.RAdam,
        optimizer_cls=torch.optim.AdamW,
        weight_decay=wd,
        layerwise_decay=layerwise_decay,
        attn_loss_weight=attn_loss_weight)

    teacher_encoder, student_encoder, train_ds, valid_ds = get_datasets(
        teacher_model_path, student_model_path, dataset=dataset)
    teacher_encoder.eval()
    pls.utils.set_trainable(teacher_encoder, False)
    teacher_encoder[0].attentions = True
    student_encoder[0].attentions = True
    print(len(train_ds), len(valid_ds))
    train_loader = DataLoader(
        train_ds,
        sampler=SortishSampler(train_ds,
                               key=lambda x: len(train_ds.text_1[x]),
                               bs=batch_size),
        collate_fn=partial(collate_distill,
                           pad=0,
                           opening_id=teacher_encoder[0].cls_token_id,
                           closing_id=teacher_encoder[0].sep_token_id,
                           truncate_length=teacher_encoder[0].max_seq_length),
        batch_size=batch_size,
        num_workers=workers)
    valid_loader = DataLoader(
        valid_ds,
        sampler=SortSampler(valid_ds, key=lambda x: len(valid_ds.text_1[x])),
        collate_fn=partial(collate_distill,
                           pad=0,
                           opening_id=teacher_encoder[0].cls_token_id,
                           closing_id=teacher_encoder[0].sep_token_id,
                           truncate_length=teacher_encoder[0].max_seq_length),
        batch_size=batch_size,
        num_workers=1)

    pl_module = DistillModule(config,
                              teacher_encoder,
                              student_encoder,
                              metrics=())

    checkpoints = pl.callbacks.ModelCheckpoint(
        dirpath=str(CACHE_DIR / "model_checkpoints"),
        monitor='val_loss',
        mode="min",
        filename='{step:06d}-{val_loss:.4f}',
        save_top_k=1,
        save_last=False)
    callbacks = [
        checkpoints,
        pl.callbacks.LearningRateMonitor(logging_interval='step'),
    ]

    loggers = [
        pl.loggers.TensorBoardLogger(str(CACHE_DIR / "tb_logs_distill"),
                                     name=""),
        pls.loggers.ScreenLogger(),
    ]
    if os.environ.get("WANDB_PROJ"):
        loggers.append(
            pl.loggers.WandbLogger(project=os.environ["WANDB_PROJ"]))
    trainer = pl.Trainer(
        # amp_backend="apex", amp_level='O2',
        precision=16 if config.fp16 else 32,
        gpus=1,
        val_check_interval=0.5 if dataset is SBertDataset.AllNLI else 1.,
        gradient_clip_val=10,
        max_epochs=epochs,
        # max_steps=steps,
        callbacks=callbacks,
        accumulate_grad_batches=config.grad_accu,
        # auto_scale_batch_size='power' if batch_size is None else None,
        logger=loggers,
        log_every_n_steps=100)

    trainer.fit(pl_module,
                train_dataloader=train_loader,
                val_dataloaders=valid_loader)

    pl_module.load_state_dict(
        torch.load(checkpoints.best_model_path)["state_dict"])

    output_folder = CACHE_DIR / f"{student_encoder[0].transformer.__class__.__name__}_distilled"
    student_encoder.save(str(output_folder))
예제 #8
0
def get_loaders(embedder, t2s, workers, batch_size, sample_train) -> Tuple[DataLoader, DataLoader, DataLoader]:
    df_train, df_valid, df_test = get_splitted_data()
    if t2s:
        print("Converting traditional to simplified...")
        for df in (df_train, df_valid, df_test):
            df["text_1"] = df["text_1"].apply(convert_t2s)
            df["text_2"] = df["text_2"].apply(convert_t2s)
    print(df_valid.text_1.head(2))
    print(df_test.text_1.head(2))
    if sample_train > 0 and sample_train < 1:
        df_train = df_train.sample(frac=sample_train)
    ds_train = NewsSimilarityDataset(
        embedder.tokenizer, df_train)
    train_loader = DataLoader(
        ds_train,
        sampler=SortishSampler(
            ds_train,
            key=pair_max_len(ds_train),
            bs=batch_size
        ),
        collate_fn=partial(
            collate_pairs,
            pad=0,
            opening_id=embedder.cls_token_id,
            closing_id=embedder.sep_token_id,
            truncate_length=embedder.max_seq_length
        ),
        batch_size=batch_size,
        num_workers=workers
    )
    ds_valid = NewsSimilarityDataset(
        embedder.tokenizer, df_valid)
    valid_loader = DataLoader(
        ds_valid,
        sampler=SortSampler(
            ds_valid,
            key=pair_max_len(ds_valid)
        ),
        collate_fn=partial(
            collate_pairs,
            pad=0,
            opening_id=embedder.cls_token_id,
            closing_id=embedder.sep_token_id,
            truncate_length=embedder.max_seq_length
        ),
        batch_size=batch_size * 2,
        num_workers=0
    )
    ds_test = NewsSimilarityDataset(
        embedder.tokenizer, df_test)
    test_loader = DataLoader(
        ds_test,
        sampler=SortSampler(
            ds_test,
            key=pair_max_len(ds_test)
        ),
        collate_fn=partial(
            collate_pairs,
            pad=0,
            opening_id=embedder.cls_token_id,
            closing_id=embedder.sep_token_id,
            truncate_length=embedder.max_seq_length
        ),
        batch_size=batch_size * 2,
        num_workers=0
    )
    return train_loader, valid_loader, test_loader
예제 #9
0
def main(model_path: str = "nreimers/TinyBERT_L-4_H-312_v2",
         dataset: SBertDataset = "allnli",
         cache_folder: str = "cache/teacher_embs/",
         batch_size: int = 32,
         fp16: bool = False,
         workers: int = 2,
         grad_accu: int = 1,
         lr: float = 3e-5,
         epochs: int = 2,
         wd: float = 0,
         layerwise_decay: float = 0):
    pl.seed_everything(int(os.environ.get("SEED", 42)))

    config = BaseConfig(
        model_path=model_path,
        data_path=cache_folder,
        batch_size=batch_size,
        grad_accu=grad_accu,
        learning_rate=lr,
        fp16=fp16,
        epochs=epochs,
        loss_fn=nn.MSELoss(),
        # optimizer_cls=pls.optimizers.RAdam,
        optimizer_cls=torch.optim.AdamW,
        weight_decay=wd,
        layerwise_decay=layerwise_decay)

    sents, embs, attns = joblib.load(
        Path(cache_folder) / f"{dataset.value}_train.jbl")
    encoder = load_encoder(model_path,
                           None,
                           256,
                           do_lower_case=True,
                           mean_pooling=True,
                           expand_to_dimension=embs.shape[1])
    # print(encoder)
    tokenizer = encoder[0].tokenizer
    train_ds = DistillDataset(tokenizer, sents, embs)
    sents, embs, attns = joblib.load(
        Path(cache_folder) / f"{dataset.value}_valid.jbl")
    valid_ds = DistillDataset(tokenizer, sents, embs)
    print(len(train_ds), len(valid_ds))
    del sents
    del embs
    train_loader = DataLoader(
        train_ds,
        sampler=SortishSampler(train_ds,
                               key=lambda x: len(train_ds.text[x]),
                               bs=batch_size),
        collate_fn=partial(collate_singles,
                           pad=0,
                           opening_id=encoder[0].cls_token_id,
                           closing_id=encoder[0].sep_token_id,
                           truncate_length=encoder[0].max_seq_length),
        batch_size=batch_size,
        num_workers=workers)
    valid_loader = DataLoader(
        valid_ds,
        sampler=SortSampler(valid_ds, key=lambda x: len(valid_ds.text[x])),
        collate_fn=partial(collate_singles,
                           pad=0,
                           opening_id=encoder[0].cls_token_id,
                           closing_id=encoder[0].sep_token_id,
                           truncate_length=encoder[0].max_seq_length),
        batch_size=batch_size,
        num_workers=1)

    pl_module = SentenceEncodingModule(config,
                                       EncoderWrapper(encoder),
                                       metrics=(),
                                       layerwise_decay=config.layerwise_decay)

    checkpoints = pl.callbacks.ModelCheckpoint(
        dirpath=str(CACHE_DIR / "model_checkpoints"),
        monitor='val_loss',
        mode="min",
        filename='{step:06d}-{val_loss:.4f}',
        save_top_k=1,
        save_last=False)
    callbacks = [
        checkpoints,
        pl.callbacks.LearningRateMonitor(logging_interval='step'),
    ]

    trainer = pl.Trainer(
        # accelerator='dp' if num_gpus > 1 else None,
        # amp_backend="apex", amp_level='O2',
        precision=16 if config.fp16 else 32,
        gpus=1,
        val_check_interval=0.5 if dataset is SBertDataset.AllNLI else 1.,
        gradient_clip_val=10,
        max_epochs=epochs,
        # max_steps=steps,
        callbacks=callbacks,
        accumulate_grad_batches=config.grad_accu,
        # auto_scale_batch_size='power' if batch_size is None else None,
        logger=[
            # pl.loggers.TensorBoardLogger(str(CACHE_DIR / "tb_logs"), name=""),
            pls.loggers.ScreenLogger(),
            # pl.loggers.WandbLogger(project="news-similarity")
        ],
        log_every_n_steps=100)

    trainer.fit(pl_module,
                train_dataloader=train_loader,
                val_dataloaders=valid_loader)

    pl_module.load_state_dict(
        torch.load(checkpoints.best_model_path)["state_dict"])

    output_folder = CACHE_DIR / f"{encoder[0].transformer.__class__.__name__}_distilled"
    encoder.save(str(output_folder))