예제 #1
0
파일: data.py 프로젝트: zeta1999/RE2RNN
    def __init__(self,
                 query,
                 query_inverse,
                 lengths,
                 intent_gold,
                 intent_re,
                 re_out,
                 shots=None):
        assert len(query) == len(intent_gold)
        assert len(query) == len(intent_re)
        self.dataset = query
        self.lengths = lengths
        self.re_out = re_out
        self.dataset_inverse = query_inverse

        if (shots == None) or (shots > len(query)):
            self.intent = intent_gold
        elif shots == 0:
            self.intent = intent_re
        else:
            idxs = evan_select_from_total_number(len(query), shots)
            new_intent = np.array(intent_re)
            selected = np.array(intent_gold)[idxs].reshape(-1)
            new_intent[idxs] = selected
            self.intent = list(new_intent)
예제 #2
0
파일: data.py 프로젝트: zeta1999/RE2RNN
 def __init__(self, query, lengths, intent, shots=None):
     assert len(query) == len(intent)
     if (not shots) or (shots > len(query)):
         self.dataset = query
         self.intent = intent
         self.lengths = lengths
     else:
         # idxs = np.random.choice(np.arange(len(query)), size=int(portion* len(query)), replace=False)
         idxs = evan_select_from_total_number(len(query), shots)
         self.dataset = list(np.array(query)[idxs])
         self.intent = list(np.array(intent)[idxs])
         self.lengths = list(np.array(lengths)[idxs])
예제 #3
0
파일: data.py 프로젝트: zeta1999/RE2RNN
 def __init__(self, query, query_inverse, lengths, intent, shots=None):
     assert len(query) == len(intent)
     if (not shots) or (shots > len(query)):
         self.dataset = query
         self.dataset_inverse = query_inverse
         self.intent = intent
         self.lengths = lengths
     else:
         idxs = evan_select_from_total_number(len(query), shots)
         self.dataset = list(np.array(query)[idxs])
         self.dataset_inverse = list(np.array(query_inverse)[idxs])
         self.intent = list(np.array(intent)[idxs])
         self.lengths = list(np.array(lengths)[idxs])
예제 #4
0
파일: data.py 프로젝트: zeta1999/RE2RNN
 def __init__(self, query, lengths, intent, re_out, shots=None):
     assert len(query) == len(intent)
     if (shots == None) or (shots > len(query)):
         self.dataset = query
         self.intent = intent
         self.lengths = lengths
         self.re_out = re_out
     else:
         idxs = evan_select_from_total_number(len(query), shots)
         self.dataset = list(np.array(query)[idxs])
         self.intent = list(np.array(intent)[idxs])
         self.lengths = list(np.array(lengths)[idxs])
         self.re_out = list(np.array(re_out)[idxs])