Ejemplo n.º 1
0
    def forward(self, x):
        """Perform LSTM step.

        State from previous steps is used to compute output.

        """
        self.step += 1
        name = '{0}_step{1}'.format(self.name, self.step)

        # Apply linearity
        input_concat = lbann.Concatenation([x, self.last_output],
                                           name=name + '_input',
                                           data_layout=self.data_layout)
        fc = self.fc(input_concat)

        # Get gates and cell update
        slice = lbann.Slice(fc,
                            slice_points=_str_list([0, self.size, 4*self.size]),
                            name=name + '_fc_slice',
                            data_layout=self.data_layout)
        cell_update = lbann.Tanh(slice,
                                 name=name + '_cell_update',
                                 data_layout=self.data_layout)
        sigmoid = lbann.Sigmoid(slice,
                                name=name + '_sigmoid',
                                data_layout=self.data_layout)
        slice = lbann.Slice(sigmoid,
                            slice_points=_str_list([0, self.size, 2*self.size, 3*self.size]),
                            name=name + '_sigmoid_slice',
                            data_layout=self.data_layout)
        f = lbann.Identity(slice, name=name + '_forget_gate',
                           data_layout=self.data_layout)
        i = lbann.Identity(slice, name=name + '_input_gate',
                           data_layout=self.data_layout)
        o = lbann.Identity(slice, name=name + '_output_gate',
                           data_layout=self.data_layout)

        # Cell state
        cell_forget = lbann.Multiply([f, self.last_cell],
                                     name=name + '_cell_forget',
                                     data_layout=self.data_layout)
        cell_input = lbann.Multiply([i, cell_update],
                                    name=name + '_cell_input',
                                    data_layout=self.data_layout)
        cell = lbann.Add([cell_forget, cell_input], name=name + '_cell',
                         data_layout=self.data_layout)

        # Output
        cell_act = lbann.Tanh(cell, name=name + '_cell_activation',
                              data_layout=self.data_layout)
        output = lbann.Multiply([o, cell_act], name=name,
                                data_layout=self.data_layout)

        # Update state and return output
        self.last_cell = cell
        self.last_output = output
        return output
Ejemplo n.º 2
0
    def forward(self, x):
        """Do the VAE forward step

        :param x: list of tensors of longs, embed representation of input
        :return: float, kl term component of loss
        :return: float, recon component of loss
        """

        x = lbann.Slice(x, slice_points=str_list([0, self.input_feature_dims]))
        x = lbann.Identity(x)
        x_emb = lbann.Embedding(x,
                                num_embeddings=self.dictionary_size,
                                embedding_dim=self.embedding_size,
                                name='emb',
                                weights=self.emb_weights)

        # Encoder: x -> z, kl_loss
        z, kl_loss = self.forward_encoder(x_emb)

        # Decoder: x, z -> recon_loss
        pred = self.forward_decoder(x_emb, z)
        recon_loss = self.compute_loss(x, pred)

        # Hack to remove blocking GPU allreduce in evaluation layer
        kl_loss = lbann.Identity(kl_loss, device='CPU')
        recon_loss = lbann.Identity(recon_loss, device='CPU')

        return kl_loss, recon_loss
Ejemplo n.º 3
0
def make_model(
        motif_size,
        walk_length,
        num_vertices,
        embed_dim,
        learn_rate,
        num_epochs,
        embeddings_dir,
):

    # Layer graph
    input_ = lbann.Slice(
        lbann.Input(data_field='samples'),
        slice_points=str_list([0, motif_size, motif_size+walk_length]),
    )
    motif_indices = lbann.Identity(input_)
    walk_indices = lbann.Identity(input_)
    gan = model.gan.CommunityGAN(
        num_vertices,
        motif_size,
        embed_dim,
        learn_rate,
    )
    loss, real_disc_prob, fake_disc_prob, gen_prob = gan(
        motif_indices,
        motif_size,
        walk_indices,
        walk_length,
    )

    # Metrics
    metrics = [
        lbann.Metric(real_disc_prob, name='D(real)'),
        lbann.Metric(fake_disc_prob, name='D(fake)'),
        lbann.Metric(gen_prob, name='G'),
    ]

    # Callbacks
    callbacks = [
        lbann.CallbackPrint(),
        lbann.CallbackTimer(),
        lbann.CallbackDumpWeights(directory=embeddings_dir,
                                  epoch_interval=num_epochs),
    ]

    # Perform computation at double precision
    for l in lbann.traverse_layer_graph(input_):
        l.datatype = lbann.DataType.DOUBLE
        for w in l.weights:
            w.datatype = lbann.DataType.DOUBLE

    # Contruct model
    return lbann.Model(
        num_epochs,
        layers=lbann.traverse_layer_graph(input_),
        objective_function=loss,
        metrics=metrics,
        callbacks=callbacks,
    )
Ejemplo n.º 4
0
    def forward_encoder(self, x_emb):
        """Encoder step, emulating z ~ E(x) = q_E(z|x)

        :param x_emb: (n_batch, len(x), d_z) of floats, embeddings for input sentence x
        :return: (n_batch, d_z) of floats, sample of latent vector z
        :return: float, kl term component of loss
        """

        # _, h = self.encoder_rnn(x, None)
        h = self.encoder_rnn(x_emb, None)

        h = lbann.Slice(
            h,
            slice_points=str_list(
                [self.input_feature_dims - 1, self.input_feature_dims]),
            axis=0,
        )
        h = lbann.Identity(h)

        mu, logvar = self.q_mu(h), self.q_logvar(h)

        # Set datatype of previous layers
        # Note: Depth-first search from mu and logvar to x_emb
        stack = [mu, logvar]
        in_stack = {l: True for l in stack}
        while stack:
            l = stack.pop()
            if type(l) not in (lbann.Slice, lbann.Reshape, lbann.Tessellate):
                l.datatype = self.datatype
            for parent in l.parents:
                if parent not in in_stack and parent is not x_emb:
                    stack.append(parent)
                    in_stack[parent] = True

        # eps = torch.randn_like(mu)
        eps = lbann.Gaussian(mean=0, stdev=1, hint_layer=mu)

        # z = mu + (logvar / 2).exp() * eps
        z = lbann.Add([
            mu,
            (lbann.Multiply([
                lbann.Exp(lbann.WeightedSum(logvar, scaling_factors='0.5')),
                eps
            ]))
        ])

        # kl_loss = 0.5 * (logvar.exp() + mu ** 2 - 1 - logvar).sum(1).mean()
        kl_loss = lbann.Reduction(
            lbann.WeightedSum(
                lbann.Exp(logvar),
                lbann.Square(mu),
                self.constant(1, hint_layer=mu),
                logvar,
                scaling_factors='0.5 0.5 -0.5 -0.5',
            ),
            mode='sum',
        )

        return z, kl_loss
Ejemplo n.º 5
0
 def forward(self, _):
     w = lbann.WeightsLayer(weights=self.weights,
                            dims='%d %d'.format(self.width, self.height))
     slice = lbann.Slice(w,
                         axis=0,
                         slice_points=' '.join(range(self.width + 1)))
     cols = []
     for _ in range(self.width):
         cols.append(lbann.Sqrt(lbann.L2Norm2(slice)))
     return lbann.Sum(cols)
