Ejemplo n.º 1
0
def main():
    parser = argparse.ArgumentParser(__doc__)
    parser.add_argument(
        "bert_model", type=str, help="Variant of pre-trained model.")
    parser.add_argument(
        "layer", type=int,
        help="Layer from of layer from which the representation is taken.")
    parser.add_argument(
        "data_lng1", type=str,
        help="Sentences with language for training.")
    parser.add_argument(
        "data_lng2", type=str,
        help="Sentences with language for training.")
    parser.add_argument(
        "save_model", type=str, help="Path to the saved model.")
    parser.add_argument(
        "--mean-pool", default=False, action="store_true",
        help="If true, use mean-pooling instead of [CLS] vector.")
    parser.add_argument("--num-threads", type=int, default=4)
    args = parser.parse_args()

    torch.set_num_threads(args.num_threads)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    tokenizer, model = load_bert(args.bert_model, device)[:2]

    print(f"Loading representation for {args.data_lng1}", file=sys.stderr)
    lng1_repr = repr_for_text_file(
        args.data_lng1, model, tokenizer, args.layer, args.mean_pool)
    print(f"Loading representation for {args.data_lng2}", file=sys.stderr)
    lng2_repr = repr_for_text_file(
        args.data_lng2, model, tokenizer, args.layer, args.mean_pool)
    print("BERT representations loaded.", file=sys.stderr)

    print("Fitting the projection.", file=sys.stderr)
    model = LinearRegression()
    model.fit(lng1_repr, lng2_repr)
    print("Done, saving model.", file=sys.stderr)

    joblib.dump(model, args.save_model)
Ejemplo n.º 2
0
def main():
    parser = argparse.ArgumentParser(__doc__)
    parser.add_argument("bert_model",
                        type=str,
                        help="Variant of pre-trained model.")
    parser.add_argument(
        "layer",
        type=int,
        help="Layer from of layer from which the representation is taken.")
    parser.add_argument("src", type=str, help="Sentences in source language.")
    parser.add_argument("mt", type=str, help="Machine-translated sentences.")
    parser.add_argument(
        "--mean-pool",
        default=False,
        action="store_true",
        help="If true, use mean-pooling instead of [CLS] vecotr.")
    parser.add_argument("--center-lng",
                        default=False,
                        action="store_true",
                        help="If true, center representations first.")
    parser.add_argument("--batch-size", type=int, default=32)
    parser.add_argument("--src-proj",
                        default=None,
                        type=str,
                        help="Sklearn projection of the source language.")
    parser.add_argument("--mt-proj",
                        default=None,
                        type=str,
                        help="Sklearn projection of the target language.")
    parser.add_argument("--num-threads", type=int, default=4)
    args = parser.parse_args()

    if args.center_lng and (args.src_proj is not None
                            and args.src_proj is not None):
        print(
            "You can either project or center "
            "the representations, not both.",
            file=sys.stderr)
        exit(1)

    torch.set_num_threads(args.num_threads)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    tokenizer, model = load_bert(args.bert_model, device)[:2]

    src_repr = repr_for_txt_file(args.src,
                                 tokenizer,
                                 model,
                                 device,
                                 args.layer,
                                 center_lng=args.center_lng,
                                 mean_pool=args.mean_pool)
    mt_repr = repr_for_txt_file(args.mt,
                                tokenizer,
                                model,
                                device,
                                args.layer,
                                center_lng=args.center_lng,
                                mean_pool=args.mean_pool)

    if args.src_proj is not None:
        src_repr = apply_sklearn_proj(src_repr, args.src_proj)
    if args.mt_proj is not None:
        mt_repr = apply_sklearn_proj(mt_repr, args.mt_proj)

    src_norm = (src_repr * src_repr).sum(1).sqrt()
    mt_norm = (mt_repr * mt_repr).sum(1).sqrt()

    cosine = (src_repr * mt_repr).sum(1) / src_norm / mt_norm

    for num in cosine.cpu().detach().numpy():
        print(num)
