def do_kd_pseudo_training(args):
    prepare_output_path(args.output_dir, args.overwrite_output_dir)
    device, n_gpus = setup_backend(args.no_cuda)
    # Set seed
    args.seed = set_seed(args.seed, n_gpus)
    # prepare data
    processor = TokenClsProcessor(
        args.data_dir, tag_col=args.tag_col, ignore_token=args.ignore_token
    )
    train_labeled_ex = processor.get_train_examples(filename=args.train_filename)
    train_unlabeled_ex = processor.get_train_examples(filename=args.unlabeled_filename)
    dev_ex = processor.get_dev_examples(filename=args.dev_filename)
    test_ex = processor.get_test_examples(filename=args.test_filename)
    vocab = processor.get_vocabulary(train_labeled_ex + train_unlabeled_ex + dev_ex + test_ex)
    vocab_size = len(vocab) + 1
    num_labels = len(processor.get_labels()) + 1
    # create an embedder
    embedder_cls = MODEL_TYPE[args.model_type]
    if args.config_file is not None:
        embedder_model = embedder_cls.from_config(vocab_size, num_labels, args.config_file)
    else:
        embedder_model = embedder_cls(vocab_size, num_labels)

    # load external word embeddings if present
    if args.embedding_file is not None:
        emb_dict = load_embedding_file(args.embedding_file, dim=embedder_model.word_embedding_dim)
        emb_mat = get_embedding_matrix(emb_dict, vocab)
        emb_mat = torch.tensor(emb_mat, dtype=torch.float)
        embedder_model.load_embeddings(emb_mat)

    classifier = NeuralTagger(
        embedder_model,
        word_vocab=vocab,
        labels=processor.get_labels(),
        use_crf=args.use_crf,
        device=device,
        n_gpus=n_gpus,
    )

    train_batch_size = args.b * max(1, n_gpus)
    train_labeled_dataset = classifier.convert_to_tensors(
        train_labeled_ex,
        max_seq_length=args.max_sentence_length,
        max_word_length=args.max_word_length,
    )
    train_unlabeled_dataset = classifier.convert_to_tensors(
        train_unlabeled_ex,
        max_seq_length=args.max_sentence_length,
        max_word_length=args.max_word_length,
        include_labels=False,
    )

    if args.parallel_batching:
        # # concat labeled+unlabeled dataset
        # train_dataset = ConcatTensorDataset(train_labeled_dataset, [train_unlabeled_dataset])
        # match sizes of labeled/unlabeled train data for parallel batching
        larger_ds, smaller_ds = (
            (train_labeled_dataset, train_unlabeled_dataset)
            if len(train_labeled_dataset) > len(train_unlabeled_dataset)
            else (train_unlabeled_dataset, train_labeled_dataset)
        )
        concat_smaller_ds = smaller_ds
        while len(concat_smaller_ds) < len(larger_ds):
            concat_smaller_ds = ConcatTensorDataset(concat_smaller_ds, [smaller_ds])
        if len(concat_smaller_ds[0]) == 4:
            train_unlabeled_dataset = concat_smaller_ds
        else:
            train_labeled_dataset = concat_smaller_ds
    else:
        train_dataset = CombinedTensorDataset([train_labeled_dataset, train_unlabeled_dataset])

    # load saved teacher args if exist
    if os.path.exists(args.teacher_model_path + os.sep + "training_args.bin"):
        t_args = torch.load(args.teacher_model_path + os.sep + "training_args.bin")
        t_device, t_n_gpus = setup_backend(t_args.no_cuda)
        teacher = TransformerTokenClassifier.load_model(
            model_path=args.teacher_model_path,
            model_type=args.teacher_model_type,
            config_name=t_args.config_name,
            tokenizer_name=t_args.tokenizer_name,
            do_lower_case=t_args.do_lower_case,
            output_path=t_args.output_dir,
            device=t_device,
            n_gpus=t_n_gpus,
        )
    else:
        teacher = TransformerTokenClassifier.load_model(
            model_path=args.teacher_model_path, model_type=args.teacher_model_type
        )
        teacher.to(device, n_gpus)

    teacher_labeled_dataset = teacher.convert_to_tensors(train_labeled_ex, args.teacher_max_seq_len)
    teacher_unlabeled_dataset = teacher.convert_to_tensors(
        train_unlabeled_ex, args.teacher_max_seq_len, False
    )

    if args.parallel_batching:
        # # concat teacher labeled+unlabeled dataset
        # teacher_dataset = ConcatTensorDataset(teacher_labeled_dataset, [teacher_unlabeled_dataset])
        # match sizes of labeled/unlabeled teacher train data for parallel batching
        larger_ds, smaller_ds = (
            (teacher_labeled_dataset, teacher_unlabeled_dataset)
            if len(teacher_labeled_dataset) > len(teacher_unlabeled_dataset)
            else (teacher_unlabeled_dataset, teacher_labeled_dataset)
        )
        concat_smaller_ds = smaller_ds
        while len(concat_smaller_ds) < len(larger_ds):
            concat_smaller_ds = ConcatTensorDataset(concat_smaller_ds, [smaller_ds])
        if len(concat_smaller_ds[0]) == 4:
            teacher_unlabeled_dataset = concat_smaller_ds
        else:
            teacher_labeled_dataset = concat_smaller_ds

        train_all_dataset = ParallelDataset(
            train_labeled_dataset,
            teacher_labeled_dataset,
            train_unlabeled_dataset,
            teacher_unlabeled_dataset,
        )

        train_all_sampler = RandomSampler(train_all_dataset)
        # this way must use same batch size for both labeled/unlabeled sets
        train_dl = DataLoader(
            train_all_dataset, sampler=train_all_sampler, batch_size=train_batch_size
        )

    else:
        teacher_dataset = CombinedTensorDataset(
            [teacher_labeled_dataset, teacher_unlabeled_dataset]
        )

        train_dataset = ParallelDataset(train_dataset, teacher_dataset)
        train_sampler = RandomSampler(train_dataset)
        train_dl = DataLoader(train_dataset, sampler=train_sampler, batch_size=train_batch_size)

    if dev_ex is not None:
        dev_dataset = classifier.convert_to_tensors(
            dev_ex, max_seq_length=args.max_sentence_length, max_word_length=args.max_word_length
        )
        dev_sampler = SequentialSampler(dev_dataset)
        dev_dl = DataLoader(dev_dataset, sampler=dev_sampler, batch_size=args.b)

    if test_ex is not None:
        test_dataset = classifier.convert_to_tensors(
            test_ex, max_seq_length=args.max_sentence_length, max_word_length=args.max_word_length
        )
        test_sampler = SequentialSampler(test_dataset)
        test_dl = DataLoader(test_dataset, sampler=test_sampler, batch_size=args.b)
    if args.lr is not None:
        opt = classifier.get_optimizer(lr=args.lr)

    distiller = TeacherStudentDistill(
        teacher, args.kd_temp, args.kd_dist_w, args.kd_student_w, args.kd_loss_fn
    )

    classifier.train(
        train_dl,
        dev_dl,
        test_dl,
        epochs=args.e,
        batch_size=args.b,
        logging_steps=args.logging_steps,
        save_steps=args.save_steps,
        save_path=args.output_dir,
        optimizer=opt if opt is not None else None,
        best_result_file=args.best_result_file,
        distiller=distiller,
        word_dropout=args.word_dropout,
    )

    classifier.save_model(args.output_dir)
