Пример #1
0
    def __init__(self, args):
        super(SequenceEncoder, self).__init__()
        self.embedding = str2embedding[args.embedding](args, len(args.tokenizer.vocab))
        self.encoder = str2encoder[args.encoder](args)

    def forward(self, src, seg):
        emb = self.embedding(src, seg)
        output = self.encoder(emb, seg)

        return output


if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    model_opts(parser)

    parser.add_argument("--load_model_path", default=None, type=str,
                        help="Path of the input model.")
    parser.add_argument("--vocab_path", default=None, type=str,
                        help="Path of the vocabulary file.")
    parser.add_argument("--cand_vocab_path", default=None, type=str,
                        help="Path of the candidate vocabulary file.")
    parser.add_argument("--test_path", type=str, required=True,
                        help="Path of the target word an its context.")
    parser.add_argument("--config_path", default="models/bert/base_config.json", type=str,
                        help="Path of the config file.")

    parser.add_argument("--tokenizer", choices=["bert", "char", "space"], default="bert",
                        help="Specify the tokenizer."
                             "Original Google BERT uses bert tokenizer on Chinese corpus."
Пример #2
0
def main():
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    # Path options.
    parser.add_argument("--load_model_path",
                        default=None,
                        type=str,
                        help="Path of the classfier model.")
    parser.add_argument("--vocab_path",
                        default=None,
                        type=str,
                        help="Path of the vocabulary file.")
    parser.add_argument("--spm_model_path",
                        default=None,
                        type=str,
                        help="Path of the sentence piece model.")
    parser.add_argument("--test_path", type=str, help="Path of the testset.")
    parser.add_argument("--test_features_path",
                        default=None,
                        type=str,
                        help="Path of the test features for stacking.")
    parser.add_argument("--config_path",
                        default="models/bert/base_config.json",
                        type=str,
                        help="Path of the config file.")

    # Model options.
    model_opts(parser)
    parser.add_argument("--pooling",
                        choices=["mean", "max", "first", "last"],
                        default="first",
                        help="Pooling type.")

    # Inference options.
    parser.add_argument("--batch_size",
                        type=int,
                        default=64,
                        help="Batch size.")
    parser.add_argument("--seq_length",
                        type=int,
                        default=128,
                        help="Sequence length.")
    parser.add_argument("--labels_num",
                        type=int,
                        required=True,
                        help="Number of prediction labels.")

    # Tokenizer options.
    parser.add_argument(
        "--tokenizer",
        choices=["bert", "char", "space"],
        default="bert",
        help="Specify the tokenizer."
        "Original Google BERT uses bert tokenizer on Chinese corpus."
        "Char tokenizer segments sentences into characters."
        "Space tokenizer segments sentences into words according to space.")

    # Output options.
    parser.add_argument("--output_logits",
                        action="store_true",
                        help="Write logits to output file.")
    parser.add_argument("--output_prob",
                        action="store_true",
                        help="Write probabilities to output file.")

    # Cross validation options.
    parser.add_argument("--folds_num",
                        type=int,
                        default=5,
                        help="The number of folds for cross validation.")

    args = parser.parse_args()

    # Load the hyperparameters from the config file.
    args = load_hyperparam(args)

    # Build tokenizer.
    args.tokenizer = str2tokenizer[args.tokenizer](args)

    # Build classification model and load parameters.
    args.soft_targets, args.soft_alpha = False, False

    dataset = read_dataset(args, args.test_path)

    src = torch.LongTensor([sample[0] for sample in dataset])
    seg = torch.LongTensor([sample[1] for sample in dataset])

    batch_size = args.batch_size
    instances_num = src.size()[0]

    print("The number of prediction instances: ", instances_num)

    test_features = [[] for _ in range(args.folds_num)]
    for fold_id in range(args.folds_num):
        load_model_name = ".".join(args.load_model_path.split(".")[:-1])
        load_model_suffix = args.load_model_path.split(".")[-1]

        model = Classifier(args)
        model = load_model(
            model, load_model_name + "-fold_" + str(fold_id) + "." +
            load_model_suffix)

        # For simplicity, we use DataParallel wrapper to use multiple GPUs.
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        model = model.to(device)
        if torch.cuda.device_count() > 1:
            print("{} GPUs are available. Let's use them.".format(
                torch.cuda.device_count()))
            model = torch.nn.DataParallel(model)

        model.eval()
        for i, (src_batch,
                seg_batch) in enumerate(batch_loader(batch_size, src, seg)):
            src_batch = src_batch.to(device)
            seg_batch = seg_batch.to(device)
            with torch.no_grad():
                _, logits = model(src_batch, None, seg_batch)

            prob = nn.Softmax(dim=1)(logits)
            prob = prob.cpu().numpy().tolist()
            test_features[fold_id].extend(prob)

    test_features = np.array(test_features)
    test_features = np.mean(test_features, axis=0)
    np.save(args.test_features_path, test_features)