Ejemplo n.º 3
0
def main():
    parser = argparse.ArgumentParser(__doc__)
    parser.add_argument("bert_model",
                        type=str,
                        help="Variant of pre-trained model.")
    parser.add_argument(
        "layer",
        type=int,
        help="Layer from of layer from which the representation is taken.")
    parser.add_argument("languages",
                        type=str,
                        help="File with a list of languages.")
    parser.add_argument("train_data_txt", type=str, help="Training sentences.")
    parser.add_argument("train_data_lng",
                        type=str,
                        help="Language codes for training sentences.")
    parser.add_argument("val_data_txt", type=str, help="Validation sentences.")
    parser.add_argument("val_data_lng",
                        type=str,
                        help="Language codes for validation sentences.")
    parser.add_argument("test_data_txt", type=str, help="Test sentences.")
    parser.add_argument("test_data_lng",
                        type=str,
                        help="Language codes for test sentences.")
    parser.add_argument("--hidden",
                        default=None,
                        type=int,
                        help="Size of the hidden classification layer.")
    parser.add_argument("--num-threads", type=int, default=4)
    parser.add_argument("--save-model",
                        type=str,
                        help="Path where to save the best model.")
    parser.add_argument("--save-centroids",
                        type=str,
                        help="Path to save language centroids.")
    parser.add_argument("--test-output",
                        type=str,
                        default=None,
                        help="Output for example classification.")
    parser.add_argument("--skip-tokenization",
                        default=False,
                        action="store_true",
                        help="Only split on spaces, skip wordpieces.")
    parser.add_argument(
        "--mean-pool",
        default=False,
        action="store_true",
        help="If true, use mean-pooling instead of [CLS] vecotr.")
    parser.add_argument(
        "--center-lng",
        default=False,
        action="store_true",
        help="Center languages to be around coordinate origin.")
    args = parser.parse_args()

    with open(args.languages) as f_lang:
        languages = [line.strip() for line in f_lang]
    lng2idx = {lng: i for i, lng in enumerate(languages)}

    torch.set_num_threads(args.num_threads)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    tokenizer, model, model_dim, _ = load_bert(args.bert_model, device)

    if args.layer < -1:
        print("Layer index cannot be negative.")
        exit(1)

    num_layers = None
    if hasattr(model.config, "num_hidden_layers"):
        num_layers = model.config.num_hidden_layers
    if hasattr(model.config, "n_layers"):
        num_layers = model.config.n_layers
    if args.layer >= num_layers:
        print(f"Model only has {num_layers} layers, {args.layer} is too much.")
        exit(1)

    train_batches = load_and_batch_data(args.train_data_txt,
                                        args.train_data_lng,
                                        tokenizer,
                                        lng2idx,
                                        batch_size=32,
                                        epochs=1000)
    print("Train data iterator initialized.")

    centroids = None
    if args.center_lng:
        print("Estimating language centroids.")
        with torch.no_grad():
            texts, labels = [], []
            for _, (txt, lab) in zip(range(100), train_batches):
                texts.append(txt)
                labels.append(lab)
            centroids = get_centroids(device,
                                      model,
                                      texts,
                                      languages,
                                      labels,
                                      args.layer,
                                      tokenizer,
                                      mean_pool=args.mean_pool)
        centroids = centroids.to(device)

        if args.save_centroids:
            torch.save(centroids.cpu(), args.save_centroids)

    print("Loading validation data.")
    val_batches_raw = list(
        load_and_batch_data(args.val_data_txt,
                            args.val_data_lng,
                            tokenizer,
                            lng2idx,
                            batch_size=32,
                            epochs=1))
    print("Validation data loaded in memory, pre-computing BERT.")
    val_batches = []
    with torch.no_grad():
        for tokens, lng in val_batches_raw:
            bert_features = get_repr_from_layer(model, tokens.to(device),
                                                args.layer,
                                                tokenizer.pad_token_id,
                                                args.mean_pool).cpu()
            val_batches.append((bert_features, lng))

    print("Loading test data.")
    test_batches_raw = list(
        load_and_batch_data(args.test_data_txt,
                            args.test_data_lng,
                            tokenizer,
                            lng2idx,
                            batch_size=32,
                            epochs=1))
    print("Test data loaded in memory, pre-computing BERT.")
    test_batches = []
    with torch.no_grad():
        for tokens, lng in test_batches_raw:
            bert_features = get_repr_from_layer(model, tokens.to(device),
                                                args.layer,
                                                tokenizer.pad_token_id,
                                                args.mean_pool).cpu()
            test_batches.append((bert_features, lng))
    print()

    test_accuracies = []
    all_test_outputs = []
    trained_models = []

    for exp_no in range(5):
        print(f"Starting experiment no {exp_no + 1}")
        print(f"------------------------------------")
        if args.hidden is None:
            classifier = nn.Linear(model_dim, len(languages))
        else:
            classifier = nn.Sequential(nn.Linear(model_dim, args.hidden),
                                       nn.ReLU(), nn.Dropout(0.1),
                                       nn.Linear(args.hidden, len(languages)))
        classifier = classifier.to(device)
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(classifier.parameters(), lr=1e-3)

        def evaluate(data_batches):
            classifier.eval()
            with torch.no_grad():
                running_val_loss = 0.
                running_val_acc = 0.
                val_count = 0
                outputs = []

                for bert_features, lng in data_batches:
                    bert_features, lng = (bert_features.to(device),
                                          lng.to(device))
                    batch_size = bert_features.size(0)

                    if centroids is not None:
                        bert_features = bert_features - centroids[lng]
                    prediction = classifier(bert_features)
                    batch_loss = criterion(prediction, lng)

                    predicted_lng = prediction.max(-1)[1]
                    batch_accuracy = torch.sum((predicted_lng == lng).float())

                    running_val_loss += (batch_size *
                                         batch_loss.cpu().numpy().tolist())
                    running_val_acc += batch_accuracy.cpu().numpy().tolist()
                    val_count += batch_size

                    outputs.extend(predicted_lng.cpu().numpy().tolist())

                val_loss = running_val_loss / val_count
                accuracy = running_val_acc / val_count

            return val_loss, accuracy, outputs

        best_accuracy = 0.0
        no_improvement = 0
        learning_rate_decreased = 0
        learning_rate = 1e-3

        for i, (sentences, lng) in enumerate(train_batches):
            try:
                classifier.train()
                optimizer.zero_grad()
                sentences, lng = sentences.to(device), lng.to(device)
                bert_features = get_repr_from_layer(model,
                                                    sentences,
                                                    args.layer,
                                                    tokenizer.pad_token_id,
                                                    mean_pool=args.mean_pool)

                if centroids is not None:
                    with torch.no_grad():
                        bert_features = bert_features - centroids[lng]

                prediction = classifier(bert_features)

                loss = criterion(prediction, lng)

                loss.backward()
                optimizer.step()

                if i % 10 == 9:
                    print(f"loss: {loss.cpu().detach().numpy().tolist():5g}")

                if i % 50 == 49:
                    print()
                    val_loss, accuracy, _ = evaluate(val_batches)

                    print("Validation: "
                          f"loss: {val_loss:5g}, "
                          f"accuracy: {accuracy:5g}")

                    if accuracy > best_accuracy:
                        best_accuracy = accuracy
                        no_improvement = 0
                    else:
                        no_improvement += 1

                    if no_improvement >= 5:
                        if learning_rate_decreased >= 5:
                            print(
                                "Learning rate decreased five times, ending.")
                            break

                        learning_rate /= 2
                        print(f"Decreasing learning rate to {learning_rate}.")
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = learning_rate
                        learning_rate_decreased += 1
                        no_improvement = 0

                    print()
            except KeyboardInterrupt:
                break

        model.eval()
        test_loss, test_accuracy, test_outputs = evaluate(test_batches)
        print()
        print("Testing:")
        print(f"test loss: {test_loss:5g}, "
              f"test accuracy: {test_accuracy:5g}")

        test_accuracies.append(test_accuracy)

        this_test_outputs = []
        for lng_prediction in test_outputs:
            this_test_outputs.append(languages[lng_prediction])
        all_test_outputs.append(this_test_outputs)
        trained_models.append(classifier.cpu())

    print()
    print("===============================================")
    print("All experiments done.")
    print("===============================================")
    print(f"Mean test accuracy {np.mean(test_accuracies)}")
    print(f"Mean test stdev    {np.std(test_accuracies)}")

    best_exp_id = np.argmax(test_accuracies)

    print(f"Best test accuracy {max(test_accuracies)}")

    if args.save_model:
        torch.save(trained_models[best_exp_id], args.save_model)

    if args.test_output is not None:
        with open(args.test_output, 'w') as f_out:
            for prediction in all_test_outputs[best_exp_id]:
                print(prediction, file=f_out)
