def test_mark_teach_minbatch(test_conf, test_corp): data_rows = [test_corp.tokens_to_ids([ "i", "am" ]), test_corp.tokens_to_ids([ "i" ])] teach_rows = [test_corp.tokens_to_ids(["<sj>", "i", "</sj>", "<v>", "am", "</v>"]), test_corp.tokens_to_ids(["<sj>", "i", "</sj>"])] batch = MarkTeacherMinBatch(test_conf, test_corp, data_rows, teach_rows) f = lambda x: test_corp.ids_to_tokens(list(x)) assert f(batch.data_batch_at(0)) == ["i", "i"] assert f(batch.data_batch_at(1)) == ["am", "<pad>"] assert (batch.teach_batch_at(0)[0] == mark.convert_types_to_vec(['<sj>'])).all() assert (batch.teach_batch_at(0)[1] == mark.convert_types_to_vec(['<sj>'])).all() assert (batch.teach_batch_at(1)[0] == mark.convert_types_to_vec(['<v>'])).all() assert batch.teach_batch_at(1)[1] == -1
def test_v2_train(test_corp): conf, encdec, opt = build_model("v2", test_corp) batch_size = 2 train_idxs, test_idxs, trains, tests = MarkTeacherMinBatch.randomized_from_corpus(conf, conf.corpus, batch_size) loss = dummy_data_train(conf, encdec, opt, trains[0]) assert loss < 5.0