예제 #1
0
 def __iter__(self) -> MetricLearningBatch:
     for df in super().__iter__():
         lang1 = df['lang1'].values
         lang2 = df['lang2'].values
         normalized_score = get_tensor(df[self.cats].values.astype('float32'))
         dist = get_tensor(df['dist'].values.astype('float32'))
         return MetricLearningBatch(lang1, lang2, normalized_score, dist).cuda()
예제 #2
0
파일: lm_model.py 프로젝트: esantus/xib
    def score(self, batch) -> Dict[Cat, FT]:
        distr = self(batch)
        scores = dict()
        for name, output in distr.items():
            i = get_index(name, new_style=self.new_style)
            target = batch.target_feat[:, i]
            weight = batch.target_weight[:, i]

            if self.weighted_loss == '':
                log_probs = gather(output, target)
                score = -log_probs
            else:
                e = get_new_style_enum(i)
                mat = get_tensor(e.get_distance_matrix())
                mat = mat[target.rename(None)]
                if self.weighted_loss == 'mr':
                    mat_exp = torch.where(mat > 0, (mat + 1e-8).log(),
                                          get_zeros(mat.shape).fill_(-99.9))
                    logits = mat_exp + output
                    # NOTE(j_luo) For the categories except Ptype, the sums of probs are not 1.0 (they are conditioned on certain values of Ptyle).
                    # As a result, we need to incur penalties based on the remaining prob mass as well.
                    # Specifically, the remaining prob mass will result in a penalty of 1.0, which is e^(0.0).
                    none_probs = (
                        1.0 -
                        output.exp().sum(dim=-1, keepdims=True)).clamp(min=0.0)
                    none_penalty = (1e-8 + none_probs).log().align_as(output)
                    logits = torch.cat([logits, none_penalty], dim=-1)
                    score = torch.logsumexp(logits, dim=-1).exp()
                elif self.weighted_loss == 'ot':
                    if not self.training:
                        raise RuntimeError('Cannot use OT for training.')

                    probs = output.exp()
                    # We have to incur penalties based on the remaining prob mass as well.
                    none_probs = (1.0 -
                                  probs.sum(dim=-1, keepdims=True)).clamp(
                                      min=0.0)
                    mat = torch.cat([
                        mat,
                        get_tensor(torch.ones_like(none_probs.rename(None)))
                    ],
                                    dim=-1)
                    probs = torch.cat([probs, none_probs], dim=-1)
                    score = (mat * probs).sum(dim=-1)
                else:
                    raise ValueError(f'Cannot recognize {self.weighted_loss}.')
            scores[name] = (score, weight)
        return scores
예제 #3
0
파일: modules.py 프로젝트: esantus/xib
 def __init__(self, feat_emb_name, group_name, char_emb_name, num_features, dim, feat_groups, num_feature_groups, new_style):
     super().__init__()
     self.embed_layer = self._get_embeddings()
     self.register_buffer('c_idx', get_tensor(get_effective_c_idx(feat_groups)).refine_names('chosen_feat_group'))
     cat_enum_pairs = get_needed_categories(feat_groups, new_style=new_style, breakdown=new_style)
     if new_style:
         self.effective_num_feature_groups = sum([e.num_groups() for e in cat_enum_pairs])
         simple_conversions = np.zeros([num_features], dtype='int64')
         max_len = max(len(new_feat.value) for new_feat in conversions.values() if new_feat.value.is_complex())
         complex_conversions = np.zeros([num_features, max_len], dtype='int64')
         for old_feat, new_feat in conversions.items():
             if new_feat.value.is_complex():
                 l = len(new_feat.value)
                 complex_conversions[old_feat.value.g_idx, :l] = [x.value.g_idx for x in new_feat.value]
             else:
                 simple_conversions[old_feat.value.g_idx] = new_feat.value.g_idx
         self.simple_conversions = get_tensor(simple_conversions)
         self.complex_conversions = get_tensor(complex_conversions)
     else:
         self.effective_num_feature_groups = len(cat_enum_pairs)
