def main():
    n = None
    mlp_hidden_dim = 25
    num_mlp_layers = 2

    validation_data_file = "./soft_patterns/data/test.data"
    dev_vocab = vocab_from_text(validation_data_file)
    print("Dev vocab size:", len(dev_vocab))

    embedding_file = "./soft_patterns/glove.6B.50d.txt"
    vocab, embeddings, word_dim = read_embeddings(embedding_file, dev_vocab)

    seed = 100
    torch.manual_seed(seed)
    np.random.seed(seed)

    patterns = "5-50_4-50_3-50_2-50"
    pattern_specs = OrderedDict(
        sorted(([int(y) for y in x.split("-")] for x in patterns.split("_")),
               key=lambda t: t[0]))
    num_padding_tokens = max(list(pattern_specs.keys())) - 1

    dev_input, _ = read_docs(validation_data_file,
                             vocab,
                             num_padding_tokens=num_padding_tokens)
    validation_label_file = "./soft_patterns/data/test.labels"
    dev_labels = read_labels(validation_label_file)
    dev_data = list(zip(dev_input, dev_labels))

    num_classes = len(set(dev_labels))
    print("num_classes:", num_classes)

    semiring = Semiring(zeros, ones, torch.add, torch.mul, sigmoid, identity)

    rnn = None

    model = SoftPatternClassifier(pattern_specs, mlp_hidden_dim,
                                  num_mlp_layers, num_classes, embeddings,
                                  vocab, semiring, 0.1, False, rnn, None,
                                  False, 0, False, None, None)

    input_model = "./soft_patterns/output/model_9.pth"
    state_dict = torch.load(input_model,
                            map_location=lambda storage, loc: storage)

    model.load_state_dict(state_dict)

    test_acc = evaluate_accuracy(model, dev_data, 1, False)

    print("Test accuracy: {:>8,.3f}%".format(100 * test_acc))

    return 0
def main(args):
    print(args)

    if args.seed != -1:
        torch.manual_seed(args.seed)
        np.random.seed(args.seed)

    pattern_specs = OrderedDict(sorted(([int(y) for y in x.split("-")] for x in args.patterns.split("_")),
                                       key=lambda t: t[0]))

    n = args.num_train_instances
    mlp_hidden_dim = args.mlp_hidden_dim
    num_mlp_layers = args.num_mlp_layers

    dev_vocab = vocab_from_text(args.vd)
    print("Dev vocab size:", len(dev_vocab))

    vocab, embeddings, word_dim = \
        read_embeddings(args.embedding_file, dev_vocab)

    num_padding_tokens = max(list(pattern_specs.keys())) - 1

    dev_input, dev_text = read_docs(args.vd, vocab, num_padding_tokens=num_padding_tokens)
    dev_labels = read_labels(args.vl)
    dev_data = list(zip(dev_input, dev_labels))
    if n is not None:
        dev_data = dev_data[:n]

    num_classes = len(set(dev_labels))
    print("num_classes:", num_classes)

    semiring = \
        MaxPlusSemiring if args.maxplus else (
            LogSpaceMaxTimesSemiring if args.maxtimes else ProbSemiring
        )

    if args.use_rnn:
        rnn = Rnn(word_dim,
                  args.hidden_dim,
                  cell_type=LSTM,
                  gpu=args.gpu)
    else:
        rnn = None

    model = SoftPatternClassifier(pattern_specs, mlp_hidden_dim, num_mlp_layers, num_classes, embeddings, vocab,
                                  semiring, args.bias_scale_param, args.gpu, rnn=rnn, pre_computed_patterns=None)

    if args.gpu:
        print("Cuda!")
        model.to_cuda(model)
        state_dict = torch.load(args.input_model)
    else:
        state_dict = torch.load(args.input_model, map_location=lambda storage, loc: storage)

    # Loading model
    model.load_state_dict(state_dict)

    interpret_documents(model, args.batch_size, dev_data, dev_text, args.ofile, args.max_doc_len)

    return 0
