def main(): start = time.time() query = sys.argv[1] glove = utils.load_glove() quest = utils.init_babi_deploy( os.path.join( os.path.join( os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data'), 'corpus'), 'babi.txt'), query) dmn = dmn_basic.DMN_basic(babi_train_raw=quest, babi_test_raw=[], word2vec=glove, word_vector_size=50, dim=40, mode='deploy', answer_module='feedforward', input_mask_mode="sentence", memory_hops=5, l2=0, normalize_attention=False, answer_vec='index', debug=False) dmn.load_state( 'states/dmn_basic.mh5.n40.bs10.babi1.epoch2.test1.20454.state') prediction = dmn.step_deploy() prediction = prediction[0][0] for ind in prediction.argsort()[::-1]: if ind < dmn.answer_size: print(dmn.ivocab[ind]) break print('Time taken:', time.time() - start)
def load_network(code_path, babi_train_raw, babi_test_raw, word2vec, word_vector_size, model_json): if model_json in loaded_models: print "!> model %s is already loaded" % model_json return loaded_models[model_json] print "!> loading model %s..." % model_json model_file = open(model_json) args_dict = json.load(model_file) assert word_vector_size == args_dict['word_vector_size'] args_dict['babi_train_raw'] = babi_train_raw args_dict['babi_test_raw'] = babi_test_raw args_dict['word2vec'] = word2vec # init class if args_dict['network'] == 'dmn_batch': raise Exception("dmn_batch did not implement predict()") sys.path.insert(0, code_path) import dmn_batch dmn = dmn_batch.DMN_batch(**args_dict) sys.path.insert(0, current_dir) elif args_dict['network'] == 'dmn_basic': raise Exception("dmn_batch did not implement predict()") sys.path.insert(0, code_path) import dmn_basic dmn = dmn_basic.DMN_basic(**args_dict) sys.path.insert(0, current_dir) elif args_dict['network'] == 'dmn_smooth': sys.path.insert(0, code_path) import dmn_smooth dmn = dmn_smooth.DMN_smooth(**args_dict) sys.path.insert(0, current_dir) elif args_dict['network'] == 'dmn_qa': raise Exception("dmn_batch did not implement predict()") sys.path.insert(0, code_path) import dmn_qa_draft dmn = dmn_qa_draft.DMN_qa(**args_dict) sys.path.insert(0, current_dir) else: raise Exception("No such network known: " + args_dict['network']) print "!> loading state %s..." % args_dict['load_state'] dmn.load_state(args_dict['load_state']) loaded_models[model_json] = dmn return dmn
def get_dmn(network, batch_size, args_dict): # init class if network == 'dmn_batch': import dmn_batch dmn = dmn_batch.DMN_batch(**args_dict) else: if (batch_size != 1): print "==> no minibatch training, argument batch_size is useless" batch_size = 1 if network == 'dmn_basic': import dmn_basic dmn = dmn_basic.DMN_basic(**args_dict) elif network == 'dmn_smooth': import dmn_smooth dmn = dmn_smooth.DMN_smooth(**args_dict) elif network == 'dmn_spv': import dmn_spv dmn = dmn_spv.DMN(**args_dict) else: raise Exception("No such network known: " + network) return dmn, batch_size
args_dict['babi_train_raw'] = babi_train_raw args_dict['babi_test_raw'] = babi_test_raw args_dict['word2vec'] = word2vec # init class if args.network == 'dmn_batch': import dmn_batch dmn = dmn_batch.DMN_batch(**args_dict) elif args.network == 'dmn_basic': import dmn_basic if (args.batch_size != 1): print "==> no minibatch training, argument batch_size is useless" args.batch_size = 1 dmn = dmn_basic.DMN_basic(**args_dict) elif args.network == 'dmn_smooth': import dmn_smooth if (args.batch_size != 1): print "==> no minibatch training, argument batch_size is useless" args.batch_size = 1 dmn = dmn_smooth.DMN_smooth(**args_dict) elif args.network == 'dmn_qa': import dmn_qa_draft if (args.batch_size != 1): print "==> no minibatch training, argument batch_size is useless" args.batch_size = 1 dmn = dmn_qa_draft.DMN_qa(**args_dict)
args_dict['test_raw'] = test_raw args_dict['word2vec'] = word2vec # init class if args.network == 'dmn_batch': import dmn_batch dmn = dmn_batch.DMN_batch(**args_dict) # The basic module is implemented for document similarity elif args.network == 'dmn_basic': import dmn_basic if (args.batch_size != 1): print "==> no minibatch training, argument batch_size is useless" args.batch_size = 1 dmn = dmn_basic.DMN_basic( **args_dict ) # Initialize the dmn basic with all the arguments available. This also initializes theano functions and parameters. elif args.network == 'dmn_smooth': import dmn_smooth if (args.batch_size != 1): print "==> no minibatch training, argument batch_size is useless" args.batch_size = 1 dmn = dmn_smooth.DMN_smooth(**args_dict) elif args.network == 'dmn_qa': import dmn_qa_draft if (args.batch_size != 1): print "==> no minibatch training, argument batch_size is useless" args.batch_size = 1 dmn = dmn_qa_draft.DMN_qa(**args_dict)
input_mask_mode=args.input_mask_mode, memory_hops=args.memory_hops, batch_size=args.batch_size, l2=args.l2, normalize_attention=args.normalize_attention) elif args.network == 'dmn_basic': import dmn_basic if (args.batch_size != 1): raise Exception("no minibatch training, set batch_size=1") dmn = dmn_basic.DMN_basic(babi_train_raw=babi_train_raw, babi_test_raw=babi_test_raw, word2vec=word2vec, word_vector_size=args.word_vector_size, dim=args.dim, mode=args.mode, answer_module=args.answer_module, input_mask_mode=args.input_mask_mode, memory_hops=args.memory_hops, l2=args.l2, normalize_attention=args.normalize_attention) elif args.network == 'dmn_qa': import dmn_qa if (args.batch_size != 1): raise Exception("no minibatch training, set batch_size=1") dmn = dmn_qa.DMN_qa(babi_train_raw=babi_train_raw, babi_test_raw=babi_test_raw, word2vec=word2vec, word_vector_size=args.word_vector_size, dim=args.dim,