Ejemplo n.º 4
0
def main():
    parser = argparse.ArgumentParser(__doc__)
    parser.add_argument("bert_model",
                        type=str,
                        help="Variant of pre-trained model.")
    parser.add_argument(
        "layer",
        type=int,
        help="Layer from of layer from which the representation is taken.")
    parser.add_argument("train_src",
                        type=str,
                        help="Sentences in source language for training.")
    parser.add_argument("train_mt",
                        type=str,
                        help="Machine-translated sentences for training.")
    parser.add_argument("train_hter",
                        type=str,
                        help="Machine-translated sentences for training.")
    parser.add_argument("test_src",
                        type=str,
                        help="Sentences in source language for testing.")
    parser.add_argument("test_mt",
                        type=str,
                        help="Machine-translated sentences for testing.")
    parser.add_argument(
        "--exclude-src",
        default=False,
        action="store_true",
        help="Exclude source representatiion from the classifier.")
    parser.add_argument(
        "--exclude-mt",
        default=False,
        action="store_true",
        help="Exclude target representatiion from the classifier.")
    parser.add_argument(
        "--mean-pool",
        default=False,
        action="store_true",
        help="If true, use mean-pooling instead of [CLS] vecotr.")
    parser.add_argument("--batch-size", type=int, default=32)
    parser.add_argument("--num-threads", type=int, default=4)
    args = parser.parse_args()

    if args.exclude_src and args.exclude_mt:
        print("You cannot exclude both source and MT!", file=sys.stderr)
        exit(1)

    torch.set_num_threads(args.num_threads)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    tokenizer, model = load_bert(args.bert_model, device)[:2]

    if not args.exclude_src:
        train_src_repr = repr_for_txt_file(args.train_src,
                                           tokenizer,
                                           model,
                                           device,
                                           args.layer,
                                           center_lng=False,
                                           mean_pool=args.mean_pool).numpy()
    if not args.exclude_mt:
        train_mt_repr = repr_for_txt_file(args.train_mt,
                                          tokenizer,
                                          model,
                                          device,
                                          args.layer,
                                          center_lng=False,
                                          mean_pool=args.mean_pool).numpy()
    if args.exclude_src:
        train_inputs = train_mt_repr
    elif args.exclude_mt:
        train_inputs = train_src_repr
    else:
        train_inputs = np.concatenate((train_src_repr, train_mt_repr), axis=1)

    with open(args.train_hter) as f_tgt:
        train_targets = np.array([float(line.rstrip()) for line in f_tgt])

    print("Training regression ... ", file=sys.stderr, end="", flush=True)
    regressor = MLPRegressor((256), early_stopping=True)
    regressor.fit(train_inputs, train_targets)
    print("Done.", file=sys.stderr)

    if not args.exclude_src:
        test_src_repr = repr_for_txt_file(args.test_src,
                                          tokenizer,
                                          model,
                                          device,
                                          args.layer,
                                          center_lng=False,
                                          mean_pool=args.mean_pool).numpy()
    if not args.exclude_mt:
        test_mt_repr = repr_for_txt_file(args.test_mt,
                                         tokenizer,
                                         model,
                                         device,
                                         args.layer,
                                         center_lng=False,
                                         mean_pool=args.mean_pool).numpy()

    if args.exclude_src:
        test_inputs = test_mt_repr
    elif args.exclude_mt:
        test_inputs = test_src_repr
    else:
        test_inputs = np.concatenate((test_src_repr, test_mt_repr), axis=1)

    predictions = regressor.predict(test_inputs)

    for num in predictions:
        print(num)
