コード例 #1
0
def build_model(X_val=None, k=20):
    preprocessor = build_preprocessor(min_freq=1)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = SeqNet(
        module=Model,
        module__vocab_size=100,  # Dummy dimension
        module__emb_dim=100,
        # module__hidden_dim=100,
        optimizer=torch.optim.Adam,
        optimizer__lr=0.002,
        criterion=BPRLoss,
        max_epochs=10,
        batch_size=512,
        iterator_train=NegativeSamplingIterator,
        # iterator_train=SequenceIterator,
        iterator_train__neg_samples=100,
        iterator_train__ns_exponent=0.,
        iterator_train__shuffle=True,
        iterator_train__sort=True,
        iterator_train__sort_key=lambda x: len(x.text),
        iterator_valid=SequenceIterator,
        # iterator_valid__neg_samples=100,
        # iterator_valid__ns_exponent=0.,
        iterator_valid__shuffle=False,
        iterator_valid__sort=False,
        train_split=partial(train_split, prep=preprocessor, X_val=X_val),
        device=device,
        predict_nonlinearity=partial(inference, k=k, device=device),
        callbacks=[
            skorch.callbacks.Initializer("*_fc*", fn=xavier_init),
            skorch.callbacks.GradientNormClipping(1.),  # Original paper
            DynamicVariablesSetter(),
            skorch.callbacks.EpochScoring(
                partial(ppx, entry="valid_loss"),
                name="perplexity",
                use_caching=False,
                lower_is_better=False,
            ),
            skorch.callbacks.BatchScoring(partial(scoring, k=k, func=recall),
                                          name="recall@20",
                                          on_train=False,
                                          lower_is_better=False,
                                          use_caching=True),
            skorch.callbacks.BatchScoring(partial(scoring, k=k, func=rr),
                                          name="mrr@20",
                                          on_train=False,
                                          lower_is_better=False,
                                          use_caching=True),
            skorch.callbacks.ProgressBar('count'),
        ],
    )

    full = make_pipeline(
        preprocessor,
        model,
    )
    return full
コード例 #2
0
def test_data(flat_data, flat_oov, batch_size=32):
    model = build_preprocessor().fit(flat_data)
    dataset = model.transform(flat_data)
    batch = next(iter(BucketIterator(dataset, batch_size=batch_size)))
    assert batch.text.shape[0] == batch_size
    assert batch.text.shape[1] == 4
    oov = model.transform(flat_oov)

    oov_batch = next(iter(BucketIterator(oov, batch_size=batch_size)))
    assert oov_batch.text.shape == batch.text.shape
    assert (oov_batch.text == 0).all()