Ejemplo n.º 6
0
def construct_model():
    """Model description

    """
    import lbann
    import lbann.modules

    fc = lbann.modules.FullyConnectedModule
    conv = lbann.modules.Convolution2dModule

    conv1 = conv(20, 3, stride=1, padding=1, name='conv1')
    conv2 = conv(20, 3, stride=1, padding=1, name='conv2')
    fc1 = fc(100, name='fc1')
    fc2 = fc(20, name='fc2')
    fc3 = fc(num_classes, name='fc3')
    # Layer graph
    input = lbann.Input(name='inp_tensor', target_mode='classification')
    inp_slice = lbann.Slice(input,
                            axis=0,
                            slice_points=str_list([0, dims - 1, dims]),
                            name='inp_slice')
    xdata = lbann.Identity(inp_slice)
    ylabel = lbann.Identity(inp_slice, name='gt_y')
    #NHWC to NCHW
    x = lbann.Reshape(xdata, dims='14 13 13')
    x = conv2(conv1(x))
    x = lbann.Reshape(x, dims='3380')
    x = lbann.Dropout(lbann.Relu(fc1(x)), keep_prob=0.5)
    x = lbann.Dropout(fc2(x), keep_prob=0.5)
    pred = lbann.Softmax(fc3(x))
    gt_label = lbann.OneHot(ylabel, size=num_classes)
    loss = lbann.CrossEntropy([pred, gt_label], name='loss')
    acc = lbann.CategoricalAccuracy([pred, gt_label])

    layers = list(lbann.traverse_layer_graph(input))
    # Setup objective function
    weights = set()
    for l in layers:
        weights.update(l.weights)
    obj = lbann.ObjectiveFunction(loss)

    callbacks = [lbann.CallbackPrint(), lbann.CallbackTimer()]

    # Construct model
    num_epochs = 10
    return lbann.Model(num_epochs,
                       weights=weights,
                       layers=layers,
                       metrics=[lbann.Metric(acc, name='accuracy', unit='%')],
                       objective_function=obj,
                       callbacks=callbacks)
Ejemplo n.º 7
0
    def forward(self, x, z):
        """Do the WAE forward step

        :param x: list of tensors of longs, embed representation of input
        :return: float, kl term component of loss
        :return: float, recon component of loss
        """

        x = lbann.Slice(x, slice_points=str_list([0, self.input_feature_dims]))
        x = lbann.Identity(x)
        x_emb = lbann.Embedding(x,
                                num_embeddings=self.dictionary_size,
                                embedding_dim=self.embedding_size,
                                name='emb',
                                weights=self.emb_weights)

        # Encoder: x -> z, kl_loss
        z_sample = self.forward_encoder(x_emb)

        eps = lbann.Gaussian(mean=self.gmean,
                             stdev=self.gstd,
                             hint_layer=z_sample)
        z_sample = lbann.Add([z_sample, eps])

        # Decoder: x, z -> recon_loss
        #pred = self.forward_decoder(x_emb, z_sample)
        pred, arg_max = self.forward_decoder(x_emb, z_sample)
        recon_loss = self.compute_loss(x, pred)

        # Hack to remove blocking GPU allreduce in evaluation layer
        #kl_loss = lbann.Identity(kl_loss, device='CPU')
        recon_loss = lbann.Identity(recon_loss, device='CPU')

        z_prior = lbann.Tessellate(
            lbann.Reshape(z, dims=str_list([1, self.zdim])),
            dims=str_list([self.input_feature_dims, self.zdim]),
        )

        d_real = self.discriminator0(
            lbann.Concatenation([x_emb, z_prior], axis=1))

        z_sample0 = lbann.Tessellate(
            lbann.Reshape(z_sample, dims=str_list([1, self.zdim])),
            dims=str_list([self.input_feature_dims, self.zdim]),
        )
        y_z_sample = lbann.Concatenation([x_emb, z_sample0], axis=1)

        d_fake = self.discriminator0(lbann.StopGradient(y_z_sample))
        d_adv = self.discriminator1(y_z_sample)  #freeze

        return recon_loss, d_real, d_fake, d_adv, arg_max
Ejemplo n.º 8
0
    def forward(
        self,
        motif_indices,
        motif_size,
        walk_indices,
        walk_length,
    ):

        # Apply generator
        fake_motif_indices, gen_prob, gen_log_prob = self.generator(
            walk_length,
            walk_indices,
            motif_size,
        )

        # Get discriminator embeddings in log-space
        all_motif_indices = lbann.Concatenation(motif_indices,
                                                fake_motif_indices)
        all_motif_log_embeddings = self.discriminator.get_log_embeddings(
            all_motif_indices)
        all_motif_log_embeddings = lbann.Slice(
            all_motif_log_embeddings,
            slice_points=str_list([0, motif_size, 2 * motif_size]),
        )
        real_motif_log_embeddings = lbann.Identity(all_motif_log_embeddings)
        fake_motif_log_embeddings = lbann.Identity(all_motif_log_embeddings)

        # Apply discriminator
        real_disc_prob, real_disc_log_not_prob \
            = self.discriminator(motif_size, real_motif_log_embeddings)
        fake_disc_prob, fake_disc_log_not_prob \
            = self.discriminator(motif_size, fake_motif_log_embeddings)

        # Loss function
        # L_disc = - log(D(real)) - log(1-D(fake))
        # L_gen = - log(G) * stop_gradient(log(1-D(fake)))
        real_disc_log_prob \
            = lbann.Log(lbann.Clamp(real_disc_prob, min=1e-37, max=1))
        disc_loss = lbann.WeightedSum(
            real_disc_log_prob,
            fake_disc_log_not_prob,
            scaling_factors=str_list([-1, -1]),
        )
        gen_loss = lbann.Multiply(
            gen_log_prob,
            lbann.StopGradient(fake_disc_log_not_prob),
        )
        loss = lbann.Add(disc_loss, gen_loss)

        return loss, real_disc_prob, fake_disc_prob, gen_prob
Ejemplo n.º 9
0
 def forward(self, x):
     x_slice = lbann.Slice(x,
                           axis=0,
                           slice_points="0 921 4750 8579",
                           name='inp_slice')
     gene = self.geneT(lbann.Identity(x_slice))
     drug1 = self.drug1T(lbann.Identity(x_slice))
     drug2 = self.drug2T(lbann.Identity(x_slice))
     concat = self.concatT(
         lbann.Concatenation([gene, drug1, drug2],
                             name=self.name + 'concat'))
     response_fc = lbann.FullyConnected(concat,
                                        num_neurons=1,
                                        has_bias=True)
     return response_fc
Ejemplo n.º 10
0
    def matrix_to_graph(cls, mat_layer, num_vertices, num_features):
        """Given a 2D matrix of shape (num_vertices, num_features), returns a 
           GraphVertexData object with num_vertices number of nodes with num_features. 
           
        """

        slice_points = str_list([i for i in range(0,num_vertices * num_features + 1, num_features)])
        flattened_layer = lbann.Reshape(mat_layer, dims = str(num_vertices * num_features))
        sliced_mat_layer = lbann.Slice(flattened_layer, axis = 0, slice_points = slice_points)

        list_of_layers = []
        for node in range(num_vertices):
            temp = lbann.Identity(sliced_mat_layer)
            list_of_layers.append(lbann.Reshape(temp, dims=str_list([1, num_features])))
        return cls(list_of_layers, num_features)
Ejemplo n.º 11
0
 def encoder_cnn(self, y):
     img_sca = lbann.Slice(y,
                           axis=0,
                           slice_points="0 16384 16399",
                           name=self.name + '_y_slice')
     #assume C first, is data C first?
     img = lbann.Reshape(img_sca,
                         dims='4 64 64',
                         name=self.name + 'enc_reshape0')
     x = self.enc_conv[2](self.enc_conv[1](self.enc_conv[0](img)))
     x = lbann.Reshape(x,
                       dims=str(16 * 8 * 8),
                       name=self.name + 'enc_reshape1')
     h_stack = lbann.Concatenation([x, img_sca], axis=0)
     z = self.enc_out(h_stack)
     return z