def main():
    parser = argparse.ArgumentParser(__doc__)
    parser.add_argument("bert_model",
                        type=str,
                        help="Variant of pre-trained model.")
    parser.add_argument(
        "layer",
        type=int,
        help="Layer from of layer from which the representation is taken.")
    parser.add_argument("data",
                        type=str,
                        nargs="+",
                        help="Sentences with language for training.")
    parser.add_argument("--distance",
                        choices=["cosine", "euklid"],
                        default="cosine")
    parser.add_argument("--skip-tokenization",
                        default=False,
                        action="store_true",
                        help="Only split on spaces, skip wordpieces.")
    parser.add_argument(
        "--mean-pool",
        default=False,
        action="store_true",
        help="If true, use mean-pooling instead of [CLS] vector.")
    parser.add_argument(
        "--center-lng",
        default=False,
        action="store_true",
        help="Center languages to be around coordinate origin.")
    parser.add_argument(
        "--projections",
        default=None,
        nargs="+",
        help="List of sklearn projections for particular languages.")
    parser.add_argument("--em-iterations",
                        default=None,
                        type=int,
                        help="Iterations of projection self-learning.")
    parser.add_argument("--num-threads", type=int, default=4)
    args = parser.parse_args()

    if args.center_lng and args.projections is not None:
        print("You cannot do projections and centering at once.",
              file=sys.stderr)
        exit(1)
    if (args.projections is not None
            and len(args.projections) != len(args.data)):
        print("You must have a projection for each data file.",
              file=sys.stderr)
        exit(1)
    if (args.projections is not None and args.em_iterations is not None):
        print("You either have pre-trained projections or self-train them.",
              file=sys.stderr)
        exit(1)

    projections = None
    if args.projections is not None:
        projections = []
        for proj_str in args.projections:
            if proj_str == "None":
                projections.append(None)
            else:
                projections.append(joblib.load(proj_str))

    distance_fn = None
    if args.distance == "cosine":
        distance_fn = cosine_distances
    elif args.distance == "euklid":
        distance_fn = euklid_distances
    else:
        raise ValueError("Unknown distance function.")

    torch.set_num_threads(args.num_threads)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    tokenizer, model = load_bert(args.bert_model, device)[:2]

    representations = []

    with torch.no_grad():
        for i, text_file in enumerate(args.data):
            print(f"Processing {text_file}")
            vectors = [
                get_repr_from_layer(model,
                                    sentence_tensor,
                                    args.layer,
                                    tokenizer.pad_token_id,
                                    mean_pool=args.mean_pool)
                for sentence_tensor in batch_generator(
                    text_data_generator(text_file, tokenizer), 64, tokenizer)
            ]

            lng_repr = torch.cat(vectors, dim=0)
            if args.center_lng:
                lng_repr = lng_repr - lng_repr.mean(0, keepdim=True)

            if projections is not None and projections[i] is not None:
                proj = projections[i]
                lng_repr = torch.from_numpy(proj.predict(lng_repr.numpy()))

            representations.append(lng_repr)

        mutual_projections = None
        if args.em_iterations is not None:
            print(f"EM training ...")
            new_mutual_projections = {}
            for i in range(args.em_iterations):
                print(f" ... iteration {i + 1}")
                for lng1, repr1 in zip(args.data, representations):
                    for lng2, repr2 in zip(args.data, representations):
                        if mutual_projections is not None:
                            proj = mutual_projections[(lng1, lng2)]
                            repr1 = torch.from_numpy(
                                proj.predict(repr1.numpy()))

                        distances = distance_fn(repr1, repr2)
                        retrieved = repr2[distances.min(dim=1)[1]]
                        proj = LinearRegression()
                        proj.fit(repr1.numpy(), retrieved.numpy())
                        new_mutual_projections[(lng1, lng2)] = proj
                mutual_projections = new_mutual_projections

        data_len = representations[0].shape[0]
        assert all(r.shape[0] == data_len for r in representations)
        print()
        for k in [1, 5, 10, 20, 50, 100]:
            print(f"Recall at {k}, random baseline {k / data_len:.5f}")
            print("--", end="\t")
            for lng in args.data:
                print(lng[-6:-4], end="\t")
            print()

            recalls_to_avg = []

            for lng1, repr1 in zip(args.data, representations):
                print(lng1[-6:-4], end="\t")
                for lng2, repr2 in zip(args.data, representations):

                    if mutual_projections is not None:
                        proj = mutual_projections[(lng1, lng2)]
                        repr1 = torch.from_numpy(proj.predict(repr1.numpy()))

                    distances = distance_fn(repr1, repr2)

                    recall = recall_at_k_from_distances(distances, k)
                    print(f"{recall.numpy():.5f}", end="\t")

                    if lng1 != lng2:
                        recalls_to_avg.append(recall.numpy())
                print()
            print(f"On average: {np.mean(recalls_to_avg):.5f}")
            print()
