Beispiel #1
0
 def __init__(self, load_path, threshold = 0.5, cuda = False, model_name = "treelstm" , model = None):
     '''
     :param load_path: the path to saved model. e.g. "saved_model.pt"
     :param cuda: bool : whether GPU is utilized when model calculation. Notice that you need install the torch gpu version.
     :param model_name: the model type used. Only the "treelstm" is supported in this script.
     :param threshold: decided what similarity scores are outputted.
     :param model: A network variant of Tree-LSTM
     '''
     self.args = parse_args(["--cuda"]) #TODO argparse
     self.model =None
     self.embmodel = None #真正用于计算编码的模型
     self.model_name = model_name
     self.db = None
     if cuda and torch.cuda.is_available():
         self.device = torch.device("cuda:0")
     else:
         self.device = torch.device("cpu")
     if model is not None:
         self.model = model
     else:
         if model_name=="treelstm":
             from model import SimilarityTreeLSTM
             self.model = SimilarityTreeLSTM(
                 self.args.vocab_size,
                 self.args.input_dim,
                 self.args.mem_dim,
                 self.args.hidden_dim,
                 self.args.num_classes,
                 device=self.device
             )
     self.load_model(load_path, self.model)
Beispiel #2
0
def test_tree_lstm():

    l_sentences = mx.nd.load(os.path.join(CURRENT_DIR, 'l_sentences.nd'))
    r_sentences = mx.nd.load(os.path.join(CURRENT_DIR, 'r_sentences.nd'))
    with open(os.path.join(CURRENT_DIR, 'trees.pkl'), 'rb') as f:
        l_trees, r_trees = pickle.load(f)

    rnn_hidden_size, sim_hidden_size, num_classes = 150, 50, 5
    net = SimilarityTreeLSTM(sim_hidden_size, rnn_hidden_size, 2413, 300,
                             num_classes)
    net.initialize(mx.init.Xavier(magnitude=2.24))
    sent = mx.nd.concat(l_sentences[0], r_sentences[0], dim=0)
    net(sent, len(l_sentences[0]), l_trees[0], r_trees[0])
    net.embed.weight.set_data(mx.nd.random.uniform(shape=(2413, 300)))

    def verify(batch_size):
        print('verifying batch size: ', batch_size)
        fold = Fold()
        num_samples = 100
        inputs = []
        fold_preds = []
        for i in range(num_samples):
            # get next batch
            l_sent = l_sentences[i]
            r_sent = r_sentences[i]
            sent = mx.nd.concat(l_sent, r_sent, dim=0)
            l_len = len(l_sent)
            l_tree = l_trees[i]
            r_tree = r_trees[i]

            inputs.append((sent, l_len, l_tree, r_tree))
            z_fold = net.fold_encode(fold, sent, l_len, l_tree, r_tree)
            fold_preds.append(z_fold)

            if (i + 1) % batch_size == 0 or (i + 1) == num_samples:
                fold_outs = fold([fold_preds])[0]
                outs = mx.nd.concat(*[
                    net(sent, l_len, l_tree, r_tree)
                    for sent, l_len, l_tree, r_tree in inputs
                ],
                                    dim=0)
                if not almost_equal(fold_outs.asnumpy(), outs.asnumpy()):
                    print(fold_preds)
                    print('l_sents: ', l_sent, l_sentences[i - 1])
                    print('r_sents: ', r_sent, r_sentences[i - 1])
                    print('\n'.join(
                        (str(l_tree), str_tree(l_tree), str(r_tree),
                         str_tree(r_tree), str(l_trees[i - 1]),
                         str_tree(l_trees[i - 1]), str(r_trees[i - 1]),
                         str_tree(r_trees[i - 1]), str(fold))))
                    assert_almost_equal(fold_outs.asnumpy(), outs.asnumpy())
                fold_preds = []
                inputs = []
                fold.reset()

    for batch_size in range(1, 6):
        verify(batch_size)
Beispiel #3
0
    vocab = Vocab(filepaths=token_files, embedpath=opt.word_embed)

    train_iter, dev_iter, test_iter = [
        SICKDataIter(os.path.join(root_dir, segment), vocab, num_classes)
        for segment in segments
    ]
    with open('dataset.pkl', 'wb') as f:
        pickle.dump([train_iter, dev_iter, test_iter, vocab], f)

