示例#1
0
def main():
    start = time.time()
    query = sys.argv[1]
    glove = utils.load_glove()
    quest = utils.init_babi_deploy(
        os.path.join(
            os.path.join(
                os.path.join(os.path.dirname(os.path.realpath(__file__)),
                             'data'), 'corpus'), 'babi.txt'), query)

    dmn = dmn_basic.DMN_basic(babi_train_raw=quest,
                              babi_test_raw=[],
                              word2vec=glove,
                              word_vector_size=50,
                              dim=40,
                              mode='deploy',
                              answer_module='feedforward',
                              input_mask_mode="sentence",
                              memory_hops=5,
                              l2=0,
                              normalize_attention=False,
                              answer_vec='index',
                              debug=False)

    dmn.load_state(
        'states/dmn_basic.mh5.n40.bs10.babi1.epoch2.test1.20454.state')

    prediction = dmn.step_deploy()

    prediction = prediction[0][0]
    for ind in prediction.argsort()[::-1]:
        if ind < dmn.answer_size:
            print(dmn.ivocab[ind])
            break
    print('Time taken:', time.time() - start)
示例#2
0
def load_network(code_path, babi_train_raw, babi_test_raw, word2vec,
                 word_vector_size, model_json):
    if model_json in loaded_models:
        print "!> model %s is already loaded" % model_json
        return loaded_models[model_json]

    print "!> loading model %s..." % model_json

    model_file = open(model_json)
    args_dict = json.load(model_file)

    assert word_vector_size == args_dict['word_vector_size']

    args_dict['babi_train_raw'] = babi_train_raw
    args_dict['babi_test_raw'] = babi_test_raw
    args_dict['word2vec'] = word2vec

    # init class
    if args_dict['network'] == 'dmn_batch':
        raise Exception("dmn_batch did not implement predict()")
        sys.path.insert(0, code_path)
        import dmn_batch
        dmn = dmn_batch.DMN_batch(**args_dict)
        sys.path.insert(0, current_dir)

    elif args_dict['network'] == 'dmn_basic':
        raise Exception("dmn_batch did not implement predict()")
        sys.path.insert(0, code_path)
        import dmn_basic
        dmn = dmn_basic.DMN_basic(**args_dict)
        sys.path.insert(0, current_dir)

    elif args_dict['network'] == 'dmn_smooth':
        sys.path.insert(0, code_path)
        import dmn_smooth
        dmn = dmn_smooth.DMN_smooth(**args_dict)
        sys.path.insert(0, current_dir)

    elif args_dict['network'] == 'dmn_qa':
        raise Exception("dmn_batch did not implement predict()")

        sys.path.insert(0, code_path)
        import dmn_qa_draft
        dmn = dmn_qa_draft.DMN_qa(**args_dict)
        sys.path.insert(0, current_dir)

    else:
        raise Exception("No such network known: " + args_dict['network'])

    print "!> loading state %s..." % args_dict['load_state']

    dmn.load_state(args_dict['load_state'])

    loaded_models[model_json] = dmn

    return dmn
示例#3
0
def get_dmn(network, batch_size, args_dict):
    # init class
    if network == 'dmn_batch':
        import dmn_batch
        dmn = dmn_batch.DMN_batch(**args_dict)
    else:
        if (batch_size != 1):
            print "==> no minibatch training, argument batch_size is useless"
            batch_size = 1

        if network == 'dmn_basic':
            import dmn_basic
            dmn = dmn_basic.DMN_basic(**args_dict)
        elif network == 'dmn_smooth':
            import dmn_smooth
            dmn = dmn_smooth.DMN_smooth(**args_dict)
        elif network == 'dmn_spv':
            import dmn_spv
            dmn = dmn_spv.DMN(**args_dict)
        else:
            raise Exception("No such network known: " + network)
    return dmn, batch_size
示例#4
0
args_dict['babi_train_raw'] = babi_train_raw
args_dict['babi_test_raw'] = babi_test_raw
args_dict['word2vec'] = word2vec
    

# init class
if args.network == 'dmn_batch':
    import dmn_batch
    dmn = dmn_batch.DMN_batch(**args_dict)

elif args.network == 'dmn_basic':
    import dmn_basic
    if (args.batch_size != 1):
        print "==> no minibatch training, argument batch_size is useless"
        args.batch_size = 1
    dmn = dmn_basic.DMN_basic(**args_dict)

elif args.network == 'dmn_smooth':
    import dmn_smooth
    if (args.batch_size != 1):
        print "==> no minibatch training, argument batch_size is useless"
        args.batch_size = 1
    dmn = dmn_smooth.DMN_smooth(**args_dict)

elif args.network == 'dmn_qa':
    import dmn_qa_draft
    if (args.batch_size != 1):
        print "==> no minibatch training, argument batch_size is useless"
        args.batch_size = 1
    dmn = dmn_qa_draft.DMN_qa(**args_dict)
示例#5
0
args_dict['test_raw'] = test_raw
args_dict['word2vec'] = word2vec

# init class
if args.network == 'dmn_batch':
    import dmn_batch
    dmn = dmn_batch.DMN_batch(**args_dict)

# The basic module is implemented for document similarity
elif args.network == 'dmn_basic':
    import dmn_basic
    if (args.batch_size != 1):
        print "==> no minibatch training, argument batch_size is useless"
        args.batch_size = 1
    dmn = dmn_basic.DMN_basic(
        **args_dict
    )  # Initialize the dmn basic with all the arguments available. This also initializes theano functions and parameters.

elif args.network == 'dmn_smooth':
    import dmn_smooth
    if (args.batch_size != 1):
        print "==> no minibatch training, argument batch_size is useless"
        args.batch_size = 1
    dmn = dmn_smooth.DMN_smooth(**args_dict)

elif args.network == 'dmn_qa':
    import dmn_qa_draft
    if (args.batch_size != 1):
        print "==> no minibatch training, argument batch_size is useless"
        args.batch_size = 1
    dmn = dmn_qa_draft.DMN_qa(**args_dict)
示例#6
0
                              input_mask_mode=args.input_mask_mode,
                              memory_hops=args.memory_hops,
                              batch_size=args.batch_size,
                              l2=args.l2,
                              normalize_attention=args.normalize_attention)
elif args.network == 'dmn_basic':
    import dmn_basic
    if (args.batch_size != 1):
        raise Exception("no minibatch training, set batch_size=1")

    dmn = dmn_basic.DMN_basic(babi_train_raw=babi_train_raw,
                              babi_test_raw=babi_test_raw,
                              word2vec=word2vec,
                              word_vector_size=args.word_vector_size,
                              dim=args.dim,
                              mode=args.mode,
                              answer_module=args.answer_module,
                              input_mask_mode=args.input_mask_mode,
                              memory_hops=args.memory_hops,
                              l2=args.l2,
                              normalize_attention=args.normalize_attention)

elif args.network == 'dmn_qa':
    import dmn_qa
    if (args.batch_size != 1):
        raise Exception("no minibatch training, set batch_size=1")
    dmn = dmn_qa.DMN_qa(babi_train_raw=babi_train_raw,
                        babi_test_raw=babi_test_raw,
                        word2vec=word2vec,
                        word_vector_size=args.word_vector_size,
                        dim=args.dim,