Ejemplo n.º 6
0
def main():
    parser = argparse.ArgumentParser(__doc__)
    parser.add_argument("bert_model",
                        type=str,
                        help="Variant of pre-trained model.")
    parser.add_argument(
        "layer",
        type=int,
        help="Layer from of layer from which the representation is taken.")
    parser.add_argument("language_list",
                        type=str,
                        help="TSV file with available languages.")
    parser.add_argument("data", type=str, help="Directory with txt files.")
    parser.add_argument("target",
                        type=str,
                        help="npz file with saved centroids.")
    parser.add_argument("--num-threads", type=int, default=4)
    parser.add_argument(
        "--mean-pool",
        default=False,
        action="store_true",
        help="If true, use mean-pooling instead of [CLS] vecotr.")
    parser.add_argument("--batch-size", type=int, default=32)
    parser.add_argument("--batch-count", type=int, default=200)
    args = parser.parse_args()

    torch.set_num_threads(args.num_threads)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    tokenizer, model = load_bert(args.bert_model, device)[:2]

    language_names = []
    centroids = []

    with open(args.language_list) as lng_f:
        for line in lng_f:
            name, code = line.strip().split("\t")
            data_file = os.path.join(args.data, f"{code}.txt")

            data = text_data_generator(data_file, tokenizer)
            batches = batch_generator(data, args.batch_size, tokenizer)
            print(f"Data iterator initialized: {data_file}")

            with torch.no_grad():
                representations = []
                for _, txt in zip(range(args.batch_count), batches):
                    batch_repr = get_repr_from_layer(
                        model,
                        txt.to(device),
                        args.layer,
                        tokenizer.pad_token_id,
                        mean_pool=args.mean_pool).cpu().numpy()
                    if not np.any(np.isnan(batch_repr)):
                        representations.append(batch_repr)

                if representations:
                    language_names.append(name)
                    centroid = np.concatenate(representations, axis=0).mean(0)
                    centroids.append(centroid)

    print("Centroids computed.")

    np.savez(args.target, languages=language_names, centroids=centroids)
