def build_ner_distill_dataset(args): model = Model.load_from_checkpoint( args.resume_from_checkpoint, hparams=args ) model.eval() model.freeze() dataset, metric = ner.build_dataset( model, args.ner_data_dir, ner.task_info.task_name ) train_dataloader = torch.utils.data.DataLoader( dataset[datasets.Split.TRAIN], batch_size=args.batch_size, collate_fn=collate, num_workers=args.num_workers ) output = os.path.join(args.ner_data_dir, ner.task_info.task_name, 'output.npz') if torch.cuda.is_available(): model.cuda() map2cpu = lambda x: map2device(x) map2cuda = lambda x: map2device(x, model.device) else: map2cpu = lambda x: x map2cuda = lambda x: x with torch.no_grad(): batchs = [] for batch in tqdm(train_dataloader): batch = map2cuda(batch) logits = model.forward(task='ner', **batch).logits batch.update(logits=logits) batchs.append(map2cpu(batch)) try: numpy.savez( output, data=convert2npy(batchs), extra=convert2npy({ 'transitions': model.ner_classifier.crf.transitions, 'start_transitions': model.ner_classifier.crf.start_transitions, 'end_transitions': model.ner_classifier.crf.end_transitions }) ) except Exception as e: numpy.savez(output, data=convert2npy(batchs)) print("Done")
def build_distill_dataset(args): model = Model.load_from_checkpoint(args.resume_from_checkpoint, hparams=args, loss_func=sdp_loss) model.eval() model.freeze() dataset, metric = build_dataset(model, args.data_dir) train_dataloader = torch.utils.data.DataLoader( dataset[datasets.Split.TRAIN], batch_size=args.batch_size, collate_fn=collate, num_workers=args.num_workers) output = os.path.join(args.data_dir, 'output.npz') if torch.cuda.is_available(): model.cuda() map2cpu = lambda x: map2device(x) map2cuda = lambda x: map2device(x, model.device) else: map2cpu = lambda x: x map2cuda = lambda x: x with torch.no_grad(): batchs = [] for batch in tqdm(train_dataloader): batch = map2cuda(batch) loss, logits = model(**batch) batch.update(logits=logits) batchs.append(map2cpu(batch)) numpy.savez(output, data=convert2npy(batchs)) print("Done")