logging.info('==> SICK vocabulary size : %d ' % vocab.size)
logging.info('==> Size of train data   : %d ' % len(train_iter))
logging.info('==> Size of dev data     : %d ' % len(dev_iter))
logging.info('==> Size of test data    : %d ' % len(test_iter))

net = SimilarityTreeLSTM(sim_hidden_size, rnn_hidden_size, vocab.size,
                         vocab.embed.shape[1], num_classes)
net.collect_params().initialize(mx.init.Xavier(magnitude=2.24), ctx=context[0])
l_tree, l_sent, r_tree, r_sent, _ = train_iter.next()
train_iter.reset()
net(l_sent, r_sent, l_tree, r_tree)
net.embed.weight.set_data(vocab.embed.as_in_context(context[0]))

# use pearson correlation and mean-square error for evaluation
metric = mx.metric.create(['pearsonr', 'mse'])


def to_target(x):
    target = np.zeros((1, num_classes))
    ceil = int(math.ceil(x))
    floor = int(math.floor(x))
    if ceil == floor:
Beispiel #4
0
test_file = os.path.join(args.data, 'sick_test.pth')
if os.path.isfile(test_file):
    test_dataset = torch.load(test_file)
else:
    test_dataset = SICKDataset(test_dir, vocab, rel_vocab, args.num_classes)
    torch.save(test_dataset, test_file)
logger.debug('==> Size of test data     : %d ' % len(test_dataset))



# initialize model, criterion/loss_function, optimizer
model = SimilarityTreeLSTM(
    args.model_type,
    rel_vocab,
    vocab.size(),
    args.input_dim,
    args.mem_dim,
    args.hidden_dim,
    args.num_classes,
    args.sparse,
    args.freeze_embed,)

criterion = nn.KLDivLoss()
#criterion = nn.CrossEntropyLoss()





# +
###For changing embeddings
# -
Beispiel #5
0
os.environ['CUDA_VISIBLE_DEVICES'] = '0'  #
args = parse_args()
dataset = torch.load(os.path.join(root, "./data/train_dataset_buildroot_demo.pth"))
all = len(dataset)
np.random.shuffle(dataset)
proportion = 0.8  #
split_idx = int(all * proportion)
train_dataset, test_dataset = dataset[0:split_idx], dataset[split_idx:]
logger.info("==> Size of train data \t : %d" % len(train_dataset))
logger.info("==> Size of test data \t : %d" % len(test_dataset))
args.cuda = args.cuda and torch.cuda.is_available()
device = torch.device("cuda:0" if args.cuda else "cpu")
model = SimilarityTreeLSTM(
    args.vocab_size,
    args.input_dim,
    args.mem_dim,
    args.hidden_dim,
    args.num_classes,
    device
)
criterion = nn.BCELoss()
logger.info("[CUDA] available " + str(args.cuda))
logger.info("args" + str(args))
if args.optim == 'adam':
    optimizer = optim.Adam(filter(lambda p: p.requires_grad,
                                  model.parameters()), lr=args.lr, weight_decay=args.wd)
elif args.optim == 'adagrad':
    optimizer = optim.Adagrad(filter(lambda p: p.requires_grad,
                                     model.parameters()), lr=args.lr, weight_decay=args.wd)
elif args.optim == 'sgd':
    optimizer = optim.SGD(filter(lambda p: p.requires_grad,
                                 model.parameters()), lr=args.lr, weight_decay=args.wd)