Ejemplo n.º 7
0
 def __init__(self, latent_size):
     super().__init__()
     self.text_encoder = load_bert(latent_size)
def main():
    parser = argparse.ArgumentParser(__doc__)
    parser.add_argument("bert_model",
                        type=str,
                        help="Variant of pre-trained model.")
    parser.add_argument(
        "layer",
        type=int,
        help="Layer from of layer from which the representation is taken.")
    parser.add_argument("src", type=str, help="Sentences in source language.")
    parser.add_argument("mt", type=str, help="Machine-translated sentences.")
    parser.add_argument("--center-lng",
                        default=False,
                        action="store_true",
                        help="If true, center representations first.")
    parser.add_argument("--src-proj",
                        default=None,
                        type=str,
                        help="Sklearn projection of the source language.")
    parser.add_argument("--mt-proj",
                        default=None,
                        type=str,
                        help="Sklearn projection of the target language.")
    parser.add_argument("--num-threads", type=int, default=4)
    args = parser.parse_args()

    if args.center_lng and (args.src_proj is not None
                            and args.src_proj is not None):
        print(
            "You can either project or center "
            "the representations, not both.",
            file=sys.stderr)
        exit(1)

    torch.set_num_threads(args.num_threads)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    tokenizer, model = load_bert(args.bert_model, device)[:2]

    print(f"Loading src: {args.src}", file=sys.stderr)
    with open(args.src) as f_src:
        with torch.no_grad():
            src_repr = [
                vectors_for_sentence(tokenizer, model, line.rstrip(),
                                     args.layer)[0].numpy() for line in f_src
            ]

    print(f"Loading mt: {args.mt}", file=sys.stderr)
    with open(args.mt) as f_mt:
        with torch.no_grad():
            mt_repr = [
                vectors_for_sentence(tokenizer, model, line.rstrip(),
                                     args.layer)[0].numpy() for line in f_mt
            ]

    if args.center_lng:
        src_center = np.mean(np.concatenate(src_repr), 0)
        mt_center = np.mean(np.concatenate(mt_repr), 0)

        src_repr = [r - src_center for r in src_repr]
        mt_repr = [r - mt_center for r in mt_repr]

    if args.src_proj is not None:
        src_repr = apply_sklearn_proj(src_repr, args.src_proj)
    if args.mt_proj is not None:
        mt_repr = apply_sklearn_proj(mt_repr, args.mt_proj)

    for src, mt in zip(src_repr, mt_repr):
        similarity = (np.dot(src, mt.T) /
                      np.expand_dims(np.linalg.norm(src, axis=1), 1) /
                      np.expand_dims(np.linalg.norm(mt, axis=1), 0))

        recall = similarity.max(1).sum() / similarity.shape[0]
        precision = similarity.max(0).sum() / similarity.shape[1]

        if recall + precision > 0:
            f_score = 2 * recall * precision / (recall + precision)
        else:
            f_score = 0

        print(f_score)
