Beispiel #1
0
def main():
    parser = ArgumentParser()

    # add task level args
    parser = add_task_specific_args(parser)
    # add model specific args
    parser = Model.add_model_specific_args(parser)
    parser = optimization.add_optimizer_specific_args(parser)
    parser = Trainer.add_argparse_args(parser)

    # task specific default args
    parser.set_defaults(num_labels=56)
    parser.set_defaults(arc_hidden_size=600, rel_hidden_size=600)

    args = parser.parse_args()

    if args.build_dataset:
        build_distill_dataset(args)
    else:
        common_train(args,
                     metric='val_las',
                     model_class=Model,
                     build_method=build_method,
                     task='sdp',
                     loss_func=sdp_loss)
Beispiel #2
0
def build_distill_dataset(args):
    model = Model.load_from_checkpoint(args.resume_from_checkpoint,
                                       hparams=args,
                                       loss_func=sdp_loss)

    model.eval()
    model.freeze()

    dataset, metric = build_dataset(model, args.data_dir)
    train_dataloader = torch.utils.data.DataLoader(
        dataset[datasets.Split.TRAIN],
        batch_size=args.batch_size,
        collate_fn=collate,
        num_workers=args.num_workers)

    output = os.path.join(args.data_dir, 'output.npz')

    if torch.cuda.is_available():
        model.cuda()
        map2cpu = lambda x: map2device(x)
        map2cuda = lambda x: map2device(x, model.device)
    else:
        map2cpu = lambda x: x
        map2cuda = lambda x: x

    with torch.no_grad():
        batchs = []
        for batch in tqdm(train_dataloader):
            batch = map2cuda(batch)
            loss, logits = model(**batch)
            batch.update(logits=logits)
            batchs.append(map2cpu(batch))
        numpy.savez(output, data=convert2npy(batchs))

    print("Done")
Beispiel #3
0
def main():
    parser = ArgumentParser()

    # add task level args
    parser = add_task_specific_args(parser)

    # add model specific args
    parser = Model.add_model_specific_args(parser)
    parser = optimization.add_optimizer_specific_args(parser)
    parser = Trainer.add_argparse_args(parser)

    # task specific default args
    parser.set_defaults(gradient_clip_val=1.0,
                        num_labels=14,
                        min_epochs=1,
                        max_epochs=10)
    parser.set_defaults(arc_hidden_size=500, rel_hidden_size=100)

    args = parser.parse_args()
    if args.build_dataset:
        build_distill_dataset(args)
    elif args.tune:
        tune_train(args, model_class=Model, task_info=task_info)
    else:
        common_train(args, model_class=Model, task_info=task_info)
Beispiel #4
0
def main():
    parser = ArgumentParser()

    # add task level args
    parser = add_task_specific_args(parser)

    # add model specific args
    parser = Model.add_model_specific_args(parser)
    parser = optimization.add_optimizer_specific_args(parser)
    parser = Trainer.add_argparse_args(parser)

    # task specific default args
    parser.set_defaults(num_labels=14, max_epochs=10)
    parser.set_defaults(arc_hidden_size=500, rel_hidden_size=100)

    args = parser.parse_args()
    if args.build_dataset:
        build_distill_dataset(args)
    else:
        common_train(
            args,
            metric=f'val_{task_info.metric_name}',
            model_class=Model,
            build_method=build_method,
            task=task_info.task_name
        )
Beispiel #5
0
def main():
    parser = ArgumentParser()

    # add task level args
    parser = add_common_specific_args(parser)
    parser = add_tune_specific_args(parser)
    parser = add_task_specific_args(parser)

    # add model specific args
    parser = ViModel.add_model_specific_args(parser)
    parser = Model.add_model_specific_args(parser)
    parser = optimization.add_optimizer_specific_args(parser)
    parser = Trainer.add_argparse_args(parser)

    # task specific default args
    parser.set_defaults(gradient_clip_val=1.0, min_epochs=1, max_epochs=10)
    parser.set_defaults(num_labels=56,
                        arc_hidden_size=600,
                        rel_hidden_size=600)

    args = parser.parse_args()

    if args.use_vi:
        model_class = ViModel
        model_kwargs = {}
    else:
        model_class = Model
        model_kwargs = {'loss_func': sdp_loss}

    if args.build_dataset:
        build_distill_dataset(model_class, args, model_kwargs=model_kwargs)
    elif args.tune:
        tune_train(args,
                   model_class=model_class,
                   task_info=task_info,
                   model_kwargs=model_kwargs)
    else:
        common_train(args,
                     model_class=model_class,
                     task_info=task_info,
                     model_kwargs=model_kwargs)