Example #1
0
    def forward_encoder(self, x_emb):
        """Encoder step, emulating z ~ E(x) = q_E(z|x)

        :param x_emb: (n_batch, len(x), d_z) of floats, embeddings for input sentence x
        :return: (n_batch, d_z) of floats, sample of latent vector z
        :return: float, kl term component of loss
        """

        # _, h = self.encoder_rnn(x, None)
        h = self.encoder_rnn(x_emb, None)

        h = lbann.Slice(
            h,
            slice_points=str_list(
                [self.input_feature_dims - 1, self.input_feature_dims]),
            axis=0,
        )
        h = lbann.Identity(h)

        mu, logvar = self.q_mu(h), self.q_logvar(h)

        # Set datatype of previous layers
        # Note: Depth-first search from mu and logvar to x_emb
        stack = [mu, logvar]
        in_stack = {l: True for l in stack}
        while stack:
            l = stack.pop()
            if type(l) not in (lbann.Slice, lbann.Reshape, lbann.Tessellate):
                l.datatype = self.datatype
            for parent in l.parents:
                if parent not in in_stack and parent is not x_emb:
                    stack.append(parent)
                    in_stack[parent] = True

        # eps = torch.randn_like(mu)
        eps = lbann.Gaussian(mean=0, stdev=1, hint_layer=mu)

        # z = mu + (logvar / 2).exp() * eps
        z = lbann.Add([
            mu,
            (lbann.Multiply([
                lbann.Exp(lbann.WeightedSum(logvar, scaling_factors='0.5')),
                eps
            ]))
        ])

        # kl_loss = 0.5 * (logvar.exp() + mu ** 2 - 1 - logvar).sum(1).mean()
        kl_loss = lbann.Reduction(
            lbann.WeightedSum(
                lbann.Exp(logvar),
                lbann.Square(mu),
                self.constant(1, hint_layer=mu),
                logvar,
                scaling_factors='0.5 0.5 -0.5 -0.5',
            ),
            mode='sum',
        )

        return z, kl_loss
Example #2
0
    def forward(self, motif_size, motif_log_embeddings):
        """Predict whether a motif is real.

        @todo Numerically accurate computation of both log(D) and
        log(1-D).

        """

        # D = 1 - exp(-sum_j(prod_i(d_ij)))
        # log(1-D) = -sum_j(exp(sum_i(log(d_ij))))
        x = lbann.MatMul(
            lbann.Constant(value=1, num_neurons=str_list([1, motif_size])),
            motif_log_embeddings,
        )
        x = lbann.Exp(x)
        x = lbann.Reduction(x, mode='sum')
        x = lbann.Negative(x)
        log_not_prob = x

        # Convert log-probability to linear space
        # Note: D=-expm1(x) is accurate when D~0. When D~1, prefer
        # 1-D=exp(x).
        prob = lbann.Negative(lbann.Expm1(log_not_prob))

        return prob, log_not_prob
Example #3
0
    def compute_loss(self, x, y):

        # y[:, :-1]
        y = lbann.Slice(
            y,
            axis=0,
            slice_points=str_list([0, self.input_feature_dims - 1]),
        )
        y = lbann.Identity(y)

        # x[:, 1:]
        x = lbann.Slice(
            x,
            slice_points=str_list([1, self.input_feature_dims]),
        )
        x = lbann.Identity(x)

        # Convert indices in x to one-hot representation
        # Note: Ignored indices result in zero vectors
        ignore_mask = lbann.Equal(
            x,
            self.constant(self.label_to_ignore, hint_layer=x),
        )
        keep_mask = lbann.LogicalNot(ignore_mask)
        length = lbann.Reduction(keep_mask, mode='sum')
        length = lbann.Max(length, self.constant(1, [1]))
        x = lbann.Add(
            lbann.Multiply(keep_mask, x),
            lbann.Multiply(ignore_mask, self.constant(-1, hint_layer=x)),
        )
        x = lbann.Slice(x,
                        slice_points=str_list(range(self.input_feature_dims)))
        x = [lbann.Identity(x) for _ in range(self.input_feature_dims - 1)]
        x = [lbann.OneHot(xi, size=self.dictionary_size) for xi in x]
        x = [
            lbann.Reshape(xi, dims=str_list([1, self.dictionary_size]))
            for xi in x
        ]
        x = lbann.Concatenation(x, axis=0)

        # recon_loss = F.cross_entropy(
        #     y[:, :-1].contiguous().view(-1, y.size(-1)),
        #     x[:, 1:].contiguous().view(-1),
        #     ignore_index=self.pad
        # )
        # Note: Ideally we'd shift y by y.max(-1) for numerical stability
        shifts = lbann.MatMul(
            lbann.Max(y, self.constant(0, hint_layer=y)),
            self.constant(
                1 / math.sqrt(self.dictionary_size),
                [self.dictionary_size, self.dictionary_size],
            ),
        )
        y = lbann.Subtract(y, shifts)
        z = lbann.MatMul(
            lbann.Exp(y),
            self.constant(1, [self.dictionary_size, 1]),
        )
        z = lbann.Log(z)
        z = lbann.MatMul(
            lbann.Reshape(keep_mask, dims=str_list([1, -1])),
            z,
        )
        recon_loss = lbann.MatMul(
            lbann.Reshape(y, dims=str_list([1, -1])),
            lbann.Reshape(x, dims=str_list([1, -1])),
            transpose_b=True,
        )
        recon_loss = lbann.Subtract(z, recon_loss)
        recon_loss = lbann.Reshape(recon_loss, dims=str_list([1]))
        recon_loss = lbann.Divide(recon_loss, length)

        return recon_loss
