def main(): parser = argparse.ArgumentParser(description="Generate datapoints from AST") parser.add_argument("--ast_fp", "-a", help="Filepath with the ASTs to be parsed") parser.add_argument( "--n_ctx", "-c", type=int, default=1000, help="Number of contexts for each dp" ) parser.add_argument( "--out_fp", "-o", default="/tmp/dps.txt", help="Filepath with the output dps" ) args = parser.parse_args() if os.path.exists(args.out_fp): os.remove(args.out_fp) logging.info("Writing dps to: {}".format(args.out_fp)) num_dps = 0 with open(args.ast_fp, "r") as f, open(args.out_fp, "w") as fout: for line in file_tqdm(f): dp = json.loads(line.strip()) aug_tokens, aug_leaf_ids = get_dps(dp, args.n_ctx) for (tokens, ext), leaf in zip(aug_tokens, aug_leaf_ids): if len(tokens) > 1: json.dump([tokens, ext], fp=fout) fout.write("\n") num_dps += 1 logging.info("Wrote {} datapoints to {}".format(num_dps, args.out_fp))
def main(): parser = argparse.ArgumentParser( description="Parse AST IDs for evaluation") parser.add_argument("--ast", help="Filepath with new ASTs") parser.add_argument("--out", help="Outfile for ids.txt") parser.add_argument("--tokenizer", help="Specify Tokenizer") args = parser.parse_args() if os.path.exists(args.out): os.remove(args.out) tokenizer = Tokenizer.from_file(args.tokenizer) num_dps = 0 with open(args.ast) as fin, open(args.out, "w") as fout, open( "output/debug_ids.txt", "w") as fout2: for i, line in enumerate(file_tqdm(fin)): dp = json.loads(line.strip()) asts = split(dp, 1000, tokenizer) fout2.write("{}: {}\n".format(i, len(asts))) for ast in asts: if len(ast) > 1: ids = {"leaf_ids": ast} json.dump(ids, fp=fout) fout.write("\n") num_dps += 1 logging.info("Wrote {} datapoints to {}".format(num_dps, args.out))
def main(): parser = argparse.ArgumentParser(description="Generate datapoints from AST") parser.add_argument("--ast_fp", "-a", help="Filepath with the ASTs to be parsed") parser.add_argument( "--out_fp", "-o", default="/tmp/dps.txt", help="Filepath for the output dps" ) parser.add_argument( "--n_ctx", "-c", type=int, default=1000, help="Number of contexts for each dp" ) args = parser.parse_args() if os.path.exists(args.out_fp): os.remove(args.out_fp) logging.info("Number of context: {}".format(args.n_ctx)) num_dps = 0 logging.info("Loading asts from: {}".format(args.ast_fp)) with open(args.ast_fp, "r") as f, open(args.out_fp, "w") as fout: for line in file_tqdm(f): dp = json.loads(line.strip()) asts = separate_dps(dp, args.n_ctx) for ast, extended in asts: if len(ast) > 1: json.dump([get_dfs(ast), extended], fp=fout) fout.write("\n") num_dps += 1 logging.info("Wrote {} datapoints to {}".format(num_dps, args.out_fp))
def build_subword_prob_cli(args): if os.path.exists(args.output): logger.warning(f"{args.output} already exists!") logger.info("loading...") with open(args.word_freq) as fin: word_count_iter = (json.loads(line) for line in file_tqdm(fin)) subword_counter = build_subword_counter( word_count_iter, min_count=args.subword_min_count, min_len=args.subword_min_len, max_len=args.subword_max_len, word_boundary=args.word_boundary, uniq_factor=args.subword_uniq_factor, ) logger.info("processing...") if args.subword_prob_take_root: logger.warning( "`args.subword_prob_take_root = True` ignored at this step.") subword_prob = build_subword_prob( subword_counter, normalize_prob=normalize_prob, min_prob=args.subword_prob_min_prob, # take_root=args.subword_prob_take_root, ) logger.info("saving...") with open(args.output, 'w') as fout: for (subword, prob) in tqdm(subword_prob.most_common()): print(json.dumps((subword, prob)), file=fout)
def external(file_path, n_vocab): outfile = "output/vocab.pkl" logging.info("Reading from: {}".format(file_path)) vocab = Counter() with open(file_path, "r") as f: for line in file_tqdm(f): vocab.update(get_value(json.loads(line.strip()), "ast")) vocab_to_keep = [i[0] for i in vocab.most_common(n_vocab)] top_total = sum(i[1] for i in vocab.most_common(n_vocab)) total = sum(vocab.values()) logging.info("Total # of vocab: {}".format(len(vocab))) logging.info( "Using {} top vocab covers: {:.2f}% of the entire dataset".format( n_vocab, 100 * top_total / total)) logging.info("Top 10 most common vocab:") for v, i in vocab.most_common(10): print(v, i) # add unk and pad tokens vocab_to_keep.append(UNK) vocab_to_keep.append(PAD) logging.info("Added {} and {}".format(UNK, PAD)) # dump vocab to file with open(outfile, "wb") as fout: pickle.dump(vocab_to_keep, fout) logging.info("Wrote {} vocab to: {}".format(len(vocab_to_keep), outfile))
def external(file_path, suffix, n_ctx): outfile = "output/{}_ids.txt".format(suffix) if os.path.exists(outfile): os.remove(outfile) logging.info("Type of id to get: {}".format("all")) logging.info("Loading dps from: {}".format(file_path)) with open(file_path, "r") as f, open(outfile, "w") as fout: for line in file_tqdm(f): dp = json.loads(line.strip()) asts = separate_dps(dp, n_ctx) for ast, _ in asts: ids = {} if len(ast) > 1: if "all" in {"leaf", "all"}: ids.update(get_leaf_ids(ast)) if "all" in {"value", "all"}: ids.update(get_value_ids(ast)) if "all" in {"type", "all"}: ids.update(get_type_ids(ast)) json.dump(ids, fp=fout) fout.write("\n") logging.info("Wrote to: {}".format(outfile))
def external(file_path, suffix): outfile = "output/{}_new_trees.json".format(suffix) if os.path.exists(outfile): os.remove(outfile) logging.info("Loading asts from: {}".format(file_path)) with open(file_path, "r") as f, open(outfile, "w") as fout: for line in file_tqdm(f): dp = json.loads(line.strip()) print(json.dumps(convert(dp)), file=fout) logging.info("Wrote dps to: {}".format(outfile))
def main(): parser = argparse.ArgumentParser( description="Generate datapoints from source code") parser.add_argument("--files_fp", "-f", help="Filepath with the filenames to be parsed") parser.add_argument("--out_fp", "-o", default="/tmp/dps.txt", help="Filepath with the output dps") parser.add_argument("--base_dir", "-b", help="Base dir to append for the fps") parser.add_argument("--n_ctx", "-c", type=int, default=1000, help="Number of contexts for each dp") parser.add_argument( "id_type", choices=["leaf", "value", "token", "all"], default="", help="Which ids to generate. Default = get the tokens", ) args = parser.parse_args() if os.path.exists(args.out_fp): os.remove(args.out_fp) logging.info("Number of context: {}".format(args.n_ctx)) num_dps = 0 logging.info("Loading files from: {}".format(args.base_dir)) with open(args.files_fp, "r", errors='ignore') as f, open(args.out_fp, "w") as fout: for line in file_tqdm(f): fp = os.path.join(args.base_dir, line.strip()) try: aug_tokens, aug_types = my_tokenize( open(fp).read(), args.n_ctx) for (tokens, ext), (types_, _) in zip(aug_tokens, aug_types): if len(tokens) > 1: if args.id_type == "leaf": json.dump(get_leaf_ids(types_), fp=fout) elif args.id_type == "value": json.dump(get_value_ids(types_), fp=fout) elif args.id_type == "all": ids = get_leaf_ids(types_) ids.update(get_value_ids(types_)) json.dump(ids, fp=fout) else: json.dump([tokens, ext], fp=fout) fout.write("\n") num_dps += 1 except: continue logging.info("Wrote {} datapoints to {}".format(num_dps, args.out_fp))
def external(fp, suffix, tokenizer): tokenizer = Tokenizer.from_file(tokenizer) outfile = "output/{}_ids.txt".format(suffix) num_dps = 0 with open(fp) as fin, open(outfile, "w") as fout: for i, line in enumerate(file_tqdm(fin)): dp = json.loads(line.strip()) asts = split(dp, 1000, tokenizer) for ast in asts: if len(ast) > 1: ids = {"leaf_ids": ast} json.dump(ids, fp=fout) fout.write("\n") num_dps += 1 logging.info("Wrote {} datapoints to {}".format(num_dps, outfile))
def main(): parser = argparse.ArgumentParser(description="Create vocab for py150 dataset") parser.add_argument("--n_vocab", "-n", type=int, default=100000) parser.add_argument("--input_fp", "-i") parser.add_argument("--out_fp", "-o", default="/tmp/vocab.pkl") parser.add_argument( "--input_type", "-t", choices=["ast", "leaf", "source_code"], help="Where to get the input from (all AST nodes, leaf nodes, or source code", ) args = parser.parse_args() logging.info("Reading from: {}".format(args.input_fp)) logging.info("Input type: {}".format(args.input_type)) vocab = Counter() # with open('../data/test_trees.json', "r") as f: # for line in file_tqdm(f): # vocab.update(get_value(json.loads(line.strip()), args.input_type)) with open(args.input_fp, "r") as f: for line in file_tqdm(f): vocab.update(get_value(json.loads(line.strip()), args.input_type)) vocab_to_keep = [i[0] for i in vocab.most_common(args.n_vocab)] top_total = sum(i[1] for i in vocab.most_common(args.n_vocab)) total = sum(vocab.values()) logging.info("Total # of vocab: {}".format(len(vocab))) logging.info( "Using {} top vocab covers: {:.2f}% of the entire dataset".format( args.n_vocab, 100 * top_total / total ) ) logging.info("Top 10 most common vocab:") for v, i in vocab.most_common(10): print(v, i) # add unk and pad tokens vocab_to_keep.append(UNK) vocab_to_keep.append(PAD) logging.info("Added {} and {}".format(UNK, PAD)) # dump vocab to file with open(args.out_fp, "w") as fout: json.dump(vocab_to_keep, fout) logging.info("Wrote {} vocab to: {}".format(len(vocab_to_keep), args.out_fp))
def main(): parser = argparse.ArgumentParser( description="Generate ids (leaf, values, types) from AST") parser.add_argument("--ast_fp", "-a", help="Filepath with the new ASTs to be parsed") parser.add_argument("--out_fp", "-o", default="/tmp/ids.txt", help="Filepath for the output ids") parser.add_argument("--n_ctx", "-c", type=int, default=1000, help="Number of contexts for each dp") parser.add_argument( "id_type", choices=["leaf", "value", "type", "all"], default="leaf", help="Which ids to generate. Default = leaf", ) args = parser.parse_args() if os.path.exists(args.out_fp): os.remove(args.out_fp) logging.info("Type of id to get: {}".format(args.id_type)) logging.info("Loading dps from: {}".format(args.ast_fp)) with open(args.ast_fp, "r") as f, open(args.out_fp, "w") as fout: for line in file_tqdm(f): dp = json.loads(line.strip()) asts = separate_dps(dp, args.n_ctx) for ast, _ in asts: ids = {} if len(ast) > 1: if args.id_type in {"leaf", "all"}: ids.update(get_leaf_ids(ast)) if args.id_type in {"value", "all"}: ids.update(get_value_ids(ast)) if args.id_type in {"type", "all"}: ids.update(get_type_ids(ast)) json.dump(ids, fp=fout) fout.write("\n") logging.info("Wrote to: {}".format(args.out_fp))
def main(): parser = argparse.ArgumentParser( description="Create vocab for code2seq model for py150 dataset") parser.add_argument("--n_vocab", "-n", type=int, default=100000) parser.add_argument("--input_fp", "-i") parser.add_argument("--out_fp", "-o", default="/tmp/vocab.pkl") parser.add_argument( "--vocab_type", "-v", choices=["token", "subtoken", "output"], help="What type of vocab to get", ) args = parser.parse_args() logging.info("Reading from: {}".format(args.input_fp)) logging.info("Vocab type: {}".format(args.vocab_type)) vocab = Counter() with open(args.input_fp, "r") as f: for line in file_tqdm(f): vocab.update(get_value(json.loads(line.strip()), args.vocab_type)) vocab_to_keep = [i[0] for i in vocab.most_common(args.n_vocab)] top_total = sum(i[1] for i in vocab.most_common(args.n_vocab)) total = sum(vocab.values()) logging.info("Total # of vocab: {}".format(len(vocab))) logging.info( "Using {} top vocab covers: {:.2f}% of the entire dataset".format( args.n_vocab, 100 * top_total / total)) logging.info("Top 10 most common vocab:") for v, i in vocab.most_common(10): print(v, i) # add unk and pad tokens vocab_to_keep.append(UNK) vocab_to_keep.append(PAD) vocab_to_keep.append(PLACEHOLDER) logging.info("Added {} and {} and {}".format(UNK, PAD, PLACEHOLDER)) # dump vocab to file with open(args.out_fp, "wb") as fout: pickle.dump(vocab_to_keep, fout) logging.info("Wrote {} vocab to: {}".format(len(vocab_to_keep), args.out_fp))
def __init__( self, base_dir, fp, ids_fp, max_vocab=100000, mode="train" ): super().__init__() if mode not in {"train", "test"}: raise Exception("Mode must be either train or test") self.mode = mode self.fp = fp self.max_vocab = max_vocab # get all the relevant filepaths self.filepaths = { "vocab": os.path.join(base_dir, "vocab.pkl"), "metrics": os.path.join(base_dir, "{}_metrics.txt".format(mode)), "conv": os.path.join(base_dir, "{}_converted.txt".format(mode)), } self._add_extra_filepaths(base_dir) logging.info("Writing metrics to: {}".format(self.filepaths["metrics"])) # filter dataset filtered_fp = self._filter_dataset() # set up vocab self.vocab = self._create_vocab() # convert if not os.path.exists(self.filepaths["conv"]): with open(filtered_fp, "r") as fin, open( self.filepaths["conv"], "w" ) as fout: for line in utils.file_tqdm(fin): line = json.loads(line.strip()) print(json.dumps(self.vocab.convert(line)), file=fout) logging.info( "Converted dataset to idx and saved to: {}".format( self.filepaths["conv"] ) ) # return dataset self.dataset = self._create_dataset(self.filepaths["conv"], ids_fp) logging.info("Loaded dataset from {}".format(self.filepaths["conv"]))
def external(file_path, suffix, context_size, overlap): outfile = "output/{}_dps.txt".format(suffix) if os.path.exists(outfile): os.remove(outfile) logging.info("Number of context: {}".format(context_size)) num_dps = 0 logging.info("Loading asts from: {}".format(file_path)) with open(file_path, "r") as f, open(outfile, "w") as fout: for line in file_tqdm(f): dp = json.loads(line.strip()) asts = rq6_separate_dps(dp, context_size, overlap) for ast, extended in asts: if len(ast) > 1: json.dump([get_dfs(ast), extended], fp=fout) fout.write("\n") num_dps += 1 logging.info("Wrote {} datapoints to {}".format(num_dps, outfile))
def preprocess(fp, suffix, tokenizer): tokenizer = Tokenizer.from_file(tokenizer) dps_outfile = "output/{}_dps.txt".format(suffix) ids_outfile = "output/{}_ids.txt".format(suffix) num = 0 with open(fp) as fin, open(dps_outfile, "w") as fout_dps, open(ids_outfile, "w") as fout_ids: for i, line in enumerate(file_tqdm(fin)): dp = json.loads(line.strip()) asts, ids = split(dp, 1000, tokenizer) for i, (ast, extended) in enumerate(asts): if len(ast) > 1: json.dump([ast, extended], fp=fout_dps) json.dump(ids[i], fp=fout_ids) fout_dps.write("\n") fout_ids.write("\n") num += 1 logging.info("Wrote {} datapoints to {} and {}".format( num, ids_outfile, dps_outfile))
def main(): parser = argparse.ArgumentParser( description="Generate datapoints from AST") parser.add_argument("--ast_fp", "-a", help="Filepath with the ASTs to be parsed") parser.add_argument("--out_fp", "-o", default="/tmp/dps.txt", help="Filepath with the output dps") parser.add_argument("--n_ctx", "-c", type=int, default=1000, help="Number of contexts for each dp") parser.add_argument( "--max_path_len", "-p", type=int, default=13, help="Max length of rootpath route", ) args = parser.parse_args() if os.path.exists(args.out_fp): os.remove(args.out_fp) logging.info("Writing dps to: {}".format(args.out_fp)) num_dps = 0 with open(args.ast_fp, "r") as f, open(args.out_fp, "w") as fout: for line in file_tqdm(f): dp = json.loads(line.strip()) for dp in get_dps(dp, args.n_ctx, args.max_path_len): if len(dp[0]) > 1: json.dump(dp, fout) fout.write("\n") num_dps += 1 logging.info("Wrote {} datapoints to {}".format(num_dps, args.out_fp))
def build_subword_vocab_cli(args): if os.path.exists(args.output): logger.warning(f"{args.output} already exists!") logger.info("loading...") with open(args.word_freq) as fin: word_count_iter = (json.loads(line) for line in file_tqdm(fin)) subword_counter = build_subword_counter( word_count_iter, max_size=args.subword_vocab_max_size, min_count=args.subword_min_count, min_len=args.subword_min_len, max_len=args.subword_max_len, word_boundary=args.word_boundary, uniq_factor=args.subword_uniq_factor, ) logger.info("processing...") subword_vocab = subword_counter logger.info("saving...") with open(args.output, 'w') as fout: for (subword, count) in tqdm(subword_vocab.most_common()): print(json.dumps((subword, count)), file=fout)
def load_embedding(filename: str, show_progress=False) -> (List[str], np.ndarray): """ :param filename: a .txt file or a .pkl/.pickle file :return: tuple (words, embeddings) """ import os if show_progress: from utils import file_tqdm else: from utils import dummy_tqdm as file_tqdm _, ext = os.path.splitext(filename) if ext in (".txt", ".w2v"): vocab, emb = [], [] with open(filename, "r") as fin: if ext == ".w2v": next(fin) for line in file_tqdm(fin): ss = line.split() try: emb.append([float(x) for x in ss[1:]]) vocab.append(ss[0]) except ValueError: print(f"Error loading the line: {line[:30]} ...") emb = np.array(emb) elif ext in (".pickle", ".pkl"): import pickle try: with open(filename, 'rb') as bfin: vocab, emb = pickle.load(bfin) except UnicodeDecodeError: with open(filename, 'rb') as bfin: vocab, emb = pickle.load(bfin, encoding='bytes') else: raise ValueError(f'Unsupported target vector file extent: {filename}') return vocab, emb
def main(): parser = argparse.ArgumentParser( description="Generate datapoints from AST") parser.add_argument("--input_fp", "-i", help="Filepath with the ASTs to be parsed") parser.add_argument( "--out_fp", "-o", default="/tmp/new_trees.json", help="Filepath with the output dps", ) args = parser.parse_args() if os.path.exists(args.out_fp): os.remove(args.out_fp) logging.info("Loading asts from: {}".format(args.input_fp)) with open(args.input_fp, "r") as f, open(args.out_fp, "w") as fout: for line in file_tqdm(f): dp = json.loads(line.strip()) print(json.dumps(convert(dp)), file=fout) logging.info("Wrote dps to: {}".format(args.out_fp))
add_subword_vocab_args(parser) add_logging_args(parser) args = parser.parse_args() set_logging_config(args) dump_args(args) logger.info(f"building subword prob from `{args.prob_word_freq}`...") if args.prob_word_freq.lower().startswith("unigram_freq"): word_freq_path = import_module("datasets.unigram_freq")\ .prepare_unigram_freq_paths().word_freq_path else: raise ValueError( f"args.prob_word_freq=`{args.prob_word_freq}` not supported.") with open(word_freq_path) as fin: word_count_iter = (json.loads(line) for line in file_tqdm(fin)) subword_counter = build_subword_counter( word_count_iter, min_count=args.subword_min_count, min_len=args.subword_min_len, max_len=args.subword_max_len, word_boundary=args.word_boundary, uniq_factor=args.subword_uniq_factor, ) subword_prob = build_subword_prob( subword_counter, normalize_prob=normalize_prob, min_prob=args.subword_prob_min_prob, take_root=args.subword_prob_take_root, ) logger.info(f"subword prob size: {len(subword_prob)}")
def load_from_file(self, file_path): with open(file_path) as fin: for line in utils.file_tqdm(fin): self.inputs.append([int(s) for s in line.split()])
def prepare_glove_paths( dir_path=dir_path, zip_path=zip_path, raw_emb_path=raw_emb_path, txt_emb_path=txt_emb_path, word_freq_path=word_freq_path, w2v_emb_path=w2v_emb_path, raw_count_path=raw_count_path, ): if not os.path.exists(zip_path): logger.info("downloading zip file...") url = "http://nlp.stanford.edu/data/glove.840B.300d.zip" sp.run(f"wget -O {zip_path} {url}".split()) if not os.path.exists(raw_emb_path): logger.info("unzipping...") with zipfile.ZipFile(zip_path, "r") as zip_ref: zip_ref.extractall(dir_path) if not os.path.exists(txt_emb_path): logger.info("generating txt emb file...") with open(raw_emb_path, "r") as fin, open(txt_emb_path, "w") as fout: vocab_len = 0 for line in file_tqdm(fin): ss = line.split() if len(ss) != emb_dim + 1: logging.critical( f'line "{line[:30]}"... might include word with space, skipped' ) continue w = ss[0] # copied from `datasets/google/converter.py` aw = unicodedata.normalize("NFKD", w).encode("ASCII", "ignore") if 20 > len(aw) > 1 and not any( c in w for c in " _./") and aw.islower(): vocab_len += 1 fout.write(line) if not os.path.exists(w2v_emb_path): logger.info("generating w2v emb file...") with open(txt_emb_path) as fin, open(w2v_emb_path, "w") as fout: print(vocab_len, emb_dim, file=fout) for line in file_tqdm(fin): fout.write(line) if not os.path.exists(word_freq_path): logger.info("generating word freq jsonl file...") with open(txt_emb_path) as fin, open(word_freq_path, "w") as fout: for line in fin: print(json.dumps((line.split()[0], 1)), file=fout) if not os.path.exists(raw_count_path): logger.info("generating word freq txt file...") with open(txt_emb_path) as fin, open(raw_count_path, "w") as fout: for line in fin: print(line.split()[0], 1, file=fout, sep='\t') return dotdict( dir_path=dir_path, raw_emb_path=raw_emb_path, txt_emb_path=txt_emb_path, w2v_emb_path=w2v_emb_path, word_freq_path=word_freq_path, raw_count_path=raw_count_path, )
def main(args): set_logging_config(args) save_path = args.model_path.format( timestamp=datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')) save_dir, _ = os.path.split(save_path) try: os.makedirs(save_dir) except FileExistsError: logger.warning( "Things will get overwritten for directory {}".format(save_dir)) dump_args(args, logger, os.path.join(save_dir, 'args.json')) logger.info(f'loading target vectors from `{args.target_vectors}`...') target_words, target_embs = \ load_embedding(args.target_vectors, show_progress=True) logger.info(f'embeddings loaded with {len(target_words)} words') logger.info(f"loading subword vocab from `{args.subword_vocab}`...") with open(args.subword_vocab) as fin: subword_vocab = dict(json.loads(line) for line in file_tqdm(fin)) logger.info(f"subword vocab size: {len(subword_vocab)}") if args.subword_prob: logger.info(f"loading subword prob from `{args.subword_prob}`...") with open(args.subword_prob) as fin: subword_prob = dict(json.loads(line) for line in file_tqdm(fin)) subword_prob = subword_prob_post_process( subword_prob, min_prob=args.subword_prob_min_prob, # take_root=args.subword_prob_take_root, ) else: subword_prob = None np.random.seed(args.random_seed) def MSE(pred, target): return sum((pred - target)**2) / 2 def MSE_backward(pred, target): return (pred - target) model = PBoS( embedding_dim=len(target_embs[0]), subword_vocab=subword_vocab, subword_prob=subword_prob, weight_threshold=args.subword_weight_threshold, eps=args.subword_prob_eps, take_root=args.subword_prob_take_root, normalize_semb=args.normalize_semb, ) start_time = time() for i_epoch in range(args.epochs): h = [] h_epoch = [] lr = args.lr / (1 + i_epoch)**0.5 if args.lr_decay else args.lr logger.info('epoch {:>2} / {} | lr {:.5f}'.format( 1 + i_epoch, args.epochs, lr)) epoch_start_time = time() for i_inst, wi in enumerate( np.random.choice(len(target_words), len(target_words), replace=False), start=1, ): target_emb = target_embs[wi] word = target_words[wi] model_word = bound_word(word) if args.word_boundary else word model_emb = model.embed(model_word) grad = MSE_backward(model_emb, target_emb) if i_inst % 20 == 0: loss = MSE(model_emb, target_emb) / len( target_emb) # average over dimension for easy reading h.append(loss) if i_inst % 10000 == 0: width = len(f"{len(target_words)}") fmt = 'processed {:%d}/{:%d} | loss {:.5f}' % (width, width) logger.info( fmt.format(i_inst, len(target_words), np.average(h))) h_epoch.extend(h) h = [] d = -lr * grad model.step(model_word, d) now_time = time() logger.info( 'epoch {i_epoch:>2} / {n_epoch} | loss {loss:.5f} | time {epoch_time:.2f}s / {training_time:.2f}s' .format( i_epoch=1 + i_epoch, n_epoch=args.epochs, loss=np.average(h_epoch), epoch_time=now_time - epoch_start_time, training_time=now_time - start_time, )) logger.info('saving model...') model.dump(save_path)
import json import utils # This will escape special characters and remove whitespaces from node types/values, output is base for tokenizer # TODO: Add ext info and store as json with open("output/dps.txt", "r") as fin, open("output/new_ast_raw.txt", "w") as fout: for line in utils.file_tqdm(fin): json_line = json.loads(line) nodes = [] for node in json_line[0]: nodes.append(node.strip().replace( " ", "<spc>").encode("unicode_escape").decode()) fout.write(" ".join(nodes) + "\n")