Beispiel #6
0
def main():
    global args
    args = parse_args()
    # global logger
    logger = logging.getLogger(__name__)
    logger.setLevel(logging.DEBUG)
    formatter = logging.Formatter("[%(asctime)s] %(levelname)s:%(name)s:%(message)s")
    # file logger
    fh = logging.FileHandler(os.path.join(args.save, args.expname)+'.log', mode='w')
    fh.setLevel(logging.INFO)
    fh.setFormatter(formatter)
    logger.addHandler(fh)
    # console logger
    ch = logging.StreamHandler()
    ch.setLevel(logging.DEBUG)
    ch.setFormatter(formatter)
    logger.addHandler(ch)
    # argument validation
    args.cuda = args.cuda and torch.cuda.is_available()
    if args.sparse and args.wd != 0:
        logger.error('Sparsity and weight decay are incompatible, pick one!')
        exit()
    logger.debug(args)
    torch.manual_seed(args.seed)
    random.seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)
        torch.backends.cudnn.benchmark = True
    if not os.path.exists(args.save):
        os.makedirs(args.save)

    train_dir = os.path.join(args.data, 'train/')
    dev_dir = os.path.join(args.data, 'dev/')
    test_dir = os.path.join(args.data, 'test/')

    # write unique words from all token files
    sick_vocab_file = os.path.join(args.data, 'sick.vocab')
    if not os.path.isfile(sick_vocab_file):
        token_files_b = [os.path.join(split, 'b.toks') for split in [train_dir, dev_dir, test_dir]]
        token_files_a = [os.path.join(split, 'a.toks') for split in [train_dir, dev_dir, test_dir]]
        token_files = token_files_a + token_files_b
        sick_vocab_file = os.path.join(args.data, 'sick.vocab')
        build_vocab(token_files, sick_vocab_file)

    # get vocab object from vocab file previously written
    vocab = Vocab(filename=sick_vocab_file, data=[Constants.PAD_WORD, Constants.UNK_WORD, Constants.BOS_WORD, Constants.EOS_WORD])
    logger.debug('==> SICK vocabulary size : %d ' % vocab.size())

    # load SICK dataset splits
    train_file = os.path.join(args.data, 'sick_train.pth')
    if os.path.isfile(train_file):
        train_dataset = torch.load(train_file)
    else:
        train_dataset = SICKDataset(train_dir, vocab, args.num_classes)
        torch.save(train_dataset, train_file)
    logger.debug('==> Size of train data   : %d ' % len(train_dataset))
    dev_file = os.path.join(args.data, 'sick_dev.pth')
    if os.path.isfile(dev_file):
        dev_dataset = torch.load(dev_file)
    else:
        dev_dataset = SICKDataset(dev_dir, vocab, args.num_classes)
        torch.save(dev_dataset, dev_file)
    logger.debug('==> Size of dev data     : %d ' % len(dev_dataset))
    test_file = os.path.join(args.data, 'sick_test.pth')
    if os.path.isfile(test_file):
        test_dataset = torch.load(test_file)
    else:
        test_dataset = SICKDataset(test_dir, vocab, args.num_classes)
        torch.save(test_dataset, test_file)
    logger.debug('==> Size of test data    : %d ' % len(test_dataset))

    # initialize model, criterion/loss_function, optimizer
    model = SimilarityTreeLSTM(
                vocab.size(),
                args.input_dim,
                args.mem_dim,
                args.hidden_dim,
                args.num_classes,
                args.sparse,
                args.freeze_embed)
    criterion = nn.KLDivLoss()
    if args.cuda:
        model.cuda(), criterion.cuda()
    if args.optim == 'adam':
        optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, weight_decay=args.wd)
    elif args.optim == 'adagrad':
        optimizer = optim.Adagrad(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, weight_decay=args.wd)
    elif args.optim == 'sgd':
        optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, weight_decay=args.wd)
    metrics = Metrics(args.num_classes)

    # for words common to dataset vocab and GLOVE, use GLOVE vectors
    # for other words in dataset vocab, use random normal vectors
    emb_file = os.path.join(args.data, 'sick_embed.pth')
    if os.path.isfile(emb_file):
        emb = torch.load(emb_file)
    else:
        # load glove embeddings and vocab
        glove_vocab, glove_emb = load_word_vectors(os.path.join(args.glove, 'glove.840B.300d'))
        logger.debug('==> GLOVE vocabulary size: %d ' % glove_vocab.size())
        emb = torch.Tensor(vocab.size(), glove_emb.size(1)).normal_(-0.05, 0.05)
        # zero out the embeddings for padding and other special words if they are absent in vocab
        for idx, item in enumerate([Constants.PAD_WORD, Constants.UNK_WORD, Constants.BOS_WORD, Constants.EOS_WORD]):
            emb[idx].zero_()
        for word in vocab.labelToIdx.keys():
            if glove_vocab.getIndex(word):
                emb[vocab.getIndex(word)] = glove_emb[glove_vocab.getIndex(word)]
        torch.save(emb, emb_file)
    # plug these into embedding matrix inside model
    if args.cuda:
        emb = emb.cuda()
    model.emb.weight.data.copy_(emb)

    # create trainer object for training and testing
    trainer = Trainer(args, model, criterion, optimizer)

    best = -float('inf')
    for epoch in range(args.epochs):
        train_loss             = trainer.train(train_dataset)
        train_loss, train_pred = trainer.test(train_dataset)
        dev_loss, dev_pred     = trainer.test(dev_dataset)
        test_loss, test_pred   = trainer.test(test_dataset)

        train_pearson = metrics.pearson(train_pred, train_dataset.labels)
        train_mse = metrics.mse(train_pred, train_dataset.labels)
        logger.info('==> Epoch {}, Train \tLoss: {}\tPearson: {}\tMSE: {}'.format(epoch, train_loss, train_pearson, train_mse))
        dev_pearson = metrics.pearson(dev_pred, dev_dataset.labels)
        dev_mse = metrics.mse(dev_pred, dev_dataset.labels)
        logger.info('==> Epoch {}, Dev \tLoss: {}\tPearson: {}\tMSE: {}'.format(epoch, dev_loss, dev_pearson, dev_mse))
        test_pearson = metrics.pearson(test_pred, test_dataset.labels)
        test_mse = metrics.mse(test_pred, test_dataset.labels)
        logger.info('==> Epoch {}, Test \tLoss: {}\tPearson: {}\tMSE: {}'.format(epoch, test_loss, test_pearson, test_mse))

        if best < test_pearson:
            best = test_pearson
            checkpoint = {
                'model': trainer.model.state_dict(), 
                'optim': trainer.optimizer,
                'pearson': test_pearson, 'mse': test_mse,
                'args': args, 'epoch': epoch
                }
            logger.debug('==> New optimum found, checkpointing everything now...')
            torch.save(checkpoint, '%s.pt' % os.path.join(args.save, args.expname))