Ejemplo n.º 12
0
def Graph_Data_Parser(_lbann_input_,
                      num_nodes,
                      node_feature_size,
                      max_edges,
                      num_classes=1):
    """ A parser for graph structured data with node
        features, source and target node indices (COO)
        format, and a target

    Args:
        _lbann_input_ (Layer): The input layer of the LBANN model
        num_nodes (int): The maximum number of nodes in the dataset
        node_features_size (int): The dimensionality of the node features matrix
        max_edges (int): The maximum number of edges in the dataset
        num_classes (int): The number of classes in the target or 1 for
                           regression (default : 1)
    Returns:
        (dictionary) Returns a dictionary with the keys: node_features, source_indices,
                     target_indices, and targets
    """
    slice_points = [
        0, num_nodes * node_feature_size, max_edges, max_edges, num_classes
    ]
    shifted_slice_points = list(accumulate(slice_points))
    sliced_input = lbann.Slice(_lbann_input_,
                               slice_points=str_list(shifted_slice_points),
                               name="Sliced_Graph_Input")
    node_features = lbann.Reshape(lbann.Identity(sliced_input),
                                  dims=str_list([num_nodes,
                                                 node_feature_size]),
                                  name="Node_Feature_Matrix")
    source_indices = lbann.Identity(sliced_input)
    target_indices = lbann.Identity(sliced_input)
    targets = lbann.Identity(sliced_input)

    graph_data = {
        "node_features": node_features,
        "source_indices": source_indices,
        "target_indices": target_indices,
        "target": targets
    }
    return graph_data
Ejemplo n.º 13
0
 def forward(self, hidden_states):
     # We "pool" the model by simply taking the hidden state corresponding
     # to the first token.
     first_token_tensor = lbann.Slice(hidden_states,
                                      axis=1,
                                      slice_points=str_list([0, 1]))
     pooled_output = lbann.modules.PytorchLinear(
         first_token_tensor,
         (self.input_shape[0], self.input_shape[-1]),
         self.hidden_size,
         weights=_load_pretrained_weights(
             ".".join((self.name, "dense.weight")),
             ".".join((self.name, "dense.bias")),
             load_weights=self.load_weights,
         ),
         name=".".join((self.name, "dense")),
     )
     pooled_output = lbann.Tanh(pooled_output,
                                name=".".join((self.name, "activation")))
     return pooled_output
Ejemplo n.º 14
0
    def forward_encoder(self, x_emb):
        """Encoder step, emulating z ~ E(x) = q_E(z|x)

        :param x_emb: (n_batch, len(x), d_z) of floats, embeddings for input sentence x
        :return: (n_batch, d_z) of floats, sample of latent vector z
        :return: float, kl term component of loss
        """

        # _, h = self.encoder_rnn(x, None)
        h = self.encoder_rnn(x_emb, None)

        h = lbann.Slice(
            h,
            slice_points=str_list(
                [self.input_feature_dims - 1, self.input_feature_dims]),
            axis=0,
        )
        h = lbann.Identity(h)
        z = self.q_mu(h)
        return z
Ejemplo n.º 15
0
def setup(num_patches=3,
          mini_batch_size=512,
          num_epochs=75,
          learning_rate=0.005,
          bn_statistics_group_size=2,
          fc_data_layout='model_parallel',
          warmup=True,
          checkpoint_interval=None):

    # Data dimensions
    patch_dims = patch_generator.patch_dims
    num_labels = patch_generator.num_labels(num_patches)

    # Extract tensors from data sample
    input = lbann.Input()
    slice_points = [0]
    for _ in range(num_patches):
        patch_size = functools.reduce(operator.mul, patch_dims)
        slice_points.append(slice_points[-1] + patch_size)
    slice_points.append(slice_points[-1] + num_labels)
    sample = lbann.Slice(input, slice_points=str_list(slice_points))
    patches = [
        lbann.Reshape(sample, dims=str_list(patch_dims))
        for _ in range(num_patches)
    ]
    labels = lbann.Identity(sample)

    # Siamese network
    head_cnn = modules.ResNet(
        bn_statistics_group_size=bn_statistics_group_size)
    heads = [head_cnn(patch) for patch in patches]
    heads_concat = lbann.Concatenation(heads)

    # Classification network
    class_fc1 = modules.FcBnRelu(
        4096,
        statistics_group_size=bn_statistics_group_size,
        name='siamese_class_fc1',
        data_layout=fc_data_layout)
    class_fc2 = modules.FcBnRelu(
        4096,
        statistics_group_size=bn_statistics_group_size,
        name='siamese_class_fc2',
        data_layout=fc_data_layout)
    class_fc3 = lbann.modules.FullyConnectedModule(num_labels,
                                                   activation=lbann.Softmax,
                                                   name='siamese_class_fc3',
                                                   data_layout=fc_data_layout)
    x = class_fc1(heads_concat)
    x = class_fc2(x)
    probs = class_fc3(x)

    # Setup objective function
    cross_entropy = lbann.CrossEntropy([probs, labels])
    l2_reg_weights = set()
    for l in lbann.traverse_layer_graph(input):
        if type(l) == lbann.Convolution or type(l) == lbann.FullyConnected:
            l2_reg_weights.update(l.weights)
    l2_reg = lbann.L2WeightRegularization(weights=l2_reg_weights, scale=0.0002)
    obj = lbann.ObjectiveFunction([cross_entropy, l2_reg])

    # Setup model
    metrics = [
        lbann.Metric(lbann.CategoricalAccuracy([probs, labels]),
                     name='accuracy',
                     unit='%')
    ]
    callbacks = [lbann.CallbackPrint(), lbann.CallbackTimer()]
    if checkpoint_interval:
        callbacks.append(
            lbann.CallbackCheckpoint(checkpoint_dir='ckpt',
                                     checkpoint_epochs=5))

    # Learning rate schedules
    if warmup:
        callbacks.append(
            lbann.CallbackLinearGrowthLearningRate(target=learning_rate *
                                                   mini_batch_size / 128,
                                                   num_epochs=5))
    callbacks.append(
        lbann.CallbackDropFixedLearningRate(drop_epoch=list(range(0, 100, 15)),
                                            amt=0.25))

    # Construct model
    model = lbann.Model(num_epochs,
                        layers=lbann.traverse_layer_graph(input),
                        objective_function=obj,
                        metrics=metrics,
                        callbacks=callbacks)

    # Setup optimizer
    opt = lbann.SGD(learn_rate=learning_rate, momentum=0.9)
    # opt = lbann.Adam(learn_rate=learning_rate, beta1=0.9, beta2=0.999, eps=1e-8)

    # Setup data reader
    data_reader = make_data_reader(num_patches)

    # Return experiment objects
    return model, data_reader, opt
