def main(): parser = ArgumentParser() # add task level args parser = add_common_specific_args(parser) parser = add_tune_specific_args(parser) parser = add_task_specific_args(parser) # add model specific args parser = Model.add_model_specific_args(parser) parser = optimization.add_optimizer_specific_args(parser) parser = Trainer.add_argparse_args(parser) # set default args parser.set_defaults(gradient_clip_val=1.0, min_epochs=1, max_epochs=10) parser.set_defaults(num_labels=27) args = parser.parse_args() if args.build_dataset: build_distill_dataset(args) elif args.tune: tune_train(args, model_class=Model, task_info=task_info) else: common_train(args, model_class=Model, task_info=task_info)
def build_distill_dataset(args): model = Model.load_from_checkpoint( args.resume_from_checkpoint, hparams=args ) model.eval() model.freeze() dataset, metric = build_dataset(model, args.data_dir, task_info.task_name) train_dataloader = torch.utils.data.DataLoader( dataset[datasets.Split.TRAIN], batch_size=args.batch_size, collate_fn=collate, num_workers=args.num_workers ) output = os.path.join(args.data_dir, task_info.task_name, 'output.npz') if torch.cuda.is_available(): model.cuda() map2cpu = lambda x: map2device(x) map2cuda = lambda x: map2device(x, model.device) else: map2cpu = lambda x: x map2cuda = lambda x: x with torch.no_grad(): batchs = [] for batch in tqdm(train_dataloader): batch = map2cuda(batch) logits = model.forward(**batch).logits batch.update(logits=logits) batchs.append(map2cpu(batch)) numpy.savez(output, data=convert2npy(batchs)) print("Done")
def main(): parser = ArgumentParser() # add task level args parser = add_task_specific_args(parser) # add model specific args parser = Model.add_model_specific_args(parser) parser = optimization.add_optimizer_specific_args(parser) parser = Trainer.add_argparse_args(parser) # set task specific args parser.set_defaults(num_labels=2, max_epochs=10) args = parser.parse_args() args.data_dir = os.path.abspath(args.data_dir) if args.build_dataset: build_distill_dataset(args) else: common_train(args, metric=f'val_{task_info.metric_name}', model_class=Model, build_method=build_method, task=task_info.task_name)