Example #4
0
encode3neuron = lbann.Relu(encode3, name="encode3neuron")

# Latent space
mu = lbann.FullyConnected(encode3neuron,
                          name="mu",
                          num_neurons=30,
                          has_bias=True)

logsd = lbann.FullyConnected(encode3,
                             name="logsd",
                             num_neurons=30,
                             has_bias=True)

# KL divergence
sd = lbann.Exp(logsd, name="sd")

var = lbann.Square(sd, name="var")

meansq = lbann.Square(mu, name="meansq")

kldiv_plus_half = lbann.WeightedSum([meansq, var, logsd],
                                    name="kldiv_plus_half",
                                    scaling_factors='0.5 0.5 -1')

kldiv_full = lbann.Rsqrt(kldiv_plus_half, name="kldiv_full")

kldiv = lbann.Reduction(kldiv_full, name="kldiv", mode="sum")

# Generate sample
noise = lbann.Gaussian(name="noise", mean=0, stdev=1, hint_layer=mu)
Example #5
0
    def compute_loss(self, x, y):

        # y[:, :-1]
        y = lbann.Slice(
            y,
            axis=0,
            slice_points=str_list([0, self.input_feature_dims-1]),
        )
        y = lbann.Identity(y)

        # x[:, 1:]
        x = lbann.Slice(
            x,
            slice_points=str_list([1, self.input_feature_dims]),
        )
        x = lbann.Identity(x)

        # Figure out entries in x to ignore
        ignore_mask = lbann.Equal(
            x,
            self.constant(self.label_to_ignore, hint_layer=x),
        )
        keep_mask = lbann.LogicalNot(ignore_mask)
        length = lbann.Reduction(keep_mask, mode='sum')
        length = lbann.Max(length, self.constant(1, [1]))

        # Convert entries in x to indices in y
        # Note: Ignored entries correspond to an index of -1.
        offsets = [
            row*self.dictionary_size
            for row in range(self.input_feature_dims-1)
        ]
        offsets = lbann.Weights(
            initializer=lbann.ValueInitializer(values=str_list(offsets)),
            optimizer=lbann.NoOptimizer(),
        )
        offsets = lbann.WeightsLayer(
            dims=str_list([self.input_feature_dims-1]),
            weights=offsets,
        )
        y_inds = lbann.Add(x, offsets)
        y_inds = lbann.Add(
            lbann.Multiply(keep_mask, y_inds),
            lbann.Multiply(
                ignore_mask,
                self.constant(-1, hint_layer=y_inds),
            ),
        )

        # recon_loss = F.cross_entropy(
        #     y[:, :-1].contiguous().view(-1, y.size(-1)),
        #     x[:, 1:].contiguous().view(-1),
        #     ignore_index=self.pad
        # )

        # Shift y for numerical stability
        # Note: We'd prefer to shift by y.max(-1)
        shifts = lbann.MatMul(
            lbann.Max(y, self.constant(0, hint_layer=y)),
            self.constant(
                1 / math.sqrt(self.dictionary_size),
                [self.dictionary_size, self.dictionary_size],
            ),
        )
        y = lbann.Subtract(y, shifts)

        # Compute log of softmax denominator and sum
        z = lbann.MatMul(
            lbann.Exp(y),
            self.constant(1, [self.dictionary_size, 1]),
        )
        z = lbann.Log(z)
        z = lbann.MatMul(
            lbann.Reshape(keep_mask, dims=str_list([1, -1])),
            z,
        )
        z = lbann.Reshape(z, dims=str_list([1]))

        # Compute cross entropy
        recon_loss = lbann.Gather(
            lbann.Reshape(y, dims=str_list([-1])),
            y_inds,
        )
        recon_loss = lbann.Reduction(recon_loss, mode='sum')
        recon_loss = lbann.Subtract(z, recon_loss)
        recon_loss = lbann.Divide(recon_loss, length)

        return recon_loss