Ejemplo n.º 16
0
    def forward(self, queries, keys, values, mask=None):
        """Apply multi-head attention.

        The input and output tensors are interpreted as sequences of
        vectors, where the first tensor dimension is the sequence
        dimension.

        Args:
            queries (lbann.Layer): Sequence of query vectors.
            keys (lbann.Layer): Sequence of key vectors.
            values (lbann.Layer): Sequence of value vectors.
            mask (lbann.Layer, optional): Additive attention mask. If
                the (i,j) entry is very negative (e.g. -1e9), then the
                ith query does not attend to the jth key/value pair.

        Returns:
            lbann.Layer: Sequence of output vectors. The sequence
                length is the same as `queries`.

        """
        self.instance += 1
        name = f'{self.name}_instance{self.instance}'

        # Apply fully-connected layers to input sequences
        queries_fc = lbann.ChannelwiseFullyConnected(
            queries,
            weights=self.query_weights,
            output_channel_dims=[self.embed_dim],
            name=f'{name}_queries_fc',
        )
        keys_fc = lbann.ChannelwiseFullyConnected(
            keys,
            weights=self.key_weights,
            output_channel_dims=[self.embed_dim],
            name=f'{name}_keys_fc',
        )
        values_fc = lbann.ChannelwiseFullyConnected(
            values,
            weights=self.value_weights,
            output_channel_dims=[self.embed_dim],
            name=f'{name}_values_fc',
        )

        # Slice embedding vectors for each head
        slice_points = str_list(self.head_dim * i
                                for i in range(self.num_heads + 1))
        queries_slice = lbann.Slice(
            queries_fc,
            axis=1,
            slice_points=slice_points,
            name=f'{name}_queries_slice',
        )
        keys_slice = lbann.Slice(
            keys_fc,
            axis=1,
            slice_points=slice_points,
            name=f'{name}_keys_slice',
        )
        values_slice = lbann.Slice(
            values_fc,
            axis=1,
            slice_points=slice_points,
            name=f'{name}_values_slice',
        )

        # Compute scaled dot-product attention for each head
        attentions = []
        for head in range(self.num_heads):
            head_name = f'{name}_head{head}'

            # Attention inputs
            q = lbann.Identity(queries_slice)
            k = lbann.Identity(keys_slice)
            v = lbann.Identity(values_slice)

            # Multiply queries and keys
            # Note: num_queries x num_keys
            y = lbann.MatMul(
                q,
                k,
                transpose_b=True,
                name=f'{head_name}_matmul',
            )
            y = lbann.WeightedSum(
                y,
                scaling_factors=str(1 / math.sqrt(self.head_dim)),
                name=f'{head_name}_scale',
            )
            if mask:
                y = lbann.Add(y, mask, name=f'{head_name}_mask')
            y = lbann.ChannelwiseSoftmax(y, name=f'{head_name}_softmax')

            # Attention output
            # Note: num_queries x head_dim
            attentions.append(lbann.MatMul(y, v, name=head_name))

        # Concatenate heads and apply fully-connected layer
        attentions = lbann.Concatenation(attentions,
                                         axis=1,
                                         name=f'{name}_heads_concat')
        outputs_fc = lbann.ChannelwiseFullyConnected(
            attentions,
            weights=self.output_weights,
            output_channel_dims=[self.embed_dim],
            name=f'{name}',
        )
        return outputs_fc
Ejemplo n.º 17
0
Archivo: main.py Proyecto: oyamay/lbann
                    type=int,
                    help='latent space dimensions (default: 128)',
                    metavar='NUM')
args = parser.parse_args()

# ----------------------------------
# Construct layer graph
# ----------------------------------

# Dataset properties
vocab_size = dataset.corpus.vocab_size
sequence_length = dataset.sample_dims()[0]

# Input is a sequence of token IDs
input_ = lbann.Identity(lbann.Input())
input_slice = lbann.Slice(input_,
                          slice_points=str_list(range(sequence_length + 1)))
tokens_list = [lbann.Identity(input_slice) for _ in range(sequence_length)]

# Get sequence of embedding vectors
embeddings = lbann.Embedding(input_,
                             num_embeddings=vocab_size,
                             embedding_dim=args.latent_dim)
embeddings_slice = lbann.Slice(embeddings,
                               axis=0,
                               slice_points=str_list(range(sequence_length +
                                                           1)))
embeddings_list = [
    lbann.Reshape(embeddings_slice, dims='-1') for _ in range(sequence_length)
]

# Layer modules
Ejemplo n.º 18
0
def construct_macc_surrogate_model(xdim, ydim, zdim, wae_mcf, surrogate_mcf,
                                   lambda_cyc, useCNN, dump_models,
                                   pretrained_dir, ltfb_batch_interval,
                                   num_epochs):
    """Construct MACC surrogate model.

    See https://arxiv.org/pdf/1912.08113.pdf model architecture and other details

    """
    # Layer graph
    input = lbann.Input(data_field='samples', name='inp_data')
    # data is 64*64*4 images + 15 scalar + 5 param
    inp_slice = lbann.Slice(input,
                            axis=0,
                            slice_points=str_list([0, ydim, ydim + xdim]),
                            name='inp_slice')
    gt_y = lbann.Identity(inp_slice, name='gt_y')
    gt_x = lbann.Identity(inp_slice, name='gt_x')  #param not used

    zero = lbann.Constant(value=0.0, num_neurons='1', name='zero')
    one = lbann.Constant(value=1.0, num_neurons='1', name='one')

    z = lbann.Gaussian(mean=0.0, stdev=1.0, neuron_dims="20")
    wae = macc_network_architectures.MACCWAE(
        zdim, ydim, cf=wae_mcf, use_CNN=useCNN)  #pretrained, freeze
    inv = macc_network_architectures.MACCInverse(xdim, cf=surrogate_mcf)
    fwd = macc_network_architectures.MACCForward(zdim, cf=surrogate_mcf)

    y_pred_fwd = wae.encoder(gt_y)

    param_pred_ = wae.encoder(gt_y)
    input_fake = inv(param_pred_)

    output_cyc = fwd(input_fake)
    y_image_re2 = wae.decoder(output_cyc)
    '''**** Train cycleGAN input params <--> latent space of (images, scalars) ****'''
    output_fake = fwd(gt_x)
    y_image_re = wae.decoder(output_fake)

    param_pred2_ = wae.encoder(y_image_re)
    input_cyc = inv(param_pred2_)

    L_l2_x = lbann.MeanSquaredError(input_fake, gt_x)
    L_cyc_x = lbann.MeanSquaredError(input_cyc, gt_x)

    L_l2_y = lbann.MeanSquaredError(output_fake, y_pred_fwd)
    L_cyc_y = lbann.MeanSquaredError(output_cyc, y_pred_fwd)

    #@todo slice here to separate scalar from image
    img_sca_loss = lbann.MeanSquaredError(y_image_re, gt_y)
    #L_cyc = L_cyc_y + L_cyc_x
    L_cyc = lbann.Add(L_cyc_y, L_cyc_x)

    #loss_gen0  = L_l2_y + lamda_cyc*L_cyc
    loss_gen0 = lbann.WeightedSum([L_l2_y, L_cyc],
                                  scaling_factors=f'1 {lambda_cyc}')
    loss_gen1 = lbann.WeightedSum([L_l2_x, L_cyc_y],
                                  scaling_factors=f'1 {lambda_cyc}')
    #loss_gen1  =  L_l2_x + lamda_cyc*L_cyc_y

    layers = list(lbann.traverse_layer_graph(input))
    weights = set()
    #Freeze appropriate (pretrained) weights
    pretrained_models = ["wae"]  #add macc?
    for l in layers:
        for idx in range(len(pretrained_models)):
            if (l.weights and pretrained_models[idx] in l.name):
                for w in range(len(l.weights)):
                    l.weights[w].optimizer = lbann.NoOptimizer()
        weights.update(l.weights)

    l2_reg = lbann.L2WeightRegularization(weights=weights, scale=1e-4)
    #d_adv_bce = lbann.LayerTerm(d_adv_bce,scale=0.01)
    # Setup objective function
    obj = lbann.ObjectiveFunction([loss_gen0, loss_gen1, l2_reg])
    # Initialize check metric callback
    metrics = [
        lbann.Metric(img_sca_loss, name='fw_loss'),
        lbann.Metric(L_l2_x, name='inverse loss'),
        lbann.Metric(L_cyc_y, name='output cycle loss'),
        lbann.Metric(L_cyc_x, name='param cycle loss')
    ]

    callbacks = [
        lbann.CallbackPrint(),
        lbann.CallbackSaveModel(dir=dump_models),
        lbann.CallbackLoadModel(dirs=str(pretrained_dir)),
        lbann.CallbackTimer()
    ]

    if (ltfb_batch_interval > 0):
        callbacks.append(
            lbann.CallbackLTFB(batch_interval=ltfb_batch_interval,
                               metric='fw_loss',
                               low_score_wins=True,
                               exchange_hyperparameters=True))
    # Construct model
    return lbann.Model(num_epochs,
                       weights=weights,
                       layers=layers,
                       metrics=metrics,
                       objective_function=obj,
                       callbacks=callbacks)