Beispiel #2
0
def do_kd_training(args):
    prepare_output_path(args.output_dir, args.overwrite_output_dir)
    device, n_gpus = setup_backend(args.no_cuda)
    # Set seed
    set_seed(args.seed, n_gpus)
    # prepare data
    processor = TokenClsProcessor(args.data_dir, tag_col=args.tag_col)
    train_ex = processor.get_train_examples()
    dev_ex = processor.get_dev_examples()
    test_ex = processor.get_test_examples()
    vocab = processor.get_vocabulary()
    vocab_size = len(vocab) + 1
    num_labels = len(processor.get_labels()) + 1
    # create an embedder
    embedder_cls = MODEL_TYPE[args.model_type]
    if args.config_file is not None:
        embedder_model = embedder_cls.from_config(vocab_size, num_labels,
                                                  args.config_file)
    else:
        embedder_model = embedder_cls(vocab_size, num_labels)

    # load external word embeddings if present
    if args.embedding_file is not None:
        emb_dict = load_embedding_file(args.embedding_file)
        emb_mat = get_embedding_matrix(emb_dict, vocab)
        emb_mat = torch.tensor(emb_mat, dtype=torch.float)
        embedder_model.load_embeddings(emb_mat)

    classifier = NeuralTagger(embedder_model,
                              word_vocab=vocab,
                              labels=processor.get_labels(),
                              use_crf=args.use_crf,
                              device=device,
                              n_gpus=n_gpus)

    train_batch_size = args.b * max(1, n_gpus)
    train_dataset = classifier.convert_to_tensors(
        train_ex,
        max_seq_length=args.max_sentence_length,
        max_word_length=args.max_word_length)

    teacher = TransformerTokenClassifier.load_model(
        model_path=args.teacher_model_path, model_type=args.teacher_model_type)
    teacher.to(device, n_gpus)
    teacher_dataset = teacher.convert_to_tensors(train_ex,
                                                 args.max_sentence_length,
                                                 False)

    train_dataset = ParallelDataset(train_dataset, teacher_dataset)

    train_sampler = RandomSampler(train_dataset)
    train_dl = DataLoader(train_dataset,
                          sampler=train_sampler,
                          batch_size=train_batch_size)

    if dev_ex is not None:
        dev_dataset = classifier.convert_to_tensors(
            dev_ex,
            max_seq_length=args.max_sentence_length,
            max_word_length=args.max_word_length)
        dev_sampler = SequentialSampler(dev_dataset)
        dev_dl = DataLoader(dev_dataset,
                            sampler=dev_sampler,
                            batch_size=args.b)

    if test_ex is not None:
        test_dataset = classifier.convert_to_tensors(
            test_ex,
            max_seq_length=args.max_sentence_length,
            max_word_length=args.max_word_length)
        test_sampler = SequentialSampler(test_dataset)
        test_dl = DataLoader(test_dataset,
                             sampler=test_sampler,
                             batch_size=args.b)
    if args.lr is not None:
        opt = classifier.get_optimizer(lr=args.lr)

    distiller = TeacherStudentDistill(teacher, args.kd_temp, args.kd_dist_w,
                                      args.kd_student_w, args.kd_loss_fn)
    classifier.train(train_dl,
                     dev_dl,
                     test_dl,
                     epochs=args.e,
                     batch_size=args.b,
                     logging_steps=args.logging_steps,
                     save_steps=args.save_steps,
                     save_path=args.output_dir,
                     optimizer=opt if opt is not None else None,
                     distiller=distiller)
    classifier.save_model(args.output_dir)
Beispiel #3
0
                                            num_classes=chunk_labels)

    # build model with input parameters
    model = SequenceChunker(use_cudnn=args.use_cudnn)
    model.build(word_vocab_size,
                pos_labels,
                chunk_labels,
                char_vocab_size=char_vocab_size,
                max_word_len=args.max_word_length,
                feature_size=args.feature_size,
                classifier=args.classifier)

    # initialize word embedding if external model selected
    if args.embedding_model is not None:
        embedding_model, _ = load_word_embeddings(args.embedding_model)
        embedding_mat = get_embedding_matrix(embedding_model,
                                             dataset.word_vocab)
        model.load_embedding_weights(embedding_mat)

    # train the model
    if args.char_features is True:
        train_features = [words_train, char_train]
        test_features = [words_test, char_test]
    else:
        train_features = words_train
        test_features = words_test
    train_labels = [pos_train, chunk_train]
    test_labels = [pos_test, chunk_test]
    chunk_f1_cb = ConllCallback(test_features,
                                chunk_test,
                                dataset.chunk_vocab.vocab,
                                batch_size=64)