Beispiel #7
0
class Application():
    '''
    This class loads the trained model and use the model to encode an ast into a vector and calculate the similarity between asts.
    '''
    def __init__(self, load_path, threshold = 0.5, cuda = False, model_name = "treelstm" , model = None):
        '''
        :param load_path: the path to saved model. e.g. "saved_model.pt"
        :param cuda: bool : whether GPU is utilized when model calculation. Notice that you need install the torch gpu version.
        :param model_name: the model type used. Only the "treelstm" is supported in this script.
        :param threshold: decided what similarity scores are outputted.
        :param model: A network variant of Tree-LSTM
        '''
        self.args = parse_args(["--cuda"]) #TODO argparse
        self.model =None
        self.embmodel = None #真正用于计算编码的模型
        self.model_name = model_name
        self.db = None
        if cuda and torch.cuda.is_available():
            self.device = torch.device("cuda:0")
        else:
            self.device = torch.device("cpu")
        if model is not None:
            self.model = model
        else:
            if model_name=="treelstm":
                from model import SimilarityTreeLSTM
                self.model = SimilarityTreeLSTM(
                    self.args.vocab_size,
                    self.args.input_dim,
                    self.args.mem_dim,
                    self.args.hidden_dim,
                    self.args.num_classes,
                    device=self.device
                )
        self.load_model(load_path, self.model)

    def load_model(self, path, model):
        '''
        :param path: path to saved model
        :param model: the model loaded
        :return:
        '''
        if not os.path.isfile(path):
            print("model path %s non-exists" % path)
            raise Exception
        checkpoint = torch.load(path, map_location=self.device)
        if "auc" in checkpoint:
            logger.info("checkpoint loaded: auc %f , mse: %f \n args %s" %(checkpoint['auc'], checkpoint['mse'], checkpoint['args']))
        model.load_state_dict(checkpoint['model'])
        model.eval()
        # self.model.to(self.device)
        self.embmodel =model.embmodel
        self.embmodel.to(self.device)


    def encode_ast(self, tree):
        '''
        :param tree: An Tree object instance
        :return: a numpy vector (64 or 150 d)
        '''
        with torch.no_grad():
            state, hidden = self.embmodel(tree, get_tree_flat_nodes(tree).to(self.device))
            return state.detach().squeeze(0).cpu().numpy()

    def similarity_treeencoding_with_correction(self, ltree, rtree, lcallee, rcallee):
        '''
        :param ltree: tree encoding vector
        :param rtree: tree encoding vector
        :param lcallee: (caller, callee) corresponds to ltree
        :param rcallee: (caller, callee) corresponds to rtree
        :return:
        '''

        sim_tree = self.similarity_vec(ltree, rtree)
        # return sim_tree
        # scale lcallee and rcallee in case that zero vector
        lcallee = list(map(lambda x: x + 1, lcallee))
        rcallee = list(map(lambda x: x + 1, rcallee))

        cs = cosine_similarity([lcallee], [rcallee])[0][0]  # cosine distance
        scale = exp(0 - abs(lcallee[-1] - rcallee[-1]))
        return sim_tree * scale * cs

    def similarity_tree_with_correction(self, ltree, rtree, lcallee, rcallee):
        '''
        :param ltree: AST1
        :param rtree: AST2
        :param lcallee: (caller, callee) corresponds to ltree
        :param rcallee:(caller, callee) corresponds to rtree
        :return:
        '''

        sim_tree = self.similarity_tree(ltree, rtree)
        #return sim_tree
        # scale lcallee and rcallee in case that zero vector
        lcallee=list(map(lambda x:x+1, lcallee))
        rcallee=list(map(lambda x:x+1, rcallee))

        cs = cosine_similarity([lcallee], [rcallee])[0][0] # (caller,callee)
        scale = exp(0 - abs(lcallee[-1]-rcallee[-1]))
        return sim_tree * scale * cs

    def similarity_tree(self, ltree, rtree):
        '''
        calculate the similarity of two asts
        :param ltree:  first tree
        :param rtree:  first tree
        :return:
        '''
        with torch.no_grad():
            if self.model_name=='treelstm' or self.model_name=="binarytreelstm":
                res = self.model(ltree, rtree)[0][1].item()
            else:
                res = self.model(ltree, rtree)[1].item()

            if torch.cuda.is_available():
                torch.cuda.empty_cache()

            return res

    def similarity_vec(self, lvec, rvec):
        '''
        calculate the similarity of two ast encodings
        :param lvec: numpy.ndarray or torch.tensor
        :param rvec:
        :return: similairty score ranges 0~1.
        '''
        if type(lvec) is list:
            lvec = numpy.array(lvec)
            rvec = numpy.array(rvec)
        if type(lvec) is numpy.ndarray:
            lvec = torch.from_numpy(lvec).to(self.device).float()
            rvec = torch.from_numpy(rvec).to(self.device).float()
            lvec = lvec.unsqueeze(0)
            rvec = rvec.unsqueeze(0)
        with torch.no_grad():
            if self.model_name in ['treelstm', 'treelstm_boosted']:
                res = self.model.similarity(lvec, rvec)[0][1].float().item()
            else:
                res = self.model.similarity(lvec, rvec)[1].float().item()
            return res

    def get_conn(self, db):
        # return the database connection
        global DB_CONNECTION
        if DB_CONNECTION is None:
            global datahelper
            DB_CONNECTION = datahelper.load_database(db)
        return DB_CONNECTION

    def encode_ast_in_db(self, db_path, table_name="function"):
        '''
        Encode the asts in db file "db_path" and save the encoding vectors into table 'table_name'
        :param db_path: path to sqlite database
        '''
        db_conn = self.get_conn(db_path)
        cur = db_conn.cursor()
        # create table if not exists
        sql_create_new_table = """CREATE TABLE if not exists %s (
                            function_name varchar (255),
                            elf_path varchar (255),
                            ast_encode TEXT,
                            primary key (function_name, elf_path)
                            );""" % table_name
        try:
            cur.execute(sql_create_new_table)
            db_conn.commit()
        except Exception as e:
            logger.error("sql [%s] failed" % sql_create_new_table)
            logger.error(e)
        finally:
            cur.close()
        to_encode_list = []
        global datahelper
        for func_info, func_ast in old_datahelper.get_functions(db_path):
            to_encode_list.append((func_info, func_ast))
        encode_group = to_encode_list
        logger.info("Encoding for %d ast" % len(encode_group))
        self.encode_and_update(db_path, encode_group, table_name)

    def __del__(self):
        if self.db:
            self.db.close()

    def func_wrapper(self, func, *args):
        # to execute the function in a limited period time
        self.Timeout = 350 #8 minutes
        gevent.Timeout(self.Timeout, Exception).start()
        s = datetime.datetime.now()
        try:
            g = gevent.spawn(func, args[0])
            g.join()
            return g.get()
        except Exception as e:
            #print("Timeout Error. ")
            e = datetime.datetime.now()
            #print("Real run time is %d" % (e - s).seconds)
            r = numpy.zeros((1,150))
            return r
    def encode_and_update(self, db_path, functions, table_name):
        '''
        encode asts of the functions into vectors
        :param functions: a list contains asts to be encoded
        :param table_name: the new table to save ast encodings
        '''
        db_conn = self.get_conn(db_path)
        p = Pool(processes=10)
        res = []
        count = 1
        for func_info, func_ast in functions:
            res.append((p.apply_async(self.func_wrapper, (self.encode_ast, func_ast)), func_info[0], func_info[1])) # func_info[0] is function name; func_info[1] is elf_path
            count+=1
        #try:
        p.close()
        p.join()
        result = []
        try:
            logger.info("Fetching encode results!")
            for idx, r in tqdm(enumerate(res)):
                result.append((json.dumps(r[0].get().tolist()), r[1], r[2]))
            logger.info("All encode fetched!")
        except Exception as e:
            print("Exception when fetching {}".format(str(e)))
        try:
            logger.info("Writing encoded vectors to database")
            cur = db_conn.cursor()
            sql_update = """ 
            update {} set ast_encode=? where function_name=? AND elf_path=?
            """.format(table_name)
            cur.executemany(sql_update, result)
            cur.close()
            db_conn.commit()
        except Exception as e:
            db_conn.rollback()
            print("Error when INSERT [{}]\n".format(sql_update))
            print(e)