Ejemplo n.º 9
0
def main():
    parser = argparse.ArgumentParser(__doc__)
    parser.add_argument("bert_model",
                        type=str,
                        help="Variant of pre-trained model.")
    parser.add_argument(
        "layer",
        type=int,
        help="Layer from of layer from which the representation is taken.")
    parser.add_argument("src", type=str, help="Sentences in source language.")
    parser.add_argument("tgt",
                        type=str,
                        help="Sentences in the target language.")
    parser.add_argument("--center-lng",
                        default=False,
                        action="store_true",
                        help="If true, center representations first.")
    parser.add_argument("--src-proj",
                        default=None,
                        type=str,
                        help="Sklearn projection of the source language.")
    parser.add_argument("--tgt-proj",
                        default=None,
                        type=str,
                        help="Sklearn projection of the target language.")
    parser.add_argument(
        "--reordering-penalty",
        default=1e-5,
        type=float,
        help="Penalty for long-distance alignment added to cost.")
    parser.add_argument("--verbose",
                        default=False,
                        action="store_true",
                        help="If true, print the actual alignment.")
    parser.add_argument("--iterations",
                        type=int,
                        default=0,
                        help="Number of EM iterations.")
    parser.add_argument("--train-data",
                        type=str,
                        nargs=2,
                        default=None,
                        help="Training data for EM training.")
    parser.add_argument("--save-projection",
                        type=str,
                        default=None,
                        help="Location to save the word projection.")
    parser.add_argument("--num-threads", type=int, default=4)
    args = parser.parse_args()

    if args.center_lng and (args.src_proj is not None
                            and args.tgt_proj is not None):
        print(
            "You can either project or center "
            "the representations, not both.",
            file=sys.stderr)
        exit(1)

    torch.set_num_threads(args.num_threads)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    tokenizer, model = load_bert(args.bert_model, device)[:2]

    proj = None
    if args.iterations > 0:
        if args.train_data is None:
            print("You need to specify train data for EM training.",
                  file=sys.stderr)
            exit(1)

        print("Loading training data.", file=sys.stderr)
        train_src_repr, train_tgt_repr = load_data(
            args.train_data[0], args.train_data[1], model, tokenizer,
            args.layer, args.center_lng, args.src_proj, args.tgt_proj)

        for iteration in range(args.iterations):
            print(f"Iteration {iteration + 1}", file=sys.stderr)
            proj = em_step(train_src_repr, train_tgt_repr,
                           args.reordering_penalty, proj)
            print("Done.", file=sys.stderr)

        if args.save_projection:
            joblib.dump(proj, args.save_projection)

    print("Loading test data.", file=sys.stderr)
    src_repr, tgt_repr = load_data(args.src, args.tgt, model, tokenizer,
                                   args.layer, args.center_lng, args.src_proj,
                                   args.tgt_proj)

    for (src_mat, src_tok), (tgt_mat, tgt_tok) in zip(src_repr, tgt_repr):
        alignment = align(src_mat, tgt_mat, args.reordering_penalty, proj)

        if args.verbose:
            for i, token in enumerate(src_tok):
                aligned_indices = [tix for six, tix in alignment if six == i]
                aligned_formatted = [
                    f"{tgt_tok[j]} ({j})" for j in aligned_indices
                ]
                print(f"{i:2d}: {token} -- {', '.join(aligned_formatted)}")
            print()
        else:
            print(" ".join(f"{src_id}-{tgt_id}"
                           for src_id, tgt_id in alignment))