Ejemplo n.º 19
0
def construct_jag_wae_model(ydim, zdim, mcf, useCNN, dump_models,
                            ltfb_batch_interval, num_epochs):
    """Construct LBANN model.

    JAG Wasserstein autoencoder  model

    """

    # Layer graph
    input = lbann.Input(data_field='samples', name='inp_data')
    # data is 64*64*4 images + 15 scalar + 5 param
    #inp_slice = lbann.Slice(input, axis=0, slice_points="0 16399 16404",name='inp_slice')
    inp_slice = lbann.Slice(input,
                            axis=0,
                            slice_points=str_list([0, ydim, ydim + 5]),
                            name='inp_slice')
    gt_y = lbann.Identity(inp_slice, name='gt_y')
    gt_x = lbann.Identity(inp_slice, name='gt_x')  #param not used

    zero = lbann.Constant(value=0.0, num_neurons='1', name='zero')
    one = lbann.Constant(value=1.0, num_neurons='1', name='one')

    z_dim = 20  #Latent space dim

    z = lbann.Gaussian(mean=0.0, stdev=1.0, neuron_dims="20")
    model = macc_network_architectures.MACCWAE(zdim,
                                               ydim,
                                               cf=mcf,
                                               use_CNN=useCNN)
    d1_real, d1_fake, d_adv, pred_y = model(z, gt_y)

    d1_real_bce = lbann.SigmoidBinaryCrossEntropy([d1_real, one],
                                                  name='d1_real_bce')
    d1_fake_bce = lbann.SigmoidBinaryCrossEntropy([d1_fake, zero],
                                                  name='d1_fake_bce')
    d_adv_bce = lbann.SigmoidBinaryCrossEntropy([d_adv, one], name='d_adv_bce')
    img_loss = lbann.MeanSquaredError([pred_y, gt_y])
    rec_error = lbann.L2Norm2(
        lbann.WeightedSum([pred_y, gt_y], scaling_factors="1 -1"))

    layers = list(lbann.traverse_layer_graph(input))
    # Setup objective function
    weights = set()
    src_layers = []
    dst_layers = []
    for l in layers:
        if (l.weights and "disc0" in l.name and "instance1" in l.name):
            src_layers.append(l.name)
        #freeze weights in disc2
        if (l.weights and "disc1" in l.name):
            dst_layers.append(l.name)
            for idx in range(len(l.weights)):
                l.weights[idx].optimizer = lbann.NoOptimizer()
        weights.update(l.weights)
    l2_reg = lbann.L2WeightRegularization(weights=weights, scale=1e-4)
    d_adv_bce = lbann.LayerTerm(d_adv_bce, scale=0.01)
    obj = lbann.ObjectiveFunction(
        [d1_real_bce, d1_fake_bce, d_adv_bce, img_loss, rec_error, l2_reg])
    # Initialize check metric callback
    metrics = [lbann.Metric(img_loss, name='recon_error')]
    #pred_y = macc_models.MACCWAE.pred_y_name
    callbacks = [
        lbann.CallbackPrint(),
        lbann.CallbackTimer(),
        lbann.CallbackPrintModelDescription(),
        lbann.CallbackSaveModel(dir=dump_models),
        lbann.CallbackReplaceWeights(source_layers=list2str(src_layers),
                                     destination_layers=list2str(dst_layers),
                                     batch_interval=2)
    ]

    if (ltfb_batch_interval > 0):
        callbacks.append(
            lbann.CallbackLTFB(batch_interval=ltfb_batch_interval,
                               metric='recon_error',
                               low_score_wins=True,
                               exchange_hyperparameters=True))

    # Construct model
    return lbann.Model(num_epochs,
                       weights=weights,
                       layers=layers,
                       metrics=metrics,
                       objective_function=obj,
                       callbacks=callbacks)
Ejemplo n.º 20
0
def make_model(
    num_epochs,
    embed_dim,
    num_heads,
    label_smoothing,
):

    # Embedding weights
    var = 2 / (embed_dim + vocab_size)  # Glorot initialization
    embedding_weights = lbann.Weights(
        name='embeddings',
        initializer=lbann.NormalInitializer(standard_deviation=math.sqrt(var)),
    )

    # Input is two sequences of token IDs
    input_ = lbann.Input(data_field='samples')

    # Get sequences of embedding vectors
    # Note: Scale embeddings by sqrt(embed_dim).
    # Note: Decoder input is shifted right, so embedding for last
    # token isn't needed.
    embeddings_tokens = lbann.Identity(
        lbann.Slice(
            input_,
            axis=0,
            slice_points=str_list([0, 2 * sequence_length - 1]),
        ))
    embeddings = lbann.Embedding(
        embeddings_tokens,
        weights=embedding_weights,
        num_embeddings=vocab_size,
        embedding_dim=embed_dim,
        padding_idx=pad_index,
    )
    embeddings = lbann.WeightedSum(
        embeddings,
        scaling_factors=str(math.sqrt(embed_dim)),
    )
    embeddings_slice = lbann.Slice(
        embeddings,
        axis=0,
        slice_points=str_list([0, sequence_length, 2 * sequence_length - 1]),
    )
    encoder_input = lbann.Identity(embeddings_slice)
    decoder_input = lbann.Identity(embeddings_slice)

    # Apply transformer model
    transformer = lbann.models.Transformer(
        hidden_size=embed_dim,
        num_heads=num_heads,
        name='transformer',
    )
    result = transformer(
        encoder_input,
        sequence_length,
        decoder_input,
        sequence_length - 1,
    )

    # Reconstruct decoder input
    preds = lbann.ChannelwiseFullyConnected(
        result,
        weights=embedding_weights,
        output_channel_dims=[vocab_size],
        bias=False,
        transpose=True,
    )
    preds = lbann.ChannelwiseSoftmax(preds)
    preds = lbann.Slice(preds,
                        axis=0,
                        slice_points=str_list(range(sequence_length)))
    preds = [lbann.Identity(preds) for _ in range(sequence_length - 1)]

    # Count number of non-pad tokens
    label_tokens = lbann.Identity(
        lbann.Slice(
            input_,
            slice_points=str_list([sequence_length + 1, 2 * sequence_length]),
        ))
    pads = lbann.Constant(value=pad_index,
                          num_neurons=str(sequence_length - 1))
    is_not_pad = lbann.NotEqual(label_tokens, pads)
    num_not_pad = lbann.Reduction(is_not_pad, mode='sum')

    # Cross entropy loss with label smoothing
    label_tokens = lbann.Slice(
        label_tokens,
        slice_points=str_list(range(sequence_length)),
    )
    label_tokens = [
        lbann.Identity(label_tokens) for _ in range(sequence_length - 1)
    ]
    if label_smoothing > 0:
        uniform_label = lbann.Constant(value=1 / vocab_size,
                                       num_neurons=str_list([1, vocab_size]))
    loss = []
    for i in range(sequence_length - 1):
        label = lbann.OneHot(label_tokens[i], size=vocab_size)
        label = lbann.Reshape(label, dims=str_list([1, vocab_size]))
        if label_smoothing > 0:
            label = lbann.WeightedSum(
                label,
                uniform_label,
                scaling_factors=str_list(
                    [1 - label_smoothing, label_smoothing]),
            )
        loss.append(lbann.CrossEntropy(preds[i], label))
    loss = lbann.Concatenation(loss)

    # Average cross entropy over non-pad tokens
    loss_scales = lbann.Divide(
        is_not_pad,
        lbann.Tessellate(num_not_pad, hint_layer=is_not_pad),
    )
    loss = lbann.Multiply(loss, loss_scales)
    loss = lbann.Reduction(loss, mode='sum')

    # Construct model
    metrics = []
    callbacks = [lbann.CallbackPrint(), lbann.CallbackTimer()]
    return lbann.Model(
        num_epochs,
        layers=lbann.traverse_layer_graph(input_),
        objective_function=loss,
        metrics=metrics,
        callbacks=callbacks,
    )
