def __init__(self, load_path, threshold = 0.5, cuda = False, model_name = "treelstm" , model = None): ''' :param load_path: the path to saved model. e.g. "saved_model.pt" :param cuda: bool : whether GPU is utilized when model calculation. Notice that you need install the torch gpu version. :param model_name: the model type used. Only the "treelstm" is supported in this script. :param threshold: decided what similarity scores are outputted. :param model: A network variant of Tree-LSTM ''' self.args = parse_args(["--cuda"]) #TODO argparse self.model =None self.embmodel = None #真正用于计算编码的模型 self.model_name = model_name self.db = None if cuda and torch.cuda.is_available(): self.device = torch.device("cuda:0") else: self.device = torch.device("cpu") if model is not None: self.model = model else: if model_name=="treelstm": from model import SimilarityTreeLSTM self.model = SimilarityTreeLSTM( self.args.vocab_size, self.args.input_dim, self.args.mem_dim, self.args.hidden_dim, self.args.num_classes, device=self.device ) self.load_model(load_path, self.model)
def test_tree_lstm(): l_sentences = mx.nd.load(os.path.join(CURRENT_DIR, 'l_sentences.nd')) r_sentences = mx.nd.load(os.path.join(CURRENT_DIR, 'r_sentences.nd')) with open(os.path.join(CURRENT_DIR, 'trees.pkl'), 'rb') as f: l_trees, r_trees = pickle.load(f) rnn_hidden_size, sim_hidden_size, num_classes = 150, 50, 5 net = SimilarityTreeLSTM(sim_hidden_size, rnn_hidden_size, 2413, 300, num_classes) net.initialize(mx.init.Xavier(magnitude=2.24)) sent = mx.nd.concat(l_sentences[0], r_sentences[0], dim=0) net(sent, len(l_sentences[0]), l_trees[0], r_trees[0]) net.embed.weight.set_data(mx.nd.random.uniform(shape=(2413, 300))) def verify(batch_size): print('verifying batch size: ', batch_size) fold = Fold() num_samples = 100 inputs = [] fold_preds = [] for i in range(num_samples): # get next batch l_sent = l_sentences[i] r_sent = r_sentences[i] sent = mx.nd.concat(l_sent, r_sent, dim=0) l_len = len(l_sent) l_tree = l_trees[i] r_tree = r_trees[i] inputs.append((sent, l_len, l_tree, r_tree)) z_fold = net.fold_encode(fold, sent, l_len, l_tree, r_tree) fold_preds.append(z_fold) if (i + 1) % batch_size == 0 or (i + 1) == num_samples: fold_outs = fold([fold_preds])[0] outs = mx.nd.concat(*[ net(sent, l_len, l_tree, r_tree) for sent, l_len, l_tree, r_tree in inputs ], dim=0) if not almost_equal(fold_outs.asnumpy(), outs.asnumpy()): print(fold_preds) print('l_sents: ', l_sent, l_sentences[i - 1]) print('r_sents: ', r_sent, r_sentences[i - 1]) print('\n'.join( (str(l_tree), str_tree(l_tree), str(r_tree), str_tree(r_tree), str(l_trees[i - 1]), str_tree(l_trees[i - 1]), str(r_trees[i - 1]), str_tree(r_trees[i - 1]), str(fold)))) assert_almost_equal(fold_outs.asnumpy(), outs.asnumpy()) fold_preds = [] inputs = [] fold.reset() for batch_size in range(1, 6): verify(batch_size)
vocab = Vocab(filepaths=token_files, embedpath=opt.word_embed) train_iter, dev_iter, test_iter = [ SICKDataIter(os.path.join(root_dir, segment), vocab, num_classes) for segment in segments ] with open('dataset.pkl', 'wb') as f: pickle.dump([train_iter, dev_iter, test_iter, vocab], f) logging.info('==> SICK vocabulary size : %d ' % vocab.size) logging.info('==> Size of train data : %d ' % len(train_iter)) logging.info('==> Size of dev data : %d ' % len(dev_iter)) logging.info('==> Size of test data : %d ' % len(test_iter)) net = SimilarityTreeLSTM(sim_hidden_size, rnn_hidden_size, vocab.size, vocab.embed.shape[1], num_classes) net.collect_params().initialize(mx.init.Xavier(magnitude=2.24), ctx=context[0]) l_tree, l_sent, r_tree, r_sent, _ = train_iter.next() train_iter.reset() net(l_sent, r_sent, l_tree, r_tree) net.embed.weight.set_data(vocab.embed.as_in_context(context[0])) # use pearson correlation and mean-square error for evaluation metric = mx.metric.create(['pearsonr', 'mse']) def to_target(x): target = np.zeros((1, num_classes)) ceil = int(math.ceil(x)) floor = int(math.floor(x)) if ceil == floor:
test_file = os.path.join(args.data, 'sick_test.pth') if os.path.isfile(test_file): test_dataset = torch.load(test_file) else: test_dataset = SICKDataset(test_dir, vocab, rel_vocab, args.num_classes) torch.save(test_dataset, test_file) logger.debug('==> Size of test data : %d ' % len(test_dataset)) # initialize model, criterion/loss_function, optimizer model = SimilarityTreeLSTM( args.model_type, rel_vocab, vocab.size(), args.input_dim, args.mem_dim, args.hidden_dim, args.num_classes, args.sparse, args.freeze_embed,) criterion = nn.KLDivLoss() #criterion = nn.CrossEntropyLoss() # + ###For changing embeddings # -
os.environ['CUDA_VISIBLE_DEVICES'] = '0' # args = parse_args() dataset = torch.load(os.path.join(root, "./data/train_dataset_buildroot_demo.pth")) all = len(dataset) np.random.shuffle(dataset) proportion = 0.8 # split_idx = int(all * proportion) train_dataset, test_dataset = dataset[0:split_idx], dataset[split_idx:] logger.info("==> Size of train data \t : %d" % len(train_dataset)) logger.info("==> Size of test data \t : %d" % len(test_dataset)) args.cuda = args.cuda and torch.cuda.is_available() device = torch.device("cuda:0" if args.cuda else "cpu") model = SimilarityTreeLSTM( args.vocab_size, args.input_dim, args.mem_dim, args.hidden_dim, args.num_classes, device ) criterion = nn.BCELoss() logger.info("[CUDA] available " + str(args.cuda)) logger.info("args" + str(args)) if args.optim == 'adam': optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, weight_decay=args.wd) elif args.optim == 'adagrad': optimizer = optim.Adagrad(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, weight_decay=args.wd) elif args.optim == 'sgd': optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, weight_decay=args.wd)
def main(): global args args = parse_args() # global logger logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) formatter = logging.Formatter("[%(asctime)s] %(levelname)s:%(name)s:%(message)s") # file logger fh = logging.FileHandler(os.path.join(args.save, args.expname)+'.log', mode='w') fh.setLevel(logging.INFO) fh.setFormatter(formatter) logger.addHandler(fh) # console logger ch = logging.StreamHandler() ch.setLevel(logging.DEBUG) ch.setFormatter(formatter) logger.addHandler(ch) # argument validation args.cuda = args.cuda and torch.cuda.is_available() if args.sparse and args.wd != 0: logger.error('Sparsity and weight decay are incompatible, pick one!') exit() logger.debug(args) torch.manual_seed(args.seed) random.seed(args.seed) if args.cuda: torch.cuda.manual_seed(args.seed) torch.backends.cudnn.benchmark = True if not os.path.exists(args.save): os.makedirs(args.save) train_dir = os.path.join(args.data, 'train/') dev_dir = os.path.join(args.data, 'dev/') test_dir = os.path.join(args.data, 'test/') # write unique words from all token files sick_vocab_file = os.path.join(args.data, 'sick.vocab') if not os.path.isfile(sick_vocab_file): token_files_b = [os.path.join(split, 'b.toks') for split in [train_dir, dev_dir, test_dir]] token_files_a = [os.path.join(split, 'a.toks') for split in [train_dir, dev_dir, test_dir]] token_files = token_files_a + token_files_b sick_vocab_file = os.path.join(args.data, 'sick.vocab') build_vocab(token_files, sick_vocab_file) # get vocab object from vocab file previously written vocab = Vocab(filename=sick_vocab_file, data=[Constants.PAD_WORD, Constants.UNK_WORD, Constants.BOS_WORD, Constants.EOS_WORD]) logger.debug('==> SICK vocabulary size : %d ' % vocab.size()) # load SICK dataset splits train_file = os.path.join(args.data, 'sick_train.pth') if os.path.isfile(train_file): train_dataset = torch.load(train_file) else: train_dataset = SICKDataset(train_dir, vocab, args.num_classes) torch.save(train_dataset, train_file) logger.debug('==> Size of train data : %d ' % len(train_dataset)) dev_file = os.path.join(args.data, 'sick_dev.pth') if os.path.isfile(dev_file): dev_dataset = torch.load(dev_file) else: dev_dataset = SICKDataset(dev_dir, vocab, args.num_classes) torch.save(dev_dataset, dev_file) logger.debug('==> Size of dev data : %d ' % len(dev_dataset)) test_file = os.path.join(args.data, 'sick_test.pth') if os.path.isfile(test_file): test_dataset = torch.load(test_file) else: test_dataset = SICKDataset(test_dir, vocab, args.num_classes) torch.save(test_dataset, test_file) logger.debug('==> Size of test data : %d ' % len(test_dataset)) # initialize model, criterion/loss_function, optimizer model = SimilarityTreeLSTM( vocab.size(), args.input_dim, args.mem_dim, args.hidden_dim, args.num_classes, args.sparse, args.freeze_embed) criterion = nn.KLDivLoss() if args.cuda: model.cuda(), criterion.cuda() if args.optim == 'adam': optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, weight_decay=args.wd) elif args.optim == 'adagrad': optimizer = optim.Adagrad(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, weight_decay=args.wd) elif args.optim == 'sgd': optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, weight_decay=args.wd) metrics = Metrics(args.num_classes) # for words common to dataset vocab and GLOVE, use GLOVE vectors # for other words in dataset vocab, use random normal vectors emb_file = os.path.join(args.data, 'sick_embed.pth') if os.path.isfile(emb_file): emb = torch.load(emb_file) else: # load glove embeddings and vocab glove_vocab, glove_emb = load_word_vectors(os.path.join(args.glove, 'glove.840B.300d')) logger.debug('==> GLOVE vocabulary size: %d ' % glove_vocab.size()) emb = torch.Tensor(vocab.size(), glove_emb.size(1)).normal_(-0.05, 0.05) # zero out the embeddings for padding and other special words if they are absent in vocab for idx, item in enumerate([Constants.PAD_WORD, Constants.UNK_WORD, Constants.BOS_WORD, Constants.EOS_WORD]): emb[idx].zero_() for word in vocab.labelToIdx.keys(): if glove_vocab.getIndex(word): emb[vocab.getIndex(word)] = glove_emb[glove_vocab.getIndex(word)] torch.save(emb, emb_file) # plug these into embedding matrix inside model if args.cuda: emb = emb.cuda() model.emb.weight.data.copy_(emb) # create trainer object for training and testing trainer = Trainer(args, model, criterion, optimizer) best = -float('inf') for epoch in range(args.epochs): train_loss = trainer.train(train_dataset) train_loss, train_pred = trainer.test(train_dataset) dev_loss, dev_pred = trainer.test(dev_dataset) test_loss, test_pred = trainer.test(test_dataset) train_pearson = metrics.pearson(train_pred, train_dataset.labels) train_mse = metrics.mse(train_pred, train_dataset.labels) logger.info('==> Epoch {}, Train \tLoss: {}\tPearson: {}\tMSE: {}'.format(epoch, train_loss, train_pearson, train_mse)) dev_pearson = metrics.pearson(dev_pred, dev_dataset.labels) dev_mse = metrics.mse(dev_pred, dev_dataset.labels) logger.info('==> Epoch {}, Dev \tLoss: {}\tPearson: {}\tMSE: {}'.format(epoch, dev_loss, dev_pearson, dev_mse)) test_pearson = metrics.pearson(test_pred, test_dataset.labels) test_mse = metrics.mse(test_pred, test_dataset.labels) logger.info('==> Epoch {}, Test \tLoss: {}\tPearson: {}\tMSE: {}'.format(epoch, test_loss, test_pearson, test_mse)) if best < test_pearson: best = test_pearson checkpoint = { 'model': trainer.model.state_dict(), 'optim': trainer.optimizer, 'pearson': test_pearson, 'mse': test_mse, 'args': args, 'epoch': epoch } logger.debug('==> New optimum found, checkpointing everything now...') torch.save(checkpoint, '%s.pt' % os.path.join(args.save, args.expname))
class Application(): ''' This class loads the trained model and use the model to encode an ast into a vector and calculate the similarity between asts. ''' def __init__(self, load_path, threshold = 0.5, cuda = False, model_name = "treelstm" , model = None): ''' :param load_path: the path to saved model. e.g. "saved_model.pt" :param cuda: bool : whether GPU is utilized when model calculation. Notice that you need install the torch gpu version. :param model_name: the model type used. Only the "treelstm" is supported in this script. :param threshold: decided what similarity scores are outputted. :param model: A network variant of Tree-LSTM ''' self.args = parse_args(["--cuda"]) #TODO argparse self.model =None self.embmodel = None #真正用于计算编码的模型 self.model_name = model_name self.db = None if cuda and torch.cuda.is_available(): self.device = torch.device("cuda:0") else: self.device = torch.device("cpu") if model is not None: self.model = model else: if model_name=="treelstm": from model import SimilarityTreeLSTM self.model = SimilarityTreeLSTM( self.args.vocab_size, self.args.input_dim, self.args.mem_dim, self.args.hidden_dim, self.args.num_classes, device=self.device ) self.load_model(load_path, self.model) def load_model(self, path, model): ''' :param path: path to saved model :param model: the model loaded :return: ''' if not os.path.isfile(path): print("model path %s non-exists" % path) raise Exception checkpoint = torch.load(path, map_location=self.device) if "auc" in checkpoint: logger.info("checkpoint loaded: auc %f , mse: %f \n args %s" %(checkpoint['auc'], checkpoint['mse'], checkpoint['args'])) model.load_state_dict(checkpoint['model']) model.eval() # self.model.to(self.device) self.embmodel =model.embmodel self.embmodel.to(self.device) def encode_ast(self, tree): ''' :param tree: An Tree object instance :return: a numpy vector (64 or 150 d) ''' with torch.no_grad(): state, hidden = self.embmodel(tree, get_tree_flat_nodes(tree).to(self.device)) return state.detach().squeeze(0).cpu().numpy() def similarity_treeencoding_with_correction(self, ltree, rtree, lcallee, rcallee): ''' :param ltree: tree encoding vector :param rtree: tree encoding vector :param lcallee: (caller, callee) corresponds to ltree :param rcallee: (caller, callee) corresponds to rtree :return: ''' sim_tree = self.similarity_vec(ltree, rtree) # return sim_tree # scale lcallee and rcallee in case that zero vector lcallee = list(map(lambda x: x + 1, lcallee)) rcallee = list(map(lambda x: x + 1, rcallee)) cs = cosine_similarity([lcallee], [rcallee])[0][0] # cosine distance scale = exp(0 - abs(lcallee[-1] - rcallee[-1])) return sim_tree * scale * cs def similarity_tree_with_correction(self, ltree, rtree, lcallee, rcallee): ''' :param ltree: AST1 :param rtree: AST2 :param lcallee: (caller, callee) corresponds to ltree :param rcallee:(caller, callee) corresponds to rtree :return: ''' sim_tree = self.similarity_tree(ltree, rtree) #return sim_tree # scale lcallee and rcallee in case that zero vector lcallee=list(map(lambda x:x+1, lcallee)) rcallee=list(map(lambda x:x+1, rcallee)) cs = cosine_similarity([lcallee], [rcallee])[0][0] # (caller,callee) scale = exp(0 - abs(lcallee[-1]-rcallee[-1])) return sim_tree * scale * cs def similarity_tree(self, ltree, rtree): ''' calculate the similarity of two asts :param ltree: first tree :param rtree: first tree :return: ''' with torch.no_grad(): if self.model_name=='treelstm' or self.model_name=="binarytreelstm": res = self.model(ltree, rtree)[0][1].item() else: res = self.model(ltree, rtree)[1].item() if torch.cuda.is_available(): torch.cuda.empty_cache() return res def similarity_vec(self, lvec, rvec): ''' calculate the similarity of two ast encodings :param lvec: numpy.ndarray or torch.tensor :param rvec: :return: similairty score ranges 0~1. ''' if type(lvec) is list: lvec = numpy.array(lvec) rvec = numpy.array(rvec) if type(lvec) is numpy.ndarray: lvec = torch.from_numpy(lvec).to(self.device).float() rvec = torch.from_numpy(rvec).to(self.device).float() lvec = lvec.unsqueeze(0) rvec = rvec.unsqueeze(0) with torch.no_grad(): if self.model_name in ['treelstm', 'treelstm_boosted']: res = self.model.similarity(lvec, rvec)[0][1].float().item() else: res = self.model.similarity(lvec, rvec)[1].float().item() return res def get_conn(self, db): # return the database connection global DB_CONNECTION if DB_CONNECTION is None: global datahelper DB_CONNECTION = datahelper.load_database(db) return DB_CONNECTION def encode_ast_in_db(self, db_path, table_name="function"): ''' Encode the asts in db file "db_path" and save the encoding vectors into table 'table_name' :param db_path: path to sqlite database ''' db_conn = self.get_conn(db_path) cur = db_conn.cursor() # create table if not exists sql_create_new_table = """CREATE TABLE if not exists %s ( function_name varchar (255), elf_path varchar (255), ast_encode TEXT, primary key (function_name, elf_path) );""" % table_name try: cur.execute(sql_create_new_table) db_conn.commit() except Exception as e: logger.error("sql [%s] failed" % sql_create_new_table) logger.error(e) finally: cur.close() to_encode_list = [] global datahelper for func_info, func_ast in old_datahelper.get_functions(db_path): to_encode_list.append((func_info, func_ast)) encode_group = to_encode_list logger.info("Encoding for %d ast" % len(encode_group)) self.encode_and_update(db_path, encode_group, table_name) def __del__(self): if self.db: self.db.close() def func_wrapper(self, func, *args): # to execute the function in a limited period time self.Timeout = 350 #8 minutes gevent.Timeout(self.Timeout, Exception).start() s = datetime.datetime.now() try: g = gevent.spawn(func, args[0]) g.join() return g.get() except Exception as e: #print("Timeout Error. ") e = datetime.datetime.now() #print("Real run time is %d" % (e - s).seconds) r = numpy.zeros((1,150)) return r def encode_and_update(self, db_path, functions, table_name): ''' encode asts of the functions into vectors :param functions: a list contains asts to be encoded :param table_name: the new table to save ast encodings ''' db_conn = self.get_conn(db_path) p = Pool(processes=10) res = [] count = 1 for func_info, func_ast in functions: res.append((p.apply_async(self.func_wrapper, (self.encode_ast, func_ast)), func_info[0], func_info[1])) # func_info[0] is function name; func_info[1] is elf_path count+=1 #try: p.close() p.join() result = [] try: logger.info("Fetching encode results!") for idx, r in tqdm(enumerate(res)): result.append((json.dumps(r[0].get().tolist()), r[1], r[2])) logger.info("All encode fetched!") except Exception as e: print("Exception when fetching {}".format(str(e))) try: logger.info("Writing encoded vectors to database") cur = db_conn.cursor() sql_update = """ update {} set ast_encode=? where function_name=? AND elf_path=? """.format(table_name) cur.executemany(sql_update, result) cur.close() db_conn.commit() except Exception as e: db_conn.rollback() print("Error when INSERT [{}]\n".format(sql_update)) print(e)
random.seed(cfg.random_seed()) if cfg.use_cuda(): torch.cuda.manual_seed(cfg.random_seed()) torch.backends.cudnn.benchmark = True cfg.logger.info('==> SICK vocabulary size : %d ' % D.vocab.size()) cfg.logger.info('==> Size of train data : %d ' % len(train_dataset)) cfg.logger.info('==> Size of dev data : %d ' % len(dev_dataset)) cfg.logger.info('==> Size of test data : %d ' % len(test_dataset)) model = SimilarityTreeLSTM( D.vocab.size(), cfg.input_dim(), cfg.mem_dim(), cfg.hidden_dim(), cfg.num_classes(), cfg.sparse(), cfg.freeze_embed()) criterion = nn.KLDivLoss() cfg.logger.info("model:\n" + str(model)) emb = get_embd(cfg, D.vocab) # plug these into embedding matrix inside model model.emb.weight.data.copy_(emb) model.to(cfg.device()), criterion.to(cfg.device())
def prepare_to_train(data=None, glove=None): args = parse_args() if data is not None: args.data = data if glove is not None: args.glove = glove args.input_dim, args.mem_dim = 300, 150 args.hidden_dim, args.num_classes = 50, 5 args.cuda = args.cuda and torch.cuda.is_available() if args.sparse and args.wd != 0: print('Sparsity and weight decay are incompatible, pick one!') exit() print(args) torch.manual_seed(args.seed) random.seed(args.seed) numpy.random.seed(args.seed) if args.cuda: torch.cuda.manual_seed(args.seed) torch.backends.cudnn.benchmark = True if not os.path.exists(args.save): os.makedirs(args.save) train_dir = os.path.join(args.data, 'train/') dev_dir = os.path.join(args.data, 'dev/') test_dir = os.path.join(args.data, 'test/') # write unique words from all token files sick_vocab_file = os.path.join(args.data, 'sick.vocab') if not os.path.isfile(sick_vocab_file): token_files_a = [ os.path.join(split, 'a.toks') for split in [train_dir, dev_dir, test_dir] ] token_files_b = [ os.path.join(split, 'b.toks') for split in [train_dir, dev_dir, test_dir] ] token_files = token_files_a + token_files_b sick_vocab_file = os.path.join(args.data, 'sick.vocab') build_vocab(token_files, sick_vocab_file) # get vocab object from vocab file previously written vocab = Vocab(filename=sick_vocab_file, data=[ Constants.PAD_WORD, Constants.UNK_WORD, Constants.BOS_WORD, Constants.EOS_WORD ]) print('==> SICK vocabulary size : %d ' % vocab.size()) # load SICK dataset splits train_file = os.path.join(args.data, 'sick_train.pth') if os.path.isfile(train_file): train_dataset = torch.load(train_file) else: train_dataset = SICKDataset(train_dir, vocab, args.num_classes) torch.save(train_dataset, train_file) print('==> Size of train data : %d ' % len(train_dataset)) dev_file = os.path.join(args.data, 'sick_dev.pth') if os.path.isfile(dev_file): dev_dataset = torch.load(dev_file) else: dev_dataset = SICKDataset(dev_dir, vocab, args.num_classes) torch.save(dev_dataset, dev_file) print('==> Size of dev data : %d ' % len(dev_dataset)) test_file = os.path.join(args.data, 'sick_test.pth') if os.path.isfile(test_file): test_dataset = torch.load(test_file) else: test_dataset = SICKDataset(test_dir, vocab, args.num_classes) torch.save(test_dataset, test_file) print('==> Size of test data : %d ' % len(test_dataset)) # initialize model, criterion/loss_function, optimizer model = SimilarityTreeLSTM(args.cuda, vocab.size(), args.input_dim, args.mem_dim, args.hidden_dim, args.num_classes, args.sparse) criterion = nn.KLDivLoss() if args.cuda: model.cuda(), criterion.cuda() if args.optim == 'adam': optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd) elif args.optim == 'adagrad': optimizer = optim.Adagrad(model.parameters(), lr=args.lr, weight_decay=args.wd) elif args.optim == 'sgd': optimizer = optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.wd) metrics = Metrics(args.num_classes) # for words common to dataset vocab and GLOVE, use GLOVE vectors # for other words in dataset vocab, use random normal vectors emb_file = os.path.join(args.data, 'sick_embed.pth') if os.path.isfile(emb_file): emb = torch.load(emb_file) else: # load glove embeddings and vocab glove_vocab, glove_emb = load_word_vectors( os.path.join(args.glove, 'glove.840B.300d')) print('==> GLOVE vocabulary size: %d ' % glove_vocab.size()) emb = torch.Tensor(vocab.size(), glove_emb.size(1)).normal_(-0.05, 0.05) # zero out the embeddings for padding and other special words if they are absent in vocab for idx, item in enumerate([ Constants.PAD_WORD, Constants.UNK_WORD, Constants.BOS_WORD, Constants.EOS_WORD ]): emb[idx].zero_() for word in vocab.labelToIdx.keys(): if glove_vocab.get_index(word): emb[vocab.get_index(word)] = glove_emb[glove_vocab.get_index( word)] torch.save(emb, emb_file) # plug these into embedding matrix inside model if args.cuda: emb = emb.cuda() model.childsumtreelstm.emb.state_dict()['weight'].copy_(emb) # create trainer object for training and testing #trainer = Trainer(args, model, criterion, optimizer) best = -float('inf') return (args, best, train_dataset, dev_dataset, test_dataset, metrics, optimizer, criterion, model)