예제 #1
0
def get_data_as_vanilla(num_samples, prefix='train'):
    try:
        sent_tokens_file = tables.open_file(os.path.join(cnt.DATA_FOLDER, cnt.SENT_TOKENS_FILE), mode='r')
        sent_tokens = sent_tokens_file.root.data
        
        random.seed(42)
        
        data_pairs = gutils.load_data_pkl(prefix + "_data_pairs.pkl")
        random.shuffle(data_pairs)

        items1, items2, labels = zip(*data_pairs)
        items1, items2, labels = np.array(items1), np.array(items2), np.array(labels)

        n = min(num_samples, len(data_pairs))

        start, end = 0, n
        
        tokens1 = [sent_tokens[i] for i in items1[start:end]]
        tokens2 = [sent_tokens[i] for i in items2[start:end]]

        sent_data_1 = gutils.get_wv_siamese(wv_model, tokens1)
        sent_data_2 = gutils.get_wv_siamese(wv_model, tokens2)

        return [sent_data_1, sent_data_2], labels[start:end]
    
    finally:
        sent_tokens_file.close()
예제 #2
0
def get_data_as_generator(num_samples, prefix='train'):
    try:
        sent_tokens_file = tables.open_file(os.path.join(cnt.DATA_FOLDER, cnt.SENT_TOKENS_FILE), mode='r')
        sent_tokens = sent_tokens_file.root.data
        
        random.seed(42)

        data_pairs = gutils.load_data_pkl(prefix + "_data_pairs.pkl")
        random.shuffle(data_pairs)

        items1, items2, labels = zip(*data_pairs)
        items1, items2, labels = np.array(items1), np.array(items2), np.array(labels)

        n = len(data_pairs)
        num_batches = int(math.ceil(float(n)/cnt.SIAMESE_BATCH_SIZE))

        batch_num = 0

        while True:
            m = batch_num % num_batches

            start, end = m*cnt.SIAMESE_BATCH_SIZE, min((m+1)*cnt.SIAMESE_BATCH_SIZE, n)
            
            tokens1 = [sent_tokens[i] for i in items1[start:end]]
            tokens2 = [sent_tokens[i] for i in items2[start:end]]
            
            sent_data_1 = gutils.get_wv_siamese(wv_model, tokens1)
            sent_data_2 = gutils.get_wv_siamese(wv_model, tokens2)
            
            batch_num += 1

            yield [sent_data_1, sent_data_2], labels[start:end]
            
    finally:
        sent_tokens_file.close()
예제 #3
0
    def get_prediction(self, sentence1, sentence2):
        sent1 = gutils.padd_fn(
            gutils.get_tokens(
                sentence1.encode("ascii", errors="ignore").decode()))
        sent2 = gutils.padd_fn(
            gutils.get_tokens(
                sentence2.encode("ascii", errors="ignore").decode()))

        return self.model.predict([
            gutils.get_wv_siamese(self.wv_model, [sent1]),
            gutils.get_wv_siamese(self.wv_model, [sent2])
        ])[0][0]
예제 #4
0
    def insert_embeddings_pytables(self):
        try:
            self.get_model()
            self.model.init_model()
            self.model.load()

            embeds_file = tables.open_file(os.path.join(
                cnt.DATA_FOLDER, cnt.SIAMESE_EMBEDDINGS_FILE),
                                           mode='w')
            atom = tables.Float32Atom()
            embeds_arr = embeds_file.create_earray(
                embeds_file.root, 'data', atom,
                (0, cnt.SIAMESE_EMBEDDING_SIZE))

            sent_tokens_file = tables.open_file(os.path.join(
                cnt.DATA_FOLDER, cnt.SENT_TOKENS_FILE),
                                                mode='r')
            sent_tokens = sent_tokens_file.root.data

            n, batch_size = len(sent_tokens), cnt.PYTABLES_INSERT_BATCH_SIZE
            num_batches = int(math.ceil(float(n) / batch_size))

            for m in range(num_batches):
                start, end = m * batch_size, min((m + 1) * batch_size, n)
                tokens_arr_input = gutils.get_wv_siamese(
                    self.wv_model, sent_tokens[start:end, :])
                embeds = self.model.get_embeddings(tokens_arr_input)
                embeds_arr.append(embeds)

        finally:
            sent_tokens_file.close()
            embeds_file.close()
예제 #5
0
 def get_representations(self, sentences):
     sentences = [
         gutils.padd_fn(
             gutils.get_tokens(
                 sentence.encode("ascii", errors="ignore").decode()))
         for sentence in sentences
     ]
     return self.model.get_embeddings(
         gutils.get_wv_siamese(self.wv_model, sentences))
예제 #6
0
 def get_representation(self, sentence):
     sent = gutils.padd_fn(
         gutils.get_tokens(
             sentence.encode("ascii", errors="ignore").decode()))
     return self.model.get_embeddings(
         gutils.get_wv_siamese(self.wv_model, [sent]))[0]