Ejemplo n.º 21
0
def construct_model(run_args):
    """Construct LBANN model.

    Initial model for ATOM molecular VAE

    """
    import lbann

    print("Dump model dir ", run_args.dump_model_dir)
    assert run_args.dump_model_dir, "evaluate script asssumes a pretrained WAE model"
    pad_index = run_args.pad_index
    assert pad_index is not None

    sequence_length = run_args.sequence_length
    assert sequence_length is not None

    print("sequence length is {}".format(sequence_length))
    data_layout = "data_parallel"
    # Layer graph
    input_ = lbann.Identity(lbann.Input(name='inp', target_mode="N/A"),
                            name='inp1')
    wae_loss = []
    input_feature_dims = sequence_length

    embedding_size = run_args.embedding_dim
    dictionary_size = run_args.num_embeddings
    assert embedding_size is not None
    assert dictionary_size is not None

    save_output = False

    print("save output? ", save_output, "out dir ", run_args.dump_outputs_dir)
    z = lbann.Gaussian(mean=0.0, stdev=1.0, neuron_dims="128")

    x = lbann.Slice(input_, slice_points=str_list([0, input_feature_dims]))
    x = lbann.Identity(x)
    waemodel = molwae.MolWAE(input_feature_dims, dictionary_size,
                             embedding_size, pad_index, save_output)
    x_emb = lbann.Embedding(x,
                            num_embeddings=waemodel.dictionary_size,
                            embedding_dim=waemodel.embedding_size,
                            name='emb',
                            weights=waemodel.emb_weights)

    latentz = waemodel.forward_encoder(x_emb)

    fake_loss = lbann.MeanAbsoluteError(latentz, z)

    layers = list(lbann.traverse_layer_graph(input_))
    # Setup objective function
    weights = set()
    for l in layers:
        weights.update(l.weights)

    obj = lbann.ObjectiveFunction(fake_loss)

    callbacks = [lbann.CallbackPrint(), lbann.CallbackTimer()]

    #Dump output (activation) for post processing
    conc_out = lbann.Concatenation([input_, latentz], name='conc_out')
    callbacks.append(
        lbann.CallbackDumpOutputs(
            batch_interval=run_args.dump_outputs_interval,
            execution_modes='test',
            directory=run_args.dump_outputs_dir,
            layers=f'{conc_out.name}'))
    # Construct model
    return lbann.Model(run_args.num_epochs,
                       weights=weights,
                       layers=layers,
                       objective_function=obj,
                       callbacks=callbacks)
Ejemplo n.º 22
0
    def compute_loss(self, x, y):

        # y[:, :-1]
        y = lbann.Slice(
            y,
            axis=0,
            slice_points=str_list([0, self.input_feature_dims - 1]),
        )
        y = lbann.Identity(y)

        # x[:, 1:]
        x = lbann.Slice(
            x,
            slice_points=str_list([1, self.input_feature_dims]),
        )
        x = lbann.Identity(x)

        # Convert indices in x to one-hot representation
        # Note: Ignored indices result in zero vectors
        ignore_mask = lbann.Equal(
            x,
            self.constant(self.label_to_ignore, hint_layer=x),
        )
        keep_mask = lbann.LogicalNot(ignore_mask)
        length = lbann.Reduction(keep_mask, mode='sum')
        length = lbann.Max(length, self.constant(1, [1]))
        x = lbann.Add(
            lbann.Multiply(keep_mask, x),
            lbann.Multiply(ignore_mask, self.constant(-1, hint_layer=x)),
        )
        x = lbann.Slice(x,
                        slice_points=str_list(range(self.input_feature_dims)))
        x = [lbann.Identity(x) for _ in range(self.input_feature_dims - 1)]
        x = [lbann.OneHot(xi, size=self.dictionary_size) for xi in x]
        x = [
            lbann.Reshape(xi, dims=str_list([1, self.dictionary_size]))
            for xi in x
        ]
        x = lbann.Concatenation(x, axis=0)

        # recon_loss = F.cross_entropy(
        #     y[:, :-1].contiguous().view(-1, y.size(-1)),
        #     x[:, 1:].contiguous().view(-1),
        #     ignore_index=self.pad
        # )
        # Note: Ideally we'd shift y by y.max(-1) for numerical stability
        shifts = lbann.MatMul(
            lbann.Max(y, self.constant(0, hint_layer=y)),
            self.constant(
                1 / math.sqrt(self.dictionary_size),
                [self.dictionary_size, self.dictionary_size],
            ),
        )
        y = lbann.Subtract(y, shifts)
        z = lbann.MatMul(
            lbann.Exp(y),
            self.constant(1, [self.dictionary_size, 1]),
        )
        z = lbann.Log(z)
        z = lbann.MatMul(
            lbann.Reshape(keep_mask, dims=str_list([1, -1])),
            z,
        )
        recon_loss = lbann.MatMul(
            lbann.Reshape(y, dims=str_list([1, -1])),
            lbann.Reshape(x, dims=str_list([1, -1])),
            transpose_b=True,
        )
        recon_loss = lbann.Subtract(z, recon_loss)
        recon_loss = lbann.Reshape(recon_loss, dims=str_list([1]))
        recon_loss = lbann.Divide(recon_loss, length)

        return recon_loss
Ejemplo n.º 23
0
    def forward(self, x, prev_state):
        """ Apply GRU step channelwise 
        Args: 
            x (Layer): Input (shape: (num_channels, *))
            prev_state (Layer): Sate from previous GRU step  (shape: (num_channels, size))
        Returns:
            (Layer, Layer): The output (out) and state (hn). The state can be passed directly into the next GRU step
        """

        self.step += 1

        name = f"{self.name}_step{self.step}"

        mat_size = self.num_channels * self.size

        prev_state = lbann.Reshape(prev_state,
                                   dims=str_list(
                                       [self.num_channels, self.size]),
                                   name=name + "_prev_state_reshape")

        fc1 = self.ih_fc(x)
        fc2 = self.hh_fc(prev_state)

        fc1_slice = lbann.Slice(
            fc1,
            axis=1,
            slice_points=str_list([0, self.size, 2 * self.size,
                                   3 * self.size]))

        Wir_x = lbann.Reshape(lbann.Identity(fc1_slice),
                              dims=str_list([self.num_channels, self.size]),
                              name=name + '_Wir_x')
        Wiz_z = lbann.Reshape(lbann.Identity(fc1_slice),
                              dims=str_list([self.num_channels, self.size]),
                              name=name + '_Wiz_z')
        Win_x = lbann.Reshape(lbann.Identity(fc1_slice),
                              dims=str_list([self.num_channels, self.size]),
                              name=name + '_Win_x')
        fc2_slice = lbann.Slice(
            fc2,
            axis=1,
            slice_points=str_list([0, self.size, 2 * self.size,
                                   3 * self.size]))

        Whr_x = lbann.Reshape(lbann.Identity(fc2_slice),
                              dims=str_list([self.num_channels, self.size]),
                              name=name + '_Whr_x')
        Whz_z = lbann.Reshape(lbann.Identity(fc2_slice),
                              dims=str_list([self.num_channels, self.size]),
                              name=name + '_Whz_z')
        Whn_x = lbann.Reshape(lbann.Identity(fc2_slice),
                              dims=str_list([self.num_channels, self.size]),
                              name=name + '_Whn_x')

        rt = \
            lbann.Sigmoid(
                lbann.Add(Wir_x, Whr_x, data_layout=self.data_layout),
                name=name + '_reset_gate',
                data_layout=self.data_layout
            )

        zt = \
            lbann.Sigmoid(
                lbann.Add(Wiz_z, Whz_z, data_layout=self.data_layout),
                name=name + '_update_gate',
                data_layout=self.data_layout,
            )

        nt = \
            lbann.Tanh(
                lbann.Add(
                    Win_x,
                    lbann.Multiply(rt, Whn_x, data_layout=self.data_layout),
                    data_layout=self.data_layout,
                ),
                name=name + '_new_gate', data_layout=self.data_layout,
            )

        ht = \
            lbann.Add(
                lbann.Multiply(
                    lbann.WeightedSum(
                        self.ones,
                        zt,
                        scaling_factors='1 -1', data_layout=self.data_layout
                    ),
                    nt,
                    data_layout=self.data_layout
                ),
                lbann.Multiply(zt, prev_state, data_layout=self.data_layout),
                name=name+ '_output', data_layout=self.data_layout,
            )

        ht = lbann.Reshape(ht, dims=str_list([self.num_channels, self.size]))

        return ht, ht
