def load_network(code_path, babi_train_raw, babi_test_raw, word2vec, word_vector_size, model_json): if model_json in loaded_models: print "!> model %s is already loaded" % model_json return loaded_models[model_json] print "!> loading model %s..." % model_json model_file = open(model_json) args_dict = json.load(model_file) assert word_vector_size == args_dict['word_vector_size'] args_dict['babi_train_raw'] = babi_train_raw args_dict['babi_test_raw'] = babi_test_raw args_dict['word2vec'] = word2vec # init class if args_dict['network'] == 'dmn_batch': raise Exception("dmn_batch did not implement predict()") sys.path.insert(0, code_path) import dmn_batch dmn = dmn_batch.DMN_batch(**args_dict) sys.path.insert(0, current_dir) elif args_dict['network'] == 'dmn_basic': raise Exception("dmn_batch did not implement predict()") sys.path.insert(0, code_path) import dmn_basic dmn = dmn_basic.DMN_basic(**args_dict) sys.path.insert(0, current_dir) elif args_dict['network'] == 'dmn_smooth': sys.path.insert(0, code_path) import dmn_smooth dmn = dmn_smooth.DMN_smooth(**args_dict) sys.path.insert(0, current_dir) elif args_dict['network'] == 'dmn_qa': raise Exception("dmn_batch did not implement predict()") sys.path.insert(0, code_path) import dmn_qa_draft dmn = dmn_qa_draft.DMN_qa(**args_dict) sys.path.insert(0, current_dir) else: raise Exception("No such network known: " + args_dict['network']) print "!> loading state %s..." % args_dict['load_state'] dmn.load_state(args_dict['load_state']) loaded_models[model_json] = dmn return dmn
def get_dmn(network, batch_size, args_dict): # init class if network == 'dmn_batch': import dmn_batch dmn = dmn_batch.DMN_batch(**args_dict) else: if (batch_size != 1): print "==> no minibatch training, argument batch_size is useless" batch_size = 1 if network == 'dmn_basic': import dmn_basic dmn = dmn_basic.DMN_basic(**args_dict) elif network == 'dmn_smooth': import dmn_smooth dmn = dmn_smooth.DMN_smooth(**args_dict) elif network == 'dmn_spv': import dmn_spv dmn = dmn_spv.DMN(**args_dict) else: raise Exception("No such network known: " + network) return dmn, batch_size
babi_train_raw, babi_test_raw = utils.get_babi_raw(args.babi_id, args.babi_test_id) word2vec = utils.load_glove(args.word_vector_size) args_dict = dict(args._get_kwargs()) args_dict['babi_train_raw'] = babi_train_raw args_dict['babi_test_raw'] = babi_test_raw args_dict['word2vec'] = word2vec # init class if args.network == 'dmn_batch': import dmn_batch dmn = dmn_batch.DMN_batch(**args_dict) elif args.network == 'dmn_basic': import dmn_basic if (args.batch_size != 1): print "==> no minibatch training, argument batch_size is useless" args.batch_size = 1 dmn = dmn_basic.DMN_basic(**args_dict) elif args.network == 'dmn_smooth': import dmn_smooth if (args.batch_size != 1): print "==> no minibatch training, argument batch_size is useless" args.batch_size = 1 dmn = dmn_smooth.DMN_smooth(**args_dict)
args.network, args.memory_hops, args.dim, ".na" if args.normalize_attention else "", args.babi_id) babi_train_raw, babi_test_raw = utils.get_babi_raw(args.babi_id) word2vec = utils.load_glove(args.word_vector_size) # init class if args.network == 'dmn_batch': import dmn_batch dmn = dmn_batch.DMN_batch(babi_train_raw=babi_train_raw, babi_test_raw=babi_test_raw, word2vec=word2vec, word_vector_size=args.word_vector_size, dim=args.dim, mode=args.mode, answer_module=args.answer_module, input_mask_mode=args.input_mask_mode, memory_hops=args.memory_hops, batch_size=args.batch_size, l2=args.l2, normalize_attention=args.normalize_attention) elif args.network == 'dmn_basic': import dmn_basic if (args.batch_size != 1): raise Exception("no minibatch training, set batch_size=1") dmn = dmn_basic.DMN_basic(babi_train_raw=babi_train_raw, babi_test_raw=babi_test_raw, word2vec=word2vec, word_vector_size=args.word_vector_size, dim=args.dim,