def load_embeddings(self, embedding_file): embedding_matrix = load_npy(embedding_file) if embedding_matrix.ndim == 2: return embedding_matrix else: assert embedding_matrix.ndim == 3 # avgEmb (uniform mean) return embedding_matrix.mean(axis=1)
def __init__(self, embedding_file, vocab_file, oov, hidden1, hidden2, hidden3, win, alpha, cembedding_file, infer_side_win): super().__init__(embedding_file, vocab_file, oov, hidden1, hidden2, hidden3, win, alpha) self.embedding_matrix = self.load_embeddings(embedding_file) self.cembedding_matrix = load_npy(cembedding_file) self.infer_side_win = infer_side_win # context size to each side for avgExp assert self.infer_side_win > 0 assert self.cembedding_matrix.ndim == 2 # use uniform avg for oov padding self.oov_vec = self.embedding_matrix[self.oov].mean(axis=0) self.oov_vecs = [self.oov_vec]*(self.win//2)
parser.add_argument("-f", default="W_w", help="Model file (optional), meant for models per epoch.") parser.add_argument("-data_path", help="Filepath containing the SCWS dataset.", default="data/SCWS/ratings.txt") parser.add_argument("-win_size", default=3, type=int, help="Context window size (n words to the left and n to the right).") parser.add_argument("-n_most_freq", type=int, help="Only consider n most freq. words from vocabulary.") args = parser.parse_args() w_index_path = "{}/w_index.json".format(args.input_dir) # model_path = "{}/sg.pickle".format(args.input_dir) log.info("Loading model.") w_index = load_json(w_index_path) if args.n_most_freq: w_index = {w: i for w, i in w_index.items() if i < args.n_most_freq + 1} print(len(w_index)) embs = load_npy("{}/{}.npy".format(args.input_dir, args.f)) c_embs = load_npy("{}/W_c.npy".format(args.input_dir)) try: if args.f == "W_w": n = "" else: n = eval(args.f[-1]) assert 0 <= n < 9 bias = load_npy("{}/Wb{}.npy".format(args.input_dir, n)) except FileNotFoundError: bias = None log.info("Loading dataset.") d = Dataset() d.create(args.data_path, w_index)
def load_embeddings(self, embedding_file): embedding_matrix = load_npy(embedding_file) assert embedding_matrix.ndim == 3 # avgExp (weighted mean) return embedding_matrix