示例#1
0
def setup(vocab_path, model_path, contents_path, label_path, model_config):
    vocab = Vocab.load(vocab_path)
    predictor = CharCNNLSTMModel(vocab, **model_config)
    device = "cuda" if torch.cuda.is_available() else "cpu"
    predictor.model.load_state_dict(
        torch.load(model_path, map_location=torch.device(device))
    )
    model = predictor.model
    model.to(device)
    model.eval()
    model_embedding = model.model_embeddings
    test_contents = read_corpus(contents_path)
    test_labels = read_labels(label_path)

    if dataset == "train":
        # only sample 30000  examples to do adversaril training
        sampled_idx = random.choices(range(len(test_contents)), k=30000)
    else:
        sampled_idx = range(len(test_contents))

    return (
        vocab,
        predictor,
        device,
        model,
        model_embedding,
        test_contents,
        test_labels,
        sampled_idx,
    )
示例#2
0
def adv_train(
    vocab_path,
    train_contents_path,
    train_label_path,
    adv_train_contents_path,
    adv_train_label_path,
    model_path,
    model_output_path,
    **model_config,
):
    vocab = Vocab.load(vocab_path)
    train_contents = read_corpus(train_contents_path)
    train_labels = read_labels(train_label_path)

    adv_train_contents = read_corpus(adv_train_contents_path)
    adv_train_labels = read_labels(adv_train_label_path)

    contents = train_contents + adv_train_contents
    labels = train_labels + adv_train_labels

    model = CharCNNLSTMModel(vocab, **model_config)
    model.model.load_state_dict(torch.load(model_path))
    model.fit(contents, labels, model_output_path)
示例#3
0
def demo(model_path,
         vocab_path,
         test_contents_path,
         test_label_path,
         num_examples=10,
         **model_config):
    vocab = Vocab.load(vocab_path)

    predictor = CharCNNLSTMModel(vocab, **model_config)
    device = "cuda" if torch.cuda.is_available() else "cpu"
    predictor.model.load_state_dict(
        torch.load(model_path, map_location=torch.device(device)))

    test_contents = read_corpus(test_contents_path)
    test_labels = read_labels(test_label_path)
    test_raw_contents = read_raw_corpus(test_contents_path)

    test_data = list(zip(test_contents, test_labels, test_raw_contents))
    demo_data = random.choices(test_data, k=num_examples)

    demo_contents = [c for c, l, r in demo_data]
    demo_labels = [l for c, l, r in demo_data]
    demo_raw_contents = [r for c, l, r in demo_data]

    demo_dataset = Dataset(demo_contents, demo_labels, vocab,
                           model_config.get("max_word_length"), "cpu")

    demo_contents, demo_labels, demo_contents_lengths = demo_dataset[len(
        demo_dataset)]

    with torch.no_grad():
        pred = predictor.model(demo_contents, demo_contents_lengths)
        predicted_labels = torch.argmax(pred, dim=1)

    for content, gt, pr in zip(demo_raw_contents, demo_labels,
                               predicted_labels):
        print("Content:", content)
        print("Predicted category:", index_mapping[int(pr)])
        print("Ground truth category:", index_mapping[int(gt)])
        print("\n")
def infer(model_path, vocab_path, test_contents_path, test_label_path,
          **model_config):
    vocab = Vocab.load(vocab_path)
    test_contents = read_corpus(test_contents_path)
    test_labels = read_labels(test_label_path)
    test_dataset = Dataset(test_contents, test_labels, vocab,
                           model_config.get("max_word_length"), "cpu")
    predictor = CharCNNLSTMModel(vocab, **model_config)
    device = "cuda" if torch.cuda.is_available() else "cpu"
    predictor.model.load_state_dict(
        torch.load(model_path, map_location=torch.device(device)))

    batch_size = 20
    accuracies = []
    # This will drop the last few examples (<= 19)
    for batch_index in range(0, len(test_labels), batch_size):
        batch_contents, batch_labels, batch_content_lengths = test_dataset[
            batch_size]
        _, accuracy = predictor.predict(batch_contents, batch_labels,
                                        batch_content_lengths)
        accuracies.append(accuracy)

    print("test accuracy:", sum(accuracies) / len(accuracies))
示例#5
0
def train(vocab_path, train_contents_path, train_label_path, **model_config):
    vocab = Vocab.load(vocab_path)
    train_contents = read_corpus(train_contents_path)
    train_labels = read_labels(train_label_path)
    model = CharCNNLSTMModel(vocab, **model_config)
    model.fit(train_contents, train_labels)