Ejemplo n.º 24
0
def construct_model():
    """Construct LBANN model.

    JAG Wasserstein autoencoder  model

    """
    import lbann

    # Layer graph
    input = lbann.Input(target_mode='N/A',name='inp_data')
    # data is 64*64*4 images + 15 scalar + 5 param
    inp_slice = lbann.Slice(input, axis=0, slice_points="0 16399 16404",name='inp_slice')
    gt_y = lbann.Identity(inp_slice,name='gt_y')
    gt_x = lbann.Identity(inp_slice, name='gt_x') #param not used

    zero  = lbann.Constant(value=0.0,num_neurons='1',name='zero')
    one  = lbann.Constant(value=1.0,num_neurons='1',name='one')

    y_dim = 16399 #image+scalar shape
    z_dim = 20  #Latent space dim

    z = lbann.Gaussian(mean=0.0,stdev=1.0, neuron_dims="20")
    d1_real, d1_fake, d_adv, pred_y  = jag_models.WAE(z_dim,y_dim)(z,gt_y)

    d1_real_bce = lbann.SigmoidBinaryCrossEntropy([d1_real,one],name='d1_real_bce')
    d1_fake_bce = lbann.SigmoidBinaryCrossEntropy([d1_fake,zero],name='d1_fake_bce')
    d_adv_bce = lbann.SigmoidBinaryCrossEntropy([d_adv,one],name='d_adv_bce')

    img_loss = lbann.MeanSquaredError([pred_y,gt_y])
    rec_error = lbann.L2Norm2(lbann.WeightedSum([pred_y,gt_y], scaling_factors="1 -1"))

    layers = list(lbann.traverse_layer_graph(input))
    # Setup objective function
    weights = set()
    src_layers = []
    dst_layers = []
    for l in layers:
      if(l.weights and "disc0" in l.name and "instance1" in l.name):
        src_layers.append(l.name)
      #freeze weights in disc2
      if(l.weights and "disc1" in l.name):
        dst_layers.append(l.name)
        for idx in range(len(l.weights)):
          l.weights[idx].optimizer = lbann.NoOptimizer()
      weights.update(l.weights)
    l2_reg = lbann.L2WeightRegularization(weights=weights, scale=1e-4)
    d_adv_bce = lbann.LayerTerm(d_adv_bce,scale=0.01)
    obj = lbann.ObjectiveFunction([d1_real_bce,d1_fake_bce,d_adv_bce,img_loss,rec_error,l2_reg])
    # Initialize check metric callback
    metrics = [lbann.Metric(img_loss, name='recon_error')]

    callbacks = [lbann.CallbackPrint(),
                 lbann.CallbackTimer(),
                 lbann.CallbackReplaceWeights(source_layers=list2str(src_layers),
                                      destination_layers=list2str(dst_layers),
                                      batch_interval=2)]

    # Construct model
    num_epochs = 100
    return lbann.Model(num_epochs,
                       weights=weights,
                       layers=layers,
                       metrics=metrics,
                       objective_function=obj,
                       callbacks=callbacks)
Ejemplo n.º 25
0
# ----------------------------------
# Construct layer graph
# ----------------------------------

# Properties of graph and random walk
num_graph_nodes = dataset.max_graph_node_id() + 1
walk_length = dataset.walk_context_length
num_negative_samples = dataset.num_negative_samples
input_size = dataset.sample_dims()[0]

# Embedding vectors, including negative sampling
# Note: Input is sequence of graph node IDs
input_ = lbann.Identity(lbann.Input())
input_slice = lbann.Slice(
    input_,
    slice_points=f'0 {num_negative_samples+1} {input_size}'
)
decoder_embeddings = lbann.Embedding(
    input_slice,
    weights=decoder_embeddings_weights,
    num_embeddings=num_graph_nodes,
    embedding_dim=args.latent_dim,
)
encoder_embeddings = lbann.Embedding(
    input_slice,
    weights=encoder_embeddings_weights,
    num_embeddings=num_graph_nodes,
    embedding_dim=args.latent_dim,
)

# Skip-Gram with negative sampling
Ejemplo n.º 26
0
        _reader.synth_dimensions = '1'
        _reader.percent_of_data_to_use = 1.0

    add_data_reader('train')
    add_data_reader('test')
    input_ = lbann.Input()

    # Radial profile
    x = lbann.WeightsLayer(
        weights=lbann.Weights(
            lbann.ValueInitializer(values=str_list(image.flatten())), ),
        dims=str_list(image.shape),
    )
    max_r = image.shape[-1] // 2
    rprof = RadialProfile()(x, image.shape, max_r)
    rprof_slice = lbann.Slice(rprof, slice_points=str_list([0, 1, 2, 3]))
    red = lbann.Identity(rprof_slice, name='red')
    green = lbann.Identity(rprof_slice, name='green')
    blue = lbann.Identity(rprof_slice, name='blue')

    # Construct model
    callbacks = [
        lbann.CallbackDumpOutputs(layers=str_list(['red', 'green', 'blue'])),
    ]
    model = lbann.Model(
        epochs=0,
        layers=lbann.traverse_layer_graph([input_, rprof]),
        callbacks=callbacks,
    )

    # Run LBANN
Ejemplo n.º 27
0
def construct_model(run_args):
    """Construct LBANN model.

    Initial model for ATOM molecular VAE

    """
    import lbann

    pad_index = run_args.pad_index
    assert pad_index is not None

    sequence_length = run_args.sequence_length
    assert sequence_length is not None, 'should be training seq len + bos + eos'

    print("sequence length is {}, which is training sequence len + bos + eos".format(sequence_length))
    data_layout = "data_parallel"
    # Layer graph
    input_ = lbann.Input(data_field='samples',name='inp_data')
    #Note input assumes to come from encoder script concatenation of input smiles + z
    inp_slice = lbann.Slice(input_, axis=0,
                             slice_points=str_list([0, sequence_length, sequence_length+run_args.z_dim]),
                             name='inp_slice')
    inp_smile = lbann.Identity(inp_slice,name='inp_smile')
    z  = lbann.Identity(inp_slice, name='z')
    wae_loss= []
    input_feature_dims = sequence_length

    embedding_size = run_args.embedding_dim
    dictionary_size = run_args.num_embeddings
    assert embedding_size is not None
    assert dictionary_size is not None

    save_output = True if run_args.dump_outputs_dir else False

    print("save output? ", save_output, "out dir ",  run_args.dump_outputs_dir)
    #uncomment below for random sampling
    #z = lbann.Gaussian(mean=0.0,stdev=1.0, neuron_dims=str(run_args.z_dim))
    x = lbann.Slice(inp_smile, slice_points=str_list([0, input_feature_dims]))
    x = lbann.Identity(x)
    waemodel = molwae.MolWAE(input_feature_dims,
                           dictionary_size,
                           embedding_size,
                           pad_index,run_args.z_dim,save_output=save_output)
    x_emb = lbann.Embedding(
            x,
            num_embeddings=waemodel.dictionary_size,
            embedding_dim=waemodel.embedding_size,
            name='emb',
            weights=waemodel.emb_weights
    )


    pred, arg_max = waemodel.forward_decoder(x_emb,z)

    recon = waemodel.compute_loss(x, pred)



    wae_loss.append(recon)

    layers = list(lbann.traverse_layer_graph(input_))
    # Setup objective function
    weights = set()
    for l in layers:
      weights.update(l.weights)
    #l2_reg = lbann.L2WeightRegularization(weights=weights, scale=1e-4)

    #wae_loss.append(l2_reg)
    print("LEN wae loss ", len(wae_loss))

    obj = lbann.ObjectiveFunction(wae_loss)

    # Initialize check metric callback
    metrics = [lbann.Metric(recon, name='recon')]

    callbacks = [lbann.CallbackPrint(),
                 lbann.CallbackTimer()]


    #Dump output (activation) for post processing
    pred_tensor = lbann.Concatenation(arg_max, name='pred_tensor')
    conc_out = lbann.Concatenation([input_,pred_tensor], name='conc_out')
    callbacks.append(lbann.CallbackDumpOutputs(batch_interval=run_args.dump_outputs_interval,
                       execution_modes='test',
                       directory=run_args.dump_outputs_dir,
                       layers=f'{conc_out.name}'))
    # Construct model
    return lbann.Model(run_args.num_epochs,
                       weights=weights,
                       layers=layers,
                       objective_function=obj,
                       metrics=metrics,
                       callbacks=callbacks)