예제 #4
0
파일: modules.py 프로젝트: esantus/xib
 def __init__(self, hidden_size, feat_groups, new_style):
     super().__init__()
     self.linear = nn.Linear(hidden_size, hidden_size)
     self.feat_predictors = nn.ModuleDict()
     for e in get_needed_categories(feat_groups, new_style=new_style, breakdown=new_style):
         # NOTE(j_luo) ModuleDict can only handle str as keys.
         self.feat_predictors[e.__name__] = nn.Linear(hidden_size, len(e))
     # If new_style, we need to get the necessary indices to convert the breakdown groups into the original feature groups.
     if new_style:
         self.conversion_idx = dict()
         for e in get_needed_categories(self.feat_groups, new_style=True, breakdown=False):
             if e.num_groups() > 1:
                 cat_idx = list()
                 for feat in e:
                     feat_cat_idx = list()
                     feat = feat.value
                     for basic_feat in feat:
                         auto_index = basic_feat.value
                         feat_cat_idx.append(auto_index.f_idx)
                     cat_idx.append(feat_cat_idx)
                 cat_idx = get_tensor(cat_idx).refine_names('new_style_idx', 'old_style_idx')
                 self.conversion_idx[e.__name__] = cat_idx
예제 #5
0
 def __post_init__(self):
     self.normalized_score = get_tensor(self.normalized_score)  # .refine_names('batch', 'feat_group')
     self.dist = get_tensor(self.dist)  # .refine_names('batch')
예제 #6
0
파일: verifier.py 프로젝트: j-luo93/XLM
    def get_graph_target(self,
                         data: Tensor,
                         lengths: Tensor,
                         lang: str,
                         split: str,
                         indices: List[int],
                         permutations: List[np.ndarray] = None,
                         keep: np.ndarray = None) -> GraphData:
        # NOTE(j_luo)  If for some reason the first one is <s> or </s>, we need to offset the indices.
        max_len = max(lengths)
        if self.ae_noise_graph_mode == 'change':
            assert permutations is not None and keep is not None

        offsets = ((data[0] == self.dico.eos_index) |
                   (data[0] == self.dico.bos_index)).long()
        graphs = self.graphs[(lang, split)]
        graphs = [graphs[i] for i in indices]
        bs = len(graphs)
        if len(offsets) != bs:
            raise RuntimeError('Something is terribly wrong.')

        ijkv = list()
        connected_vertices = get_zeros(len(graphs), max_len).bool()
        for batch_i, graph in enumerate(graphs):
            offset = offsets[batch_i].item()

            assert self.ae_noise_graph_mode != 'change', 'connected vertices cannot handle change for now.'
            vertices = np.asarray(graph.connected_vertices) + offset
            connected_vertices[batch_i, vertices] = True
            if offset > 0:
                connected_vertices[batch_i, 0] = True
            length = lengths[batch_i].item() - 1
            connected_vertices[batch_i, length] = True

            # Repeat the permutation and dropout processes and change the graph accordingly.
            if self.ae_noise_graph_mode == 'change':
                perm = permutations[batch_i].argsort()
                perm = np.arange(len(perm))[perm]
            for e in graph.edges:
                u = e.u + offset
                v = e.v + offset
                assert u < max_len and v < max_len
                if self.ae_noise_graph_mode == 'change':
                    u = perm[e.u]
                    v = perm[e.v]
                    if keep[u, batch_i] and keep[v, batch_i]:
                        ijkv.append((batch_i, u, v, e.t.value))
                else:
                    ijkv.append((batch_i, u, v, e.t.value))

        i, j, k, v = zip(*ijkv)
        v = get_tensor(v)
        edge_norm = get_zeros([bs, max_len, max_len])
        edge_type = get_zeros([bs, max_len, max_len]).long()
        # NOTE(j_luo) Edges are symmetric.
        edge_norm[i, j, k] = 1.0
        edge_norm[i, k, j] = 1.0
        edge_type[i, j, k] = v
        edge_type[i, k, j] = v
        edge_norm = edge_norm.view(-1)
        edge_type = edge_type.view(-1)
        return GraphData(None, None, edge_norm, edge_type, connected_vertices)