args.gpu,
                                  rnn=rnn,
                                  pre_computed_patterns=None)

    if args.gpu:
        print("Cuda!")
        model.to_cuda(model)
        state_dict = torch.load(args.input_model)
    else:
        state_dict = torch.load(args.input_model,
                                map_location=lambda storage, loc: storage)

    # Loading model
    model.load_state_dict(state_dict)

    interpret_documents(model, args.batch_size, dev_data, dev_text, args.ofile,
                        args.max_doc_len)

    return 0


if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        description=__doc__,
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
        parents=[soft_pattern_arg_parser(),
                 general_arg_parser()])
    parser.add_argument("--ofile", help="Output file", required=True)

    sys.exit(main(parser.parse_args()))
示例#2
0
    else:
        rnn = None

    model = SoftPatternClassifier(pattern_specs, mlp_hidden_dim, num_mlp_layers, num_classes, embeddings, vocab,
                                  semiring, args.bias_scale_param, args.gpu, rnn=rnn, pre_computed_patterns=None,
                                  no_sl=args.no_sl, shared_sl=args.shared_sl, no_eps=args.no_eps,
                                  eps_scale=args.eps_scale, self_loop_scale=args.self_loop_scale)

    if args.gpu:
        state_dict = torch.load(args.input_model)
    else:
        state_dict = torch.load(args.input_model, map_location=lambda storage, loc: storage)

    model.load_state_dict(state_dict)

    if args.gpu:
        model.to_cuda(model)

    visualize_patterns(model, dev_data, dev_text, args.k_best, args.max_doc_len, num_padding_tokens)

    return 0


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description=__doc__,
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter,
                                     parents=[soft_pattern_arg_parser(), general_arg_parser()])
    parser.add_argument("-k", "--k_best", help="Number of nearest neighbor phrases", type=int, default=5)

    sys.exit(main(parser.parse_args()))