Exemplo n.º 1
0
 def __init__(self, hidden_size: Optional[int] = None):
     hidden_size = hidden_size or g.hidden_size
     super().__init__()
     self.linear = nn.Linear(hidden_size, hidden_size)
     self.feat_predictors = nn.ModuleDict()
     for e in get_needed_categories(g.feat_groups,
                                    new_style=g.new_style,
                                    breakdown=g.new_style):
         # NOTE(j_luo) ModuleDict can only handle str as keys.
         self.feat_predictors[e.__name__] = nn.Linear(hidden_size, len(e))
     # If new_style, we need to get the necessary indices to convert the breakdown groups into the original feature groups.
     if g.new_style:
         self.conversion_idx = dict()
         for e in get_needed_categories(g.feat_groups,
                                        new_style=True,
                                        breakdown=False):
             if e.num_groups() > 1:
                 cat_idx = list()
                 for feat in e:
                     feat_cat_idx = list()
                     feat = feat.value
                     for basic_feat in feat:
                         auto_index = basic_feat.value
                         feat_cat_idx.append(auto_index.f_idx)
                     cat_idx.append(feat_cat_idx)
                 cat_idx = get_tensor(cat_idx).refine_names(
                     'new_style_idx', 'old_style_idx')
                 self.conversion_idx[e.__name__] = cat_idx
Exemplo n.º 2
0
    def forward(self, h: FT) -> Dict[str, FT]:
        shared_h = nn.functional.leaky_relu(self.linear(h).refine_names(
            ..., 'shared_repr'),
                                            negative_slope=0.1)
        ret = dict()
        for name, layer in self.feat_predictors.items():
            out = layer(shared_h).refine_names(..., name)
            if not should_predict_none(name, new_style=g.new_style):
                f_idx = get_none_index(name)
                out[:, f_idx] = -999.9
            ret[Name(name, 'camel')] = out

        # Compose probs for complex feature groups if possible.
        if g.new_style:
            for e in get_needed_categories(g.feat_groups,
                                           new_style=True,
                                           breakdown=False):
                if e.num_groups() > 1:
                    assert e not in ret
                    part_tensors = [
                        ret[part_enum.get_name()] for part_enum in e.parts()
                    ]
                    parts = list()
                    for i, part_tensor in enumerate(part_tensors):
                        conversion = self.conversion_idx[e.get_name().value][:,
                                                                             i]
                        bs = len(part_tensor)
                        part = part_tensor.rename(None).gather(
                            1,
                            conversion.rename(None).expand(bs, -1))
                        parts.append(part)
                    parts = torch.stack(parts, dim=-1)
                    dim_name = e.get_name().value
                    ret[e.get_name()] = parts.sum(dim=-1).refine_names(
                        'batch', dim_name)
                    for part_cat in e.parts():
                        del ret[part_cat.get_name()]
        for name in ret:
            ret[name] = torch.log_softmax(ret[name], dim=-1)

        # Deal with conditions for some categories
        for cat, index in conditions.items():
            if should_include(g.feat_groups, cat):
                # Find out the exact value to be conditioned on.
                # TODO(j_luo) ugly Category call.
                condition_e = get_enum_by_cat(Category(index.c_idx))
                condition_name = condition_e.__name__ + ('X' if g.new_style
                                                         else '')
                cat_name = get_enum_by_cat(cat).__name__ + ('X' if g.new_style
                                                            else '')

                condition_name = Name(condition_name, 'camel')
                cat_name = Name(cat_name, 'camel')
                condition_log_probs = ret[condition_name][..., index.f_idx]
                # condition_log_probs.align_as(ret[cat_name])
                ret[cat_name] = ret[cat_name] + condition_log_probs.rename(
                    None).unsqueeze(dim=-1)
        return ret
Exemplo n.º 3
0
 def __init__(self, feat_emb_name, group_name, char_emb_name, num_features, dim, feat_groups, num_feature_groups, new_style):
     super().__init__()
     self.embed_layer = self._get_embeddings()
     self.register_buffer('c_idx', get_tensor(get_effective_c_idx(feat_groups)).refine_names('chosen_feat_group'))
     cat_enum_pairs = get_needed_categories(feat_groups, new_style=new_style, breakdown=new_style)
     if new_style:
         self.effective_num_feature_groups = sum([e.num_groups() for e in cat_enum_pairs])
         simple_conversions = np.zeros([num_features], dtype='int64')
         max_len = max(len(new_feat.value) for new_feat in conversions.values() if new_feat.value.is_complex())
         complex_conversions = np.zeros([num_features, max_len], dtype='int64')
         for old_feat, new_feat in conversions.items():
             if new_feat.value.is_complex():
                 l = len(new_feat.value)
                 complex_conversions[old_feat.value.g_idx, :l] = [x.value.g_idx for x in new_feat.value]
             else:
                 simple_conversions[old_feat.value.g_idx] = new_feat.value.g_idx
         self.simple_conversions = get_tensor(simple_conversions)
         self.complex_conversions = get_tensor(complex_conversions)
     else:
         self.effective_num_feature_groups = len(cat_enum_pairs)