Ejemplo n.º 1
0
    def _create_hardkuma_prior_table(self,
                                     prior_param_1,
                                     max_sentence_length,
                                     l=-0.1,
                                     r=1.1,
                                     N=10000):
        """
        Creates a prior table for the HardKuma. Fixes b=1.0
        """
        kuma_priors = []
        l = torch.Tensor([l])
        r = torch.Tensor([r])

        # Create a list of (a, 1.0, CDF(0))
        with torch.no_grad():
            b = torch.Tensor([1.0])
            for a in torch.linspace(start=epsilon, end=1., steps=N):
                pk = Kumaraswamy(a, b)

                # Compute the position of 0 in the stretched distribution.
                k0 = -l / (r - l)
                kuma_priors.append((a.item(), b.item(), pk.cdf(k0).item()))
        kuma_priors = sorted(kuma_priors, key=lambda elem: elem[0])

        # Tabulate priors for every possible sentence length.
        # P(0) = 1 - (prior_param_1 / (l+1))
        self.hardkuma_prior_table = {}
        for length in range(1, max_sentence_length + 1):
            p0 = 1.0 - min(((1.0 + epsilon) / (length + 1.0)), 1.0 - epsilon)
            idx = 0
            cdf0 = float("inf")
            while cdf0 > p0 and idx != len(kuma_priors):
                a, b, cdf0 = kuma_priors[idx]
                idx += 1
            self.hardkuma_prior_table[length] = (a, b)
Ejemplo n.º 2
0
    def prior(self, seq_mask_x, seq_len_x, seq_mask_y):
        """
            Prior 1 / src_length.
        """

        prior_param_1, prior_param_2 = self.prior_params
        prior_shape = [
            seq_mask_x.size(0),
            seq_mask_y.size(1),
            seq_mask_x.size(1)
        ]

        if "bernoulli" in self.dist:
            if prior_param_1 > 0:
                # prior_param_1 words per sentence
                probs = prior_param_1 * (seq_mask_x.float() + epsilon) / (
                    seq_len_x.unsqueeze(-1).float() + 1)
                probs = torch.clamp(probs, max=(1 - 0.01))
                probs = probs.unsqueeze(1).repeat(1, seq_mask_y.size(1), 1)
            elif prior_param_2 > 0:
                # fixed prior_param_2 probability of an alignment
                probs = seq_mask_x.float().new_full(prior_shape,
                                                    fill_value=prior_param_2)
            else:
                raise Exception(
                    f"Invalid prior params for Bernoulli ({prior_param_1}, {prior_param_2})"
                )

            return BernoulliREINFORCE(probs=probs,
                                      validate_args=True)  # [B, T_y, T_x]
        elif self.dist == "concrete":
            raise NotImplementedError()
        elif self.dist in ["kuma", "hardkuma"]:

            if prior_param_1 > 0 and prior_param_2 > 0:
                p = Kumaraswamy(
                    seq_mask_x.float().new_full(prior_shape,
                                                fill_value=prior_param_1),
                    seq_mask_x.float().new_full(prior_shape,
                                                fill_value=prior_param_2))
            elif self.dist == "hardkuma" and prior_param_1 > 0:
                seq_len_numpy = seq_len_x.cpu().numpy()
                a = seq_len_x.float().new_tensor([
                    self.hardkuma_prior_table[length][0]
                    for length in seq_len_numpy
                ])  # [B]
                a = a.unsqueeze(-1).unsqueeze(-1).repeat(
                    1, seq_mask_y.size(1), seq_mask_x.size(1))
                b = torch.ones_like(a)
                p = Kumaraswamy(a, b)
            else:
                raise Exception(
                    f"Invalid Kumaraswamy parameters a={prior_param_1}, b={prior_param_2}"
                )

            if self.dist == "kuma":
                return p
            else:
                return Rectified01(Stretched(p, lower=-0.1, upper=1.1))