Beispiel #8
0
    random.seed(cfg.random_seed())

    if cfg.use_cuda():
        torch.cuda.manual_seed(cfg.random_seed())
        torch.backends.cudnn.benchmark = True


    cfg.logger.info('==> SICK vocabulary size : %d ' % D.vocab.size())
    cfg.logger.info('==> Size of train data   : %d ' % len(train_dataset))
    cfg.logger.info('==> Size of dev data     : %d ' % len(dev_dataset))
    cfg.logger.info('==> Size of test data    : %d ' % len(test_dataset))

    model = SimilarityTreeLSTM(
        D.vocab.size(),
        cfg.input_dim(),
        cfg.mem_dim(),
        cfg.hidden_dim(),
        cfg.num_classes(),
        cfg.sparse(),
        cfg.freeze_embed())

    criterion = nn.KLDivLoss()

    cfg.logger.info("model:\n" + str(model))

    emb = get_embd(cfg, D.vocab)

    # plug these into embedding matrix inside model
    model.emb.weight.data.copy_(emb)

    model.to(cfg.device()), criterion.to(cfg.device())
def prepare_to_train(data=None, glove=None):
    args = parse_args()
    if data is not None:
        args.data = data
    if glove is not None:
        args.glove = glove

    args.input_dim, args.mem_dim = 300, 150
    args.hidden_dim, args.num_classes = 50, 5
    args.cuda = args.cuda and torch.cuda.is_available()
    if args.sparse and args.wd != 0:
        print('Sparsity and weight decay are incompatible, pick one!')
        exit()
    print(args)
    torch.manual_seed(args.seed)
    random.seed(args.seed)
    numpy.random.seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)
        torch.backends.cudnn.benchmark = True
    if not os.path.exists(args.save):
        os.makedirs(args.save)

    train_dir = os.path.join(args.data, 'train/')
    dev_dir = os.path.join(args.data, 'dev/')
    test_dir = os.path.join(args.data, 'test/')

    # write unique words from all token files
    sick_vocab_file = os.path.join(args.data, 'sick.vocab')
    if not os.path.isfile(sick_vocab_file):
        token_files_a = [
            os.path.join(split, 'a.toks')
            for split in [train_dir, dev_dir, test_dir]
        ]
        token_files_b = [
            os.path.join(split, 'b.toks')
            for split in [train_dir, dev_dir, test_dir]
        ]
        token_files = token_files_a + token_files_b
        sick_vocab_file = os.path.join(args.data, 'sick.vocab')
        build_vocab(token_files, sick_vocab_file)

    # get vocab object from vocab file previously written
    vocab = Vocab(filename=sick_vocab_file,
                  data=[
                      Constants.PAD_WORD, Constants.UNK_WORD,
                      Constants.BOS_WORD, Constants.EOS_WORD
                  ])
    print('==> SICK vocabulary size : %d ' % vocab.size())

    # load SICK dataset splits
    train_file = os.path.join(args.data, 'sick_train.pth')
    if os.path.isfile(train_file):
        train_dataset = torch.load(train_file)
    else:
        train_dataset = SICKDataset(train_dir, vocab, args.num_classes)
        torch.save(train_dataset, train_file)
    print('==> Size of train data   : %d ' % len(train_dataset))
    dev_file = os.path.join(args.data, 'sick_dev.pth')
    if os.path.isfile(dev_file):
        dev_dataset = torch.load(dev_file)
    else:
        dev_dataset = SICKDataset(dev_dir, vocab, args.num_classes)
        torch.save(dev_dataset, dev_file)
    print('==> Size of dev data     : %d ' % len(dev_dataset))
    test_file = os.path.join(args.data, 'sick_test.pth')
    if os.path.isfile(test_file):
        test_dataset = torch.load(test_file)
    else:
        test_dataset = SICKDataset(test_dir, vocab, args.num_classes)
        torch.save(test_dataset, test_file)
    print('==> Size of test data    : %d ' % len(test_dataset))

    # initialize model, criterion/loss_function, optimizer
    model = SimilarityTreeLSTM(args.cuda, vocab.size(), args.input_dim,
                               args.mem_dim, args.hidden_dim, args.num_classes,
                               args.sparse)
    criterion = nn.KLDivLoss()
    if args.cuda:
        model.cuda(), criterion.cuda()
    if args.optim == 'adam':
        optimizer = optim.Adam(model.parameters(),
                               lr=args.lr,
                               weight_decay=args.wd)
    elif args.optim == 'adagrad':
        optimizer = optim.Adagrad(model.parameters(),
                                  lr=args.lr,
                                  weight_decay=args.wd)
    elif args.optim == 'sgd':
        optimizer = optim.SGD(model.parameters(),
                              lr=args.lr,
                              weight_decay=args.wd)
    metrics = Metrics(args.num_classes)

    # for words common to dataset vocab and GLOVE, use GLOVE vectors
    # for other words in dataset vocab, use random normal vectors
    emb_file = os.path.join(args.data, 'sick_embed.pth')
    if os.path.isfile(emb_file):
        emb = torch.load(emb_file)
    else:
        # load glove embeddings and vocab
        glove_vocab, glove_emb = load_word_vectors(
            os.path.join(args.glove, 'glove.840B.300d'))
        print('==> GLOVE vocabulary size: %d ' % glove_vocab.size())
        emb = torch.Tensor(vocab.size(),
                           glove_emb.size(1)).normal_(-0.05, 0.05)
        # zero out the embeddings for padding and other special words if they are absent in vocab
        for idx, item in enumerate([
                Constants.PAD_WORD, Constants.UNK_WORD, Constants.BOS_WORD,
                Constants.EOS_WORD
        ]):
            emb[idx].zero_()
        for word in vocab.labelToIdx.keys():
            if glove_vocab.get_index(word):
                emb[vocab.get_index(word)] = glove_emb[glove_vocab.get_index(
                    word)]
        torch.save(emb, emb_file)
    # plug these into embedding matrix inside model
    if args.cuda:
        emb = emb.cuda()
    model.childsumtreelstm.emb.state_dict()['weight'].copy_(emb)

    # create trainer object for training and testing
    #trainer = Trainer(args, model, criterion, optimizer)

    best = -float('inf')

    return (args, best, train_dataset, dev_dataset, test_dataset, metrics,
            optimizer, criterion, model)