Ejemplo n.º 28
0
    def forward(self, x, prev_state):
        """Apply GRU step.

        Args:
            x (Layer): Input.
            prev_state: State from previous GRU step.

        Returns:
            (Layer, Layer): The output (out)  and state (hn).
                          The state can be passed directly into
                           the next GRU step.

        """
        self.step += 1
        name = '{0}_step{1}'.format(self.name, self.step)


        fc1 = self.ih_fc(x)   #input_fc
        fc2 = self.hh_fc(prev_state)  #hidden_fc


        # Get gates and cell update
        fc1_slice = lbann.Slice(fc1,
                            slice_points=str_list([0, self.size, 2*self.size, 3*self.size]),
                            name=name + '_fc1_slice',
                            data_layout=self.data_layout)
        Wir_x = lbann.Identity(fc1_slice, name=name + '_Wrx',
                           data_layout=self.data_layout)
        Wiz_x = lbann.Identity(fc1_slice, name=name + '_Wzx',
                           data_layout=self.data_layout)
        Win_x = lbann.Identity(fc1_slice, name=name + '_Wnx',
                           data_layout=self.data_layout)

        fc2_slice = lbann.Slice(fc2,
                            slice_points=str_list([0, self.size, 2*self.size, 3*self.size]),
                            name=name + '_fc2_slice',
                            data_layout=self.data_layout)
        Whr_prev = lbann.Identity(fc2_slice, name=name + '_Wrh',
                           data_layout=self.data_layout)
        Whz_prev = lbann.Identity(fc2_slice, name=name + '_Wzh',
                           data_layout=self.data_layout)
        Whn_prev = lbann.Identity(fc2_slice, name=name + '_Wnh',
                           data_layout=self.data_layout)

        rt = \
            lbann.Sigmoid(
                lbann.Add(Wir_x, Whr_prev, data_layout=self.data_layout),
                name=name + '_reset_gate',
                data_layout=self.data_layout
            )

        zt = \
            lbann.Sigmoid(
                lbann.Add(Wiz_x, Whz_prev, data_layout=self.data_layout),
                name=name + '_update_gate',
                data_layout=self.data_layout,
            )

        nt = \
            lbann.Tanh(
                lbann.Add(
                    Win_x,
                    lbann.Multiply(rt, Whn_prev, data_layout=self.data_layout),
                    data_layout=self.data_layout,
                ),
                name=name + '_new_gate', data_layout=self.data_layout,
            )

        ht = \
            lbann.Add(
                lbann.Multiply(
                    lbann.WeightedSum(
                        self.ones,
                        zt,
                        scaling_factors='1 -1', data_layout=self.data_layout
                    ),
                    nt,
                    data_layout=self.data_layout
                ),
                lbann.Multiply(zt, prev_state, data_layout=self.data_layout),
                name=name+ '_output', data_layout=self.data_layout,
            )

        # Return output
        return ht, ht
Ejemplo n.º 29
0
            standard_deviation=1 / args.latent_dim,
        ),
        name='embeddings',
    )
    embeddings = lbann.Embedding(
        input_,
        weights=embeddings_weights,
        num_embeddings=num_vertices,
        embedding_dim=args.latent_dim,
    )
else:
    raise RuntimeError(
        f'unknown method to get embedding vectors ({args.embeddings})')
embeddings_slice = lbann.Slice(
    embeddings,
    axis=0,
    slice_points=utils.str_list(
        [0, num_negative_samples, num_negative_samples + walk_length]),
)
negative_samples_embeddings = lbann.Identity(embeddings_slice)
walk_embeddings = lbann.Identity(embeddings_slice)

# Skip-Gram objective function
positive_loss = model.skip_gram.positive_samples_loss(
    walk_length,
    lbann.Identity(walk_embeddings),
    lbann.Identity(walk_embeddings),
    scale_decay=0.8,
)
negative_loss = model.skip_gram.negative_samples_loss(
    walk_embeddings,
    negative_samples_embeddings,
Ejemplo n.º 30
0
    def forward(self, x, prev_state):
        """Apply LSTM step.

        Args:
            x (Layer): Input.
            prev_state (tuple with two `Layer`s): State from previous
                LSTM step. Comprised of LSTM output and cell state.

        Returns:
            (Layer, (Layer, Layer)): The output and state (the output
                and cell state). The state can be passed directly into
                the next LSTM step.

        """
        self.step += 1
        name = '{0}_step{1}'.format(self.name, self.step)

        # Get output and cell state from previous step
        prev_output, prev_cell = prev_state

        # Apply linearity
        input_concat = lbann.Concatenation(x, prev_output,
                                           name=name + '_input',
                                           data_layout=self.data_layout)
        fc = self.fc(input_concat)

        # Get gates and cell update
        slice = lbann.Slice(fc,
                            slice_points=str_list([0, self.size, 4*self.size]),
                            name=name + '_fc_slice',
                            data_layout=self.data_layout)
        cell_update = lbann.Tanh(slice,
                                 name=name + '_cell_update',
                                 data_layout=self.data_layout)
        sigmoid = lbann.Sigmoid(slice,
                                name=name + '_sigmoid',
                                data_layout=self.data_layout)
        slice = lbann.Slice(sigmoid,
                            slice_points=str_list([0, self.size, 2*self.size, 3*self.size]),
                            name=name + '_sigmoid_slice',
                            data_layout=self.data_layout)
        f = lbann.Identity(slice, name=name + '_forget_gate',
                           data_layout=self.data_layout)
        i = lbann.Identity(slice, name=name + '_input_gate',
                           data_layout=self.data_layout)
        o = lbann.Identity(slice, name=name + '_output_gate',
                           data_layout=self.data_layout)

        # Cell state
        cell_forget = lbann.Multiply(f, prev_cell,
                                     name=name + '_cell_forget',
                                     data_layout=self.data_layout)
        cell_input = lbann.Multiply(i, cell_update,
                                    name=name + '_cell_input',
                                    data_layout=self.data_layout)
        cell = lbann.Add(cell_forget, cell_input, name=name + '_cell',
                         data_layout=self.data_layout)

        # Output
        cell_act = lbann.Tanh(cell, name=name + '_cell_activation',
                              data_layout=self.data_layout)
        output = lbann.Multiply(o, cell_act, name=name,
                                data_layout=self.data_layout)

        # Return output and state
        return output, (output, cell)