Exemplo n.º 1
0
    def forward(self, x, z):
        """Do the WAE forward step

        :param x: list of tensors of longs, embed representation of input
        :return: float, kl term component of loss
        :return: float, recon component of loss
        """

        x = lbann.Slice(x, slice_points=str_list([0, self.input_feature_dims]))
        x = lbann.Identity(x)
        x_emb = lbann.Embedding(x,
                                num_embeddings=self.dictionary_size,
                                embedding_dim=self.embedding_size,
                                name='emb',
                                weights=self.emb_weights)

        # Encoder: x -> z, kl_loss
        z_sample = self.forward_encoder(x_emb)

        eps = lbann.Gaussian(mean=self.gmean,
                             stdev=self.gstd,
                             hint_layer=z_sample)
        z_sample = lbann.Add([z_sample, eps])

        # Decoder: x, z -> recon_loss
        #pred = self.forward_decoder(x_emb, z_sample)
        pred, arg_max = self.forward_decoder(x_emb, z_sample)
        recon_loss = self.compute_loss(x, pred)

        # Hack to remove blocking GPU allreduce in evaluation layer
        #kl_loss = lbann.Identity(kl_loss, device='CPU')
        recon_loss = lbann.Identity(recon_loss, device='CPU')

        z_prior = lbann.Tessellate(
            lbann.Reshape(z, dims=str_list([1, self.zdim])),
            dims=str_list([self.input_feature_dims, self.zdim]),
        )

        d_real = self.discriminator0(
            lbann.Concatenation([x_emb, z_prior], axis=1))

        z_sample0 = lbann.Tessellate(
            lbann.Reshape(z_sample, dims=str_list([1, self.zdim])),
            dims=str_list([self.input_feature_dims, self.zdim]),
        )
        y_z_sample = lbann.Concatenation([x_emb, z_sample0], axis=1)

        d_fake = self.discriminator0(lbann.StopGradient(y_z_sample))
        d_adv = self.discriminator1(y_z_sample)  #freeze

        return recon_loss, d_real, d_fake, d_adv, arg_max
Exemplo n.º 2
0
Arquivo: vae.py Projeto: oyamay/lbann
    def forward_decoder(self, x_emb, z):
        """Decoder step, emulating x ~ G(z)

        :param x_emb: (n_batch, len(x), d_z) of floats, embeddings for input sentence x
        :param z: (n_batch, d_z) of floats, latent vector z
        :return: float, recon component of loss
        :return: list of ints, reconstructed sentence
        """

        # z_0 = z.unsqueeze(1).repeat(1, x_emb.size(1), 1)
        # x_input = torch.cat([x_emb, z_0], dim=-1)
        z_0 = lbann.Tessellate(
            lbann.Reshape(z, dims=str_list([1, 128])),
            dims=str_list([self.input_feature_dims, 128]),
        )
        x_input = lbann.Concatenation(x_emb, z_0, axis=1)

        h_0 = self.decoder_lat(z)
        # h_0 = h_0.unsqueeze(0).repeat(self.decoder_rnn.num_layers, 1, 1)
        h_0 = lbann.Reshape(h_0, dims=str_list([1, 512]))
        h_0 = lbann.Tessellate(h_0, dims=str_list((3, 512)))

        # output, _ = self.decoder_rnn(x_input, h_0)
        output = self.decoder_rnn(x_input, h_0)

        # y = self.decoder_fc(output)
        y = lbann.ChannelwiseFullyConnected(
            output,
            output_channel_dims=self.dictionary_size,
            bias=True,
            name=f'{self.decoder_fc.name}',
            weights=self.decoder_fc.weights,
        )

        # Set datatype of layers
        # Note: Depth-first search from y to x_emb and z
        stack = [y]
        in_stack = {l: True for l in stack}
        while stack:
            l = stack.pop()
            if type(l) not in (lbann.Slice, lbann.Reshape, lbann.Tessellate):
                l.datatype = self.datatype
            for parent in l.parents:
                if parent not in in_stack and parent not in (x_emb, z):
                    stack.append(parent)
                    in_stack[parent] = True

        return y
Exemplo n.º 3
0
def random_projection(indices, num_projections, projection_dim):

    # Expand input indices to get an index for each vector entry
    # Note: proj_indices(i) = index*projection_dim + i
    proj_indices = lbann.WeightedSum(
        indices,
        scaling_factors=utils.str_list(projection_dim),
    )
    iota = lbann.WeightsLayer(
        dims=utils.str_list(projection_dim),
        weights=lbann.Weights(
            initializer=lbann.ValueInitializer(
                values=utils.str_list(range(projection_dim))),
            optimizer=lbann.NoOptimizer(),
        ),
    )
    proj_indices = lbann.Sum(
        lbann.Tessellate(
            lbann.Reshape(proj_indices,
                          dims=utils.str_list([num_projections, 1])),
            dims=utils.str_list([num_projections, projection_dim]),
        ),
        lbann.Tessellate(
            lbann.Reshape(iota, dims=utils.str_list([1, projection_dim])),
            dims=utils.str_list([num_projections, projection_dim]),
        ),
    )

    # Apply hash function and convert to Gaussian distribution
    proj = lbann.UniformHash(proj_indices)
    ones = lbann.Constant(
        value=1,
        num_neurons=utils.str_list([num_projections, projection_dim]),
    )
    eps = 0.001
    proj = lbann.ErfInv(
        lbann.WeightedSum(
            proj,
            ones,
            scaling_factors=utils.str_list([2 * (1 - eps), -(1 - eps)]),
        ))
    proj = lbann.InstanceNorm(proj)
    proj = lbann.WeightedSum(
        proj,
        scaling_factors=utils.str_list(1 / projection_dim),
    )
    return proj
