def orig(encoder, df, batch_size: int): embedder = encoder[0] model = SentencePairCosineSimilarity(encoder) dataset = SimilarityDataset(embedder.tokenizer, df, "sentence1", "sentence2", "score") loader = DataLoader( dataset, sampler=SortSampler( dataset, key=lambda x: max(len(dataset.text_1[x]), len(dataset.text_2[x]))), collate_fn=partial(collate_pairs, pad=0, opening_id=embedder.cls_token_id, closing_id=embedder.sep_token_id, truncate_length=embedder.max_seq_length), batch_size=batch_size) preds, references = [], [] with torch.no_grad(): for features, labels in tqdm(loader, ncols=100): features = features_to_device(features, encoder.device) preds.append(model(**features).cpu().numpy()) references.append(labels.cpu().numpy()) preds = np.concatenate(preds) references = np.concatenate(references) return preds, references
def get_loaders(embedder, args) -> Tuple[DataLoader, DataLoader]: ds_train = LcqmcDataset(embedder.tokenizer, filename="train.txt", cache_dir=CACHE_DIR) train_loader = DataLoader( ds_train, sampler=SortishSampler(ds_train, key=pair_max_len(ds_train), bs=args.batch_size), collate_fn=partial(collate_pairs, pad=0, opening_id=embedder.cls_token_id, closing_id=embedder.sep_token_id, truncate_length=embedder.max_seq_length), batch_size=args.batch_size, num_workers=args.workers) ds_valid = LcqmcDataset(embedder.tokenizer, filename="dev.txt", cache_dir=CACHE_DIR) valid_loader = DataLoader(ds_valid, sampler=SortSampler(ds_valid, key=pair_max_len(ds_valid)), collate_fn=partial( collate_pairs, pad=0, opening_id=embedder.cls_token_id, closing_id=embedder.sep_token_id, truncate_length=embedder.max_seq_length), batch_size=args.batch_size * 2, num_workers=args.workers) return train_loader, valid_loader
def orig(args, model): encoder = model.encoder embedder = model.encoder[0] dataset = LcqmcDataset(embedder.tokenizer, filename=args.filename) loader = DataLoader( dataset, sampler=SortSampler( dataset, key=lambda x: max(len(dataset.text_1[x]), len(dataset.text_2[x]))), collate_fn=partial(collate_pairs, pad=0, opening_id=embedder.cls_token_id, closing_id=embedder.sep_token_id, truncate_length=embedder.max_seq_length), batch_size=16) preds, references = [], [] with torch.no_grad(): for features, labels in tqdm(loader): for name in features: features[name] = features[name].to(encoder.device) preds.append(model(features).cpu().numpy()) references.append(labels.cpu().numpy()) preds = np.concatenate(preds) references = np.concatenate(references) spearman_score = spearmanr(preds, references) print(f"Spearman: {spearman_score.correlation:.4f}") return preds, references
def main( dataset: SBertDataset, model_path: str, output_folder: str = "cache/teacher_embs/", batch_size: int = 32, workers: int = 2, attentions: bool = False, sample: float = -1, no_train: bool = False ): # this is designed for stsb-roberta-base; some changes might be needed for other models encoder = load_encoder(model_path, None, 256, do_lower_case=True, mean_pooling=True).cuda().eval() encoder[0].attentions = attentions train, valid, test = get_splits("data/", dataset) Path(output_folder).mkdir(exist_ok=True, parents=True) for name, sentences in (("train", train), ("valid", valid), ("test", test)): if name == "train" and sample > 0 and sample < 1: np.random.seed(42) sentences = np.random.choice( sentences, size=round(len(sentences) * sample), replace=False ) if name == "train" and no_train: continue ds = SentenceDataset(encoder[0].tokenizer, sentences) sampler = SortSampler( ds, key=lambda x: len(ds.text[x]) ) loader = DataLoader( ds, sampler=sampler, collate_fn=partial( collate_singles, pad=0, opening_id=encoder[0].cls_token_id, closing_id=encoder[0].sep_token_id, truncate_length=encoder[0].max_seq_length ), batch_size=batch_size, num_workers=workers ) print(name) with torch.no_grad(): buffer = [] for batch, _ in tqdm(loader, ncols=100): outputs = encoder( features_to_device(batch, torch.device("cuda")) ) buffer.append([ x.detach().cpu().numpy() if x is not None else None for x in (outputs["sentence_embeddings"], outputs["attentions"]) ]) reorder_index = np.argsort(list(iter(sampler))) embs = np.concatenate([x[0] for x in buffer])[reorder_index] if buffer[0][1] is None: attns = None else: attns = list(chain.from_iterable((list(x[1]) for x in buffer))) attns = [attns[i] for i in reorder_index] joblib.dump([sentences, embs, attns], Path(output_folder) / (dataset.value + "_" + name + ".jbl")) del attns, embs
def main(args): embedder = BertWrapper(args.model_path, max_seq_length=256) pooler = PoolingLayer(embedder.get_word_embedding_dimension(), pooling_mode_mean_tokens=True, pooling_mode_cls_token=False, pooling_mode_max_tokens=False, layer_to_use=args.layer) encoder = SentenceEncoder(modules=[embedder, pooler]) model = SentencePairCosineSimilarity(encoder, linear_transform=False) model.eval() # print("\n".join([name for name, _ in model.named_parameters()])) dataset = LcqmcDataset(embedder.tokenizer, filename=args.filename) loader = DataLoader( dataset, sampler=SortSampler( dataset, key=lambda x: max(len(dataset.text_1[x]), len(dataset.text_2[x]))), collate_fn=partial(collate_pairs, pad=0, opening_id=embedder.cls_token_id, closing_id=embedder.sep_token_id, truncate_length=embedder.max_seq_length), batch_size=16) preds, references = [], [] with torch.no_grad(): for features, labels in tqdm(loader): for name in features: features[name] = features[name].to(encoder.device) preds.append(model(features).cpu().numpy()) references.append(labels.cpu().numpy()) preds = np.concatenate(preds) references = np.concatenate(references) spearman_score = spearmanr(preds, references) print(f"Spearman: {spearman_score.correlation:.4f}") print(f"Pred Min: {np.min(preds)}, {np.max(preds)}") if args.threshold == -1: best_thres, best_acc = -1, -1 for threshold in np.arange(0.05, 1, 0.05): binarized = (preds > threshold).astype("int") acc = (binarized == references).sum() / len(references) if acc > best_acc: best_acc = acc best_thres = threshold print(f"Best acc: {best_acc:.4f} @ {best_thres:.2f}") else: binarized = (preds > args.threshold).astype("int") acc = (binarized == references).sum() / len(references) print(f"Acc: {acc:.4f} @ {args.threshold:.2f}")
def get_loaders(embedder, args) -> Tuple[DataLoader, DataLoader]: df_train, df_valid, df_test = get_splitted_data(args) ds_train = NewsClassificationDataset(embedder.tokenizer, df_train) train_loader = DataLoader( ds_train, sampler=SortishSampler(ds_train, key=lambda x: len(ds_train.text[x]), bs=args.batch_size), collate_fn=partial(collate_singles, pad=0, opening_id=embedder.cls_token_id, closing_id=embedder.sep_token_id, truncate_length=embedder.max_seq_length), batch_size=args.batch_size, num_workers=args.workers) ds_valid = NewsClassificationDataset(embedder.tokenizer, df_valid) valid_loader = DataLoader( ds_valid, sampler=SortSampler(ds_valid, key=lambda x: len(ds_valid.text[x])), collate_fn=partial(collate_singles, pad=0, opening_id=embedder.cls_token_id, closing_id=embedder.sep_token_id, truncate_length=embedder.max_seq_length), batch_size=args.batch_size * 2, num_workers=args.workers) ds_test = NewsClassificationDataset(embedder.tokenizer, df_test) test_loader = DataLoader( ds_test, sampler=SortSampler(ds_test, key=lambda x: len(ds_test.text[x])), collate_fn=partial(collate_singles, pad=0, opening_id=embedder.cls_token_id, closing_id=embedder.sep_token_id, truncate_length=embedder.max_seq_length), batch_size=args.batch_size * 2, num_workers=args.workers) return train_loader, valid_loader, test_loader
def main(teacher_model_path: str, student_model_path: str = "nreimers/TinyBERT_L-4_H-312_v2", dataset: SBertDataset = "allnli", batch_size: int = 32, fp16: bool = False, workers: int = 2, grad_accu: int = 1, lr: float = 3e-5, epochs: int = 2, wd: float = 0, layerwise_decay: float = 0, attn_loss_weight: float = 1.): pl.seed_everything(int(os.environ.get("SEED", 42))) config = DistillConfig( model_path=student_model_path, teacher_model_path=teacher_model_path, dataset=dataset, data_path="", batch_size=batch_size, grad_accu=grad_accu, learning_rate=lr, fp16=fp16, epochs=epochs, loss_fn=nn.MSELoss(), # optimizer_cls=pls.optimizers.RAdam, optimizer_cls=torch.optim.AdamW, weight_decay=wd, layerwise_decay=layerwise_decay, attn_loss_weight=attn_loss_weight) teacher_encoder, student_encoder, train_ds, valid_ds = get_datasets( teacher_model_path, student_model_path, dataset=dataset) teacher_encoder.eval() pls.utils.set_trainable(teacher_encoder, False) teacher_encoder[0].attentions = True student_encoder[0].attentions = True print(len(train_ds), len(valid_ds)) train_loader = DataLoader( train_ds, sampler=SortishSampler(train_ds, key=lambda x: len(train_ds.text_1[x]), bs=batch_size), collate_fn=partial(collate_distill, pad=0, opening_id=teacher_encoder[0].cls_token_id, closing_id=teacher_encoder[0].sep_token_id, truncate_length=teacher_encoder[0].max_seq_length), batch_size=batch_size, num_workers=workers) valid_loader = DataLoader( valid_ds, sampler=SortSampler(valid_ds, key=lambda x: len(valid_ds.text_1[x])), collate_fn=partial(collate_distill, pad=0, opening_id=teacher_encoder[0].cls_token_id, closing_id=teacher_encoder[0].sep_token_id, truncate_length=teacher_encoder[0].max_seq_length), batch_size=batch_size, num_workers=1) pl_module = DistillModule(config, teacher_encoder, student_encoder, metrics=()) checkpoints = pl.callbacks.ModelCheckpoint( dirpath=str(CACHE_DIR / "model_checkpoints"), monitor='val_loss', mode="min", filename='{step:06d}-{val_loss:.4f}', save_top_k=1, save_last=False) callbacks = [ checkpoints, pl.callbacks.LearningRateMonitor(logging_interval='step'), ] loggers = [ pl.loggers.TensorBoardLogger(str(CACHE_DIR / "tb_logs_distill"), name=""), pls.loggers.ScreenLogger(), ] if os.environ.get("WANDB_PROJ"): loggers.append( pl.loggers.WandbLogger(project=os.environ["WANDB_PROJ"])) trainer = pl.Trainer( # amp_backend="apex", amp_level='O2', precision=16 if config.fp16 else 32, gpus=1, val_check_interval=0.5 if dataset is SBertDataset.AllNLI else 1., gradient_clip_val=10, max_epochs=epochs, # max_steps=steps, callbacks=callbacks, accumulate_grad_batches=config.grad_accu, # auto_scale_batch_size='power' if batch_size is None else None, logger=loggers, log_every_n_steps=100) trainer.fit(pl_module, train_dataloader=train_loader, val_dataloaders=valid_loader) pl_module.load_state_dict( torch.load(checkpoints.best_model_path)["state_dict"]) output_folder = CACHE_DIR / f"{student_encoder[0].transformer.__class__.__name__}_distilled" student_encoder.save(str(output_folder))
def get_loaders(embedder, t2s, workers, batch_size, sample_train) -> Tuple[DataLoader, DataLoader, DataLoader]: df_train, df_valid, df_test = get_splitted_data() if t2s: print("Converting traditional to simplified...") for df in (df_train, df_valid, df_test): df["text_1"] = df["text_1"].apply(convert_t2s) df["text_2"] = df["text_2"].apply(convert_t2s) print(df_valid.text_1.head(2)) print(df_test.text_1.head(2)) if sample_train > 0 and sample_train < 1: df_train = df_train.sample(frac=sample_train) ds_train = NewsSimilarityDataset( embedder.tokenizer, df_train) train_loader = DataLoader( ds_train, sampler=SortishSampler( ds_train, key=pair_max_len(ds_train), bs=batch_size ), collate_fn=partial( collate_pairs, pad=0, opening_id=embedder.cls_token_id, closing_id=embedder.sep_token_id, truncate_length=embedder.max_seq_length ), batch_size=batch_size, num_workers=workers ) ds_valid = NewsSimilarityDataset( embedder.tokenizer, df_valid) valid_loader = DataLoader( ds_valid, sampler=SortSampler( ds_valid, key=pair_max_len(ds_valid) ), collate_fn=partial( collate_pairs, pad=0, opening_id=embedder.cls_token_id, closing_id=embedder.sep_token_id, truncate_length=embedder.max_seq_length ), batch_size=batch_size * 2, num_workers=0 ) ds_test = NewsSimilarityDataset( embedder.tokenizer, df_test) test_loader = DataLoader( ds_test, sampler=SortSampler( ds_test, key=pair_max_len(ds_test) ), collate_fn=partial( collate_pairs, pad=0, opening_id=embedder.cls_token_id, closing_id=embedder.sep_token_id, truncate_length=embedder.max_seq_length ), batch_size=batch_size * 2, num_workers=0 ) return train_loader, valid_loader, test_loader
def main(model_path: str = "nreimers/TinyBERT_L-4_H-312_v2", dataset: SBertDataset = "allnli", cache_folder: str = "cache/teacher_embs/", batch_size: int = 32, fp16: bool = False, workers: int = 2, grad_accu: int = 1, lr: float = 3e-5, epochs: int = 2, wd: float = 0, layerwise_decay: float = 0): pl.seed_everything(int(os.environ.get("SEED", 42))) config = BaseConfig( model_path=model_path, data_path=cache_folder, batch_size=batch_size, grad_accu=grad_accu, learning_rate=lr, fp16=fp16, epochs=epochs, loss_fn=nn.MSELoss(), # optimizer_cls=pls.optimizers.RAdam, optimizer_cls=torch.optim.AdamW, weight_decay=wd, layerwise_decay=layerwise_decay) sents, embs, attns = joblib.load( Path(cache_folder) / f"{dataset.value}_train.jbl") encoder = load_encoder(model_path, None, 256, do_lower_case=True, mean_pooling=True, expand_to_dimension=embs.shape[1]) # print(encoder) tokenizer = encoder[0].tokenizer train_ds = DistillDataset(tokenizer, sents, embs) sents, embs, attns = joblib.load( Path(cache_folder) / f"{dataset.value}_valid.jbl") valid_ds = DistillDataset(tokenizer, sents, embs) print(len(train_ds), len(valid_ds)) del sents del embs train_loader = DataLoader( train_ds, sampler=SortishSampler(train_ds, key=lambda x: len(train_ds.text[x]), bs=batch_size), collate_fn=partial(collate_singles, pad=0, opening_id=encoder[0].cls_token_id, closing_id=encoder[0].sep_token_id, truncate_length=encoder[0].max_seq_length), batch_size=batch_size, num_workers=workers) valid_loader = DataLoader( valid_ds, sampler=SortSampler(valid_ds, key=lambda x: len(valid_ds.text[x])), collate_fn=partial(collate_singles, pad=0, opening_id=encoder[0].cls_token_id, closing_id=encoder[0].sep_token_id, truncate_length=encoder[0].max_seq_length), batch_size=batch_size, num_workers=1) pl_module = SentenceEncodingModule(config, EncoderWrapper(encoder), metrics=(), layerwise_decay=config.layerwise_decay) checkpoints = pl.callbacks.ModelCheckpoint( dirpath=str(CACHE_DIR / "model_checkpoints"), monitor='val_loss', mode="min", filename='{step:06d}-{val_loss:.4f}', save_top_k=1, save_last=False) callbacks = [ checkpoints, pl.callbacks.LearningRateMonitor(logging_interval='step'), ] trainer = pl.Trainer( # accelerator='dp' if num_gpus > 1 else None, # amp_backend="apex", amp_level='O2', precision=16 if config.fp16 else 32, gpus=1, val_check_interval=0.5 if dataset is SBertDataset.AllNLI else 1., gradient_clip_val=10, max_epochs=epochs, # max_steps=steps, callbacks=callbacks, accumulate_grad_batches=config.grad_accu, # auto_scale_batch_size='power' if batch_size is None else None, logger=[ # pl.loggers.TensorBoardLogger(str(CACHE_DIR / "tb_logs"), name=""), pls.loggers.ScreenLogger(), # pl.loggers.WandbLogger(project="news-similarity") ], log_every_n_steps=100) trainer.fit(pl_module, train_dataloader=train_loader, val_dataloaders=valid_loader) pl_module.load_state_dict( torch.load(checkpoints.best_model_path)["state_dict"]) output_folder = CACHE_DIR / f"{encoder[0].transformer.__class__.__name__}_distilled" encoder.save(str(output_folder))