Ejemplo n.º 3
0
def make_kumaraswamy(inputs, event_size):
    assert inputs.size(
        -1) == 2 * event_size, "Expected [...,%d] got [...,%d]" % (
            2 * event_size, inputs.size(-1))
    params = torch.split(inputs, event_size, -1)
    return Kumaraswamy(a=torch.clamp(F.softplus(params[0]) + 1e-2, max=10.),
                       b=torch.clamp(F.softplus(params[1]) + 1e-2, max=10.))
Ejemplo n.º 4
0
    def forward(self, x, seq_mask_x, seq_len_x, y, seq_mask_y, seq_len_y):

        # Embed the source and target words.
        x_embed = self.src_embedder(x)  # [B, T_x, emb_size]
        y_embed = self.tgt_embedder(y)  # [B, T_y, emb_size]

        # Encode both sentences.
        x_enc, _ = self.src_encoder.unsorted_forward(
            x_embed)  # [B, T_x, enc_size]
        y_enc, _ = self.tgt_encoder.unsorted_forward(
            y_embed)  # [B, T_y, enc_size]
        # x_enc = x_embed
        # y_enc = y_embed

        if self.dist in ["bernoulli-RF", "bernoulli-ST", "concrete"]:

            # compute keys and queries.
            keys = self.key_layer(x_enc)  # [B, T_x, hidden_size]
            queries = self.query_layer(y_enc)  # [B, T_y, hidden_size]

            # Compute the scores as dot attention between source and target.
            logits = torch.bmm(queries, keys.transpose(1, 2))  # [B, T_y, T_x]

            if self.dist == "bernoulli-RF":
                return BernoulliREINFORCE(logits=logits, validate_args=True)
            elif self.dist == "bernoulli-ST":
                return BernoulliStraightThrough(logits=logits,
                                                validate_args=True)
            elif self.dist == "concrete":
                logits = torch.clamp(logits, -5., 5.)
                return BinaryConcrete(temperature=logits.new([1.0]),
                                      logits=logits,
                                      validate_args=True)  # TODO

        elif self.dist in ["kuma", "hardkuma"]:

            # Not supported at the moment.
            raise NotImplementedError()

            # Compute a using attention.
            keys_a = self.kuma_a_key_layer(x_enc)  # [B, T_x, hidden_size]
            queries_a = self.kuma_a_query_layer(y_enc)  # [B, T_y, hidden_size]
            a = torch.bmm(queries_a, keys_a.transpose(1, 2))  # [B, T_y, T_x]
            # a = torch.clamp(F.softplus(a) + 0.7, 1e-5, 3.) # [B, T_y, T_x]
            # a = torch.tanh(a) + 1.1 # (0.1, 2.1)
            a = 0.01 + (0.98 * torch.sigmoid(a))

            # Compute b using attention.
            # keys_b = self.kuma_b_key_layer(x_enc) # [b, t_x, hidden_size]
            # queries_b = self.kuma_b_query_layer(y_enc) # [b, t_y, hidden_size]
            # b = torch.bmm(queries_b, keys_b.transpose(1, 2)) # [B, T_y, T_x]
            # # b = torch.clamp(F.softplus(b) + 0.7, 1e-5, 3.) # [B, T_y, T_x]
            # b = torch.tanh(b) + 1.1 # (0.1, 2.1)

            # q = Kumaraswamy(a, b)
            q = Kumaraswamy(a, 1.0 - a)
            if self.dist == "kuma":
                return q
            else:
                return Rectified01(Stretched(q, lower=-0.1, upper=1.1))
        else:
            raise Exception(f"Unknown dist option: {self.dist}")
Ejemplo n.º 5
0
 def forward(self, input_features):
     a, b = self.compute_parameters(input_features)
     return Independent(
         Rectified01(Stretched(Kumaraswamy(a=a, b=b), lower=-0.1,
                               upper=1.1)), 1)
Ejemplo n.º 6
0
 def forward(self, input_features):
     a, b = self.compute_parameters(input_features)
     return Independent(Kumaraswamy(a=a, b=b), 1)