Beispiel #3
0
 def setUp(self):
     vocab, embeddings, word_dim = read_embeddings(EMBEDDINGS_FILENAME)
     self.embeddings = embeddings
     max_pattern_length = max(list(PATTERN_SPECS.keys()))
     inputs, _ = read_docs(DATA_FILENAME, vocab, num_padding_tokens=max_pattern_length - 1)
     self.data = [(x, None) for x in inputs]  # don't care about labels
     self.models = [
         SoftPatternClassifier(
             PATTERN_SPECS,
             MLP_HIDDEN_DIM,
             NUM_MLP_LAYERS,
             NUM_CLASSES,
             embeddings,
             vocab,
             SEMIRING,
             GPU
         ),
         PooledCnnClassifier(
             WINDOW_SIZE,
             NUM_CNN_LAYERS,
             CNN_HIDDEN_DIM,
             NUM_MLP_LAYERS,
             MLP_HIDDEN_DIM,
             NUM_CLASSES,
             embeddings,
             gpu=GPU
         ),
         DanClassifier(
             MLP_HIDDEN_DIM,
             NUM_MLP_LAYERS,
             NUM_CLASSES,
             embeddings,
             gpu=GPU
         ),
         AveragingRnnClassifier(
             LSTM_HIDDEN_DIM,
             MLP_HIDDEN_DIM,
             NUM_MLP_LAYERS,
             NUM_CLASSES,
             embeddings,
             gpu=GPU
         )
     ]
     self.batch_sizes = [1, 2, 4, 5, 10, 20]
def main(args):
    print(args)

    n = args.num_train_instances
    mlp_hidden_dim = args.mlp_hidden_dim
    num_mlp_layers = args.num_mlp_layers

    dev_vocab = vocab_from_text(args.vd)
    print("Dev vocab size:", len(dev_vocab))

    vocab, embeddings, word_dim = \
        read_embeddings(args.embedding_file, dev_vocab)

    if args.seed != -1:
        torch.manual_seed(args.seed)
        np.random.seed(args.seed)

    if args.dan or args.bilstm:
        num_padding_tokens = 1
    elif args.cnn:
        num_padding_tokens = args.window_size - 1
    else:
        pattern_specs = OrderedDict(sorted(([int(y) for y in x.split("-")] for x in args.patterns.split("_")),
                                           key=lambda t: t[0]))
        num_padding_tokens = max(list(pattern_specs.keys())) - 1

    dev_input, dev_text = read_docs(args.vd, vocab, num_padding_tokens=num_padding_tokens)
    dev_labels = read_labels(args.vl)
    dev_data = list(zip(dev_input, dev_labels))
    if n is not None:
        dev_data = dev_data[:n]

    num_classes = len(set(dev_labels))
    print("num_classes:", num_classes)

    if args.dan:
        model = DanClassifier(mlp_hidden_dim,
                              num_mlp_layers,
                              num_classes,
                              embeddings,
                              args.gpu)
    elif args.bilstm:
        cell_type = LSTM

        model = AveragingRnnClassifier(args.hidden_dim,
                                       mlp_hidden_dim,
                                       num_mlp_layers,
                                       num_classes,
                                       embeddings,
                                       cell_type=cell_type,
                                       gpu=args.gpu)
    elif args.cnn:
        model = PooledCnnClassifier(args.window_size,
            args.num_cnn_layers,
            args.cnn_hidden_dim,
            num_mlp_layers,
            mlp_hidden_dim,
            num_classes,
            embeddings,
            pooling=max_pool_seq,
            gpu=args.gpu)
    else:
        semiring = \
            MaxPlusSemiring if args.maxplus else (
                LogSpaceMaxTimesSemiring if args.maxtimes else ProbSemiring
            )

        if args.use_rnn:
            rnn = Rnn(word_dim,
                      args.hidden_dim,
                      cell_type=LSTM,
                      gpu=args.gpu)
        else:
            rnn = None

        model = SoftPatternClassifier(pattern_specs, mlp_hidden_dim, num_mlp_layers, num_classes, embeddings, vocab,
                                      semiring, args.bias_scale_param, args.gpu, rnn, None, args.no_sl, args.shared_sl,
                                      args.no_eps, args.eps_scale, args.self_loop_scale)

    if args.gpu:
        state_dict = torch.load(args.input_model)
    else:
        state_dict = torch.load(args.input_model, map_location=lambda storage, loc: storage)

    model.load_state_dict(state_dict)

    if args.gpu:
        model.to_cuda(model)

    test_acc = evaluate_accuracy(model, dev_data, args.batch_size, args.gpu)

    print("Test accuracy: {:>8,.3f}%".format(100*test_acc))

    return 0