Exemplo n.º 4
0
 def create_position_ids_from_inputs_embeds(self, input_embeds):
     sequence_length = self.input_shape[1]
     position_ids = range(self.padding_idx + 1,
                          sequence_length + self.padding_idx + 1)
     position_ids = lbann.WeightsLayer(
         weights=lbann.Weights(
             initializer=lbann.ValueInitializer(
                 values=str_list(position_ids)),
             optimizer=lbann.NoOptimizer(),
         ),
         dims=str_list([sequence_length]),
     )
     position_ids = lbann.Reshape(position_ids,
                                  dims=str_list([1, sequence_length]))
     position_ids = lbann.Tessellate(position_ids,
                                     dims=str_list(self.input_shape[:-1]))
     return position_ids
Exemplo n.º 5
0
def make_model(
    num_epochs,
    embed_dim,
    num_heads,
    label_smoothing,
):

    # Embedding weights
    var = 2 / (embed_dim + vocab_size)  # Glorot initialization
    embedding_weights = lbann.Weights(
        name='embeddings',
        initializer=lbann.NormalInitializer(standard_deviation=math.sqrt(var)),
    )

    # Input is two sequences of token IDs
    input_ = lbann.Input(data_field='samples')

    # Get sequences of embedding vectors
    # Note: Scale embeddings by sqrt(embed_dim).
    # Note: Decoder input is shifted right, so embedding for last
    # token isn't needed.
    embeddings_tokens = lbann.Identity(
        lbann.Slice(
            input_,
            axis=0,
            slice_points=str_list([0, 2 * sequence_length - 1]),
        ))
    embeddings = lbann.Embedding(
        embeddings_tokens,
        weights=embedding_weights,
        num_embeddings=vocab_size,
        embedding_dim=embed_dim,
        padding_idx=pad_index,
    )
    embeddings = lbann.WeightedSum(
        embeddings,
        scaling_factors=str(math.sqrt(embed_dim)),
    )
    embeddings_slice = lbann.Slice(
        embeddings,
        axis=0,
        slice_points=str_list([0, sequence_length, 2 * sequence_length - 1]),
    )
    encoder_input = lbann.Identity(embeddings_slice)
    decoder_input = lbann.Identity(embeddings_slice)

    # Apply transformer model
    transformer = lbann.models.Transformer(
        hidden_size=embed_dim,
        num_heads=num_heads,
        name='transformer',
    )
    result = transformer(
        encoder_input,
        sequence_length,
        decoder_input,
        sequence_length - 1,
    )

    # Reconstruct decoder input
    preds = lbann.ChannelwiseFullyConnected(
        result,
        weights=embedding_weights,
        output_channel_dims=[vocab_size],
        bias=False,
        transpose=True,
    )
    preds = lbann.ChannelwiseSoftmax(preds)
    preds = lbann.Slice(preds,
                        axis=0,
                        slice_points=str_list(range(sequence_length)))
    preds = [lbann.Identity(preds) for _ in range(sequence_length - 1)]

    # Count number of non-pad tokens
    label_tokens = lbann.Identity(
        lbann.Slice(
            input_,
            slice_points=str_list([sequence_length + 1, 2 * sequence_length]),
        ))
    pads = lbann.Constant(value=pad_index,
                          num_neurons=str(sequence_length - 1))
    is_not_pad = lbann.NotEqual(label_tokens, pads)
    num_not_pad = lbann.Reduction(is_not_pad, mode='sum')

    # Cross entropy loss with label smoothing
    label_tokens = lbann.Slice(
        label_tokens,
        slice_points=str_list(range(sequence_length)),
    )
    label_tokens = [
        lbann.Identity(label_tokens) for _ in range(sequence_length - 1)
    ]
    if label_smoothing > 0:
        uniform_label = lbann.Constant(value=1 / vocab_size,
                                       num_neurons=str_list([1, vocab_size]))
    loss = []
    for i in range(sequence_length - 1):
        label = lbann.OneHot(label_tokens[i], size=vocab_size)
        label = lbann.Reshape(label, dims=str_list([1, vocab_size]))
        if label_smoothing > 0:
            label = lbann.WeightedSum(
                label,
                uniform_label,
                scaling_factors=str_list(
                    [1 - label_smoothing, label_smoothing]),
            )
        loss.append(lbann.CrossEntropy(preds[i], label))
    loss = lbann.Concatenation(loss)

    # Average cross entropy over non-pad tokens
    loss_scales = lbann.Divide(
        is_not_pad,
        lbann.Tessellate(num_not_pad, hint_layer=is_not_pad),
    )
    loss = lbann.Multiply(loss, loss_scales)
    loss = lbann.Reduction(loss, mode='sum')

    # Construct model
    metrics = []
    callbacks = [lbann.CallbackPrint(), lbann.CallbackTimer()]
    return lbann.Model(
        num_epochs,
        layers=lbann.traverse_layer_graph(input_),
        objective_function=loss,
        metrics=metrics,
        callbacks=callbacks,
    )