Exemplo n.º 1
0
    def forward(self, x, z):
        """Do the WAE forward step

        :param x: list of tensors of longs, embed representation of input
        :return: float, kl term component of loss
        :return: float, recon component of loss
        """

        x = lbann.Slice(x, slice_points=str_list([0, self.input_feature_dims]))
        x = lbann.Identity(x)
        x_emb = lbann.Embedding(x,
                                num_embeddings=self.dictionary_size,
                                embedding_dim=self.embedding_size,
                                name='emb',
                                weights=self.emb_weights)

        # Encoder: x -> z, kl_loss
        z_sample = self.forward_encoder(x_emb)

        eps = lbann.Gaussian(mean=self.gmean,
                             stdev=self.gstd,
                             hint_layer=z_sample)
        z_sample = lbann.Add([z_sample, eps])

        # Decoder: x, z -> recon_loss
        #pred = self.forward_decoder(x_emb, z_sample)
        pred, arg_max = self.forward_decoder(x_emb, z_sample)
        recon_loss = self.compute_loss(x, pred)

        # Hack to remove blocking GPU allreduce in evaluation layer
        #kl_loss = lbann.Identity(kl_loss, device='CPU')
        recon_loss = lbann.Identity(recon_loss, device='CPU')

        z_prior = lbann.Tessellate(
            lbann.Reshape(z, dims=str_list([1, self.zdim])),
            dims=str_list([self.input_feature_dims, self.zdim]),
        )

        d_real = self.discriminator0(
            lbann.Concatenation([x_emb, z_prior], axis=1))

        z_sample0 = lbann.Tessellate(
            lbann.Reshape(z_sample, dims=str_list([1, self.zdim])),
            dims=str_list([self.input_feature_dims, self.zdim]),
        )
        y_z_sample = lbann.Concatenation([x_emb, z_sample0], axis=1)

        d_fake = self.discriminator0(lbann.StopGradient(y_z_sample))
        d_adv = self.discriminator1(y_z_sample)  #freeze

        return recon_loss, d_real, d_fake, d_adv, arg_max
Exemplo n.º 2
0
    def forward(self, z, y):

        z_sample = self.encoder(y)

        y_recon = self.decoder(z_sample)

        #d real/fake share weights, shared weights is copied to d_adv
        #(through replace weight callback) and freeze
        d_real = self.discriminator0(lbann.Concatenation([y, z], axis=0))
        y_z_sample = lbann.Concatenation([y, z_sample], axis=0)
        d_fake = self.discriminator0(lbann.StopGradient(y_z_sample))
        d_adv = self.discriminator1(y_z_sample)  #freeze

        return d_real, d_fake, d_adv, y_recon
Exemplo n.º 3
0
def dense_layer(statistics_group_size, current_block_num, current_layer_num,
                cumulative_layer_num, parent_nodes, batch_norm_size,
                growth_rate):
    concatenation_node = lbann.Concatenation(parent_nodes)
    cumulative_layer_num += 1
    log('dense_block={b} dense_layer={l} Concatenation. cumulative_layer_num={n}'
        .format(b=current_block_num,
                l=current_layer_num,
                n=cumulative_layer_num))
    conv_block_1_node, cumulative_layer_num = conv_block(
        statistics_group_size,
        current_block_num,
        current_layer_num,
        cumulative_layer_num,
        concatenation_node,
        conv_dims_i=1,
        conv_pads_i=0,
        num_output_channels=batch_norm_size * growth_rate)
    conv_block_2_node, cumulative_layer_num = conv_block(
        statistics_group_size,
        current_block_num,
        current_layer_num,
        cumulative_layer_num,
        conv_block_1_node,
        conv_dims_i=3,
        conv_pads_i=1,
        num_output_channels=growth_rate)
    return conv_block_2_node, cumulative_layer_num
Exemplo n.º 4
0
    def forward(self, img, z, mcr):
        '''
        Steps: 
        - Modify image if using mcr
        - D1 + imgs -> d1_real
        - G + noise -> gen_imgs
        - D1 + gen_imgs -> d1_fake
        - Adv (D2) + gen_imgs
        Return D outputs and gen_imgs
        '''

        print('MCR in forward', mcr)
        if mcr:  ### Multi-channel rescaling. Add extra channel for real images. Generated images are rescaled inside generator
            linear_scale = 1 / self.linear_scaler
            ch2 = lbann.Tanh(
                lbann.WeightedSum(self.inv_transform(lbann.Identity(img)),
                                  scaling_factors=str(linear_scale)))
            y = lbann.Concatenation(lbann.Identity(img), ch2, axis=0)
            img = lbann.Reshape(y, dims='2 128 128')
        else:
            img = lbann.Reshape(img, dims='1 128 128')

        d1_real = self.forward_discriminator1(img)  #instance1
        gen_img = self.forward_generator(z, mcr=mcr)

        d1_fake = self.forward_discriminator1(
            lbann.StopGradient(gen_img))  #instance2
        d_adv = self.forward_discriminator2(
            gen_img)  #instance 3 //need to freeze
        #d1s share weights, d1_w is copied to d_adv (through replace weight callback) and freeze

        return d1_real, d1_fake, d_adv, gen_img, img
Exemplo n.º 5
0
    def forward(self, node_feature_mat, source_indices, target_indices):
        """Call the GatedGraphConv
        Args:
            node_feature_mat (Layer): Node feature matrix with the shape of (num_nodes,input_channels) 
            source_indices (Layer): Source node indices of the edges with shape (num_nodes)
            target_indices (Layer): Target node indices of the edges with shape (num_nodes)
        Returns:     
            (Layer) : The output after kernel ops. The output can passed into another Graph Conv layer
                          directly
        """

        if (self.input_channel_size < self.output_channel_size):
            num_zeros = self.output_channel_size - self.input_channel_size
            print(num_zeros)
            zeros = lbann.Constant(value = 0, num_neurons = str_list([self.num_nodes,num_zeros]), name = self.name+'_padded')
            node_feature_mat = lbann.Concatenation(node_feature_mat, zeros, axis = 1)       
            
        elif (input_features > self.output_channel_size):
            ValueError('The feature size of the nodes {} cannot be greater than the output dimension {}'.
                        format(input_features, self.output_channel_size))

        for layer in range(self.num_layers): 
        
            messages = self.nns[layer](node_feature_mat) 
            neighborhoods = GraphExpand(messages, target_indices)
            aggregate = GraphReduce(neighborhoods,source_indices, [self.num_nodes, self.output_channel_size])

            node_feature_mat = self.rnn(aggregate, node_feature_mat)
        
        return node_feature_mat
Exemplo n.º 6
0
    def forward_generator(self, z, mcr):
        '''
        Build the Generator
        '''
        x = lbann.Relu(
            lbann.BatchNormalization(self.g_fc1(z),
                                     decay=0.9,
                                     scale_init=1.0,
                                     epsilon=1e-5))
        dims = '512 8 8'
        x = lbann.Reshape(x, dims=dims)  #channel first

        for count, lyr in enumerate(self.g_convT):
            x = lbann.Relu(
                lbann.BatchNormalization(lyr(x),
                                         decay=0.9,
                                         scale_init=1.0,
                                         epsilon=1e-5))

        img = self.g_convT3(x)

        if mcr:  ### For multi-channel rescaling, add extra channel to output image
            linear_scale = 1 / self.linear_scaler
            #             linear_scale=lbann.Constant(value=0.001)
            ch2 = lbann.Tanh(
                lbann.WeightedSum(self.inv_transform(img),
                                  scaling_factors=str(linear_scale)))
            y = lbann.Concatenation(img, ch2, axis=0)
            img = lbann.Reshape(y, dims='2 128 128')
        else:
            img = lbann.Reshape(img, dims='1 128 128')

        return img
Exemplo n.º 7
0
def densenet(version,
             cumulative_layer_num,
             images_node
             ):
    if version == 121:
        growth_rate = 32  # k in the paper
        layers_per_block = (6, 12, 24, 16)
        num_initial_features = 64
    elif version == 161:
        growth_rate = 48  # k in the paper
        layers_per_block = (96, 48, 36, 24)
        num_initial_features = 96
    else:
        raise Exception('Invalid version={v}.'.format(v=version))
    batch_norm_size = 4

    parent_node, cumulative_layer_num = initial_layer(
        cumulative_layer_num, images_node,
        num_initial_features)
    num_features = num_initial_features
    # Start counting dense blocks at 1.
    for current_block_num, num_layers in enumerate(layers_per_block, 1):
        parent_nodes, cumulative_layer_num = dense_block(
            cumulative_layer_num,
            parent_node,
            batch_norm_size=batch_norm_size,
            current_block_num=current_block_num,
            growth_rate=growth_rate,
            num_layers=num_layers,
            num_initial_channels=num_initial_features
        )
        # num_features += num_layers * growth_rate
        for node in parent_nodes[1:]:
            num_features += node.num_output_channels
        parent_node = lbann.Concatenation(parent_nodes)
        cumulative_layer_num += 1
        log('densenet Concatenation. cumulative_layer_num={n}'.format(
            b=current_block_num, n=cumulative_layer_num))
        if current_block_num != len(layers_per_block):
            parent_node, cumulative_layer_num = transition_layer(
                current_block_num,
                cumulative_layer_num,
                parent_node,
                # In Python 3, this is integer division.
                num_output_channels=num_features//2,
            )
            num_features //= 2

    batch_normalization_node = standard_batchnorm(parent_node)
    cumulative_layer_num += 1
    log('densenet BatchNormalization. cumulative_layer_num={n}'.format(
        b=current_block_num, n=cumulative_layer_num))

    probs = classification_layer(
        cumulative_layer_num,
        batch_normalization_node
    )
    return probs
Exemplo n.º 8
0
    def forward(self, x):
        """Perform LSTM step.

        State from previous steps is used to compute output.

        """
        self.step += 1
        name = '{0}_step{1}'.format(self.name, self.step)

        # Apply linearity
        input_concat = lbann.Concatenation([x, self.last_output],
                                           name=name + '_input',
                                           data_layout=self.data_layout)
        fc = self.fc(input_concat)

        # Get gates and cell update
        slice = lbann.Slice(fc,
                            slice_points=_str_list([0, self.size, 4*self.size]),
                            name=name + '_fc_slice',
                            data_layout=self.data_layout)
        cell_update = lbann.Tanh(slice,
                                 name=name + '_cell_update',
                                 data_layout=self.data_layout)
        sigmoid = lbann.Sigmoid(slice,
                                name=name + '_sigmoid',
                                data_layout=self.data_layout)
        slice = lbann.Slice(sigmoid,
                            slice_points=_str_list([0, self.size, 2*self.size, 3*self.size]),
                            name=name + '_sigmoid_slice',
                            data_layout=self.data_layout)
        f = lbann.Identity(slice, name=name + '_forget_gate',
                           data_layout=self.data_layout)
        i = lbann.Identity(slice, name=name + '_input_gate',
                           data_layout=self.data_layout)
        o = lbann.Identity(slice, name=name + '_output_gate',
                           data_layout=self.data_layout)

        # Cell state
        cell_forget = lbann.Multiply([f, self.last_cell],
                                     name=name + '_cell_forget',
                                     data_layout=self.data_layout)
        cell_input = lbann.Multiply([i, cell_update],
                                    name=name + '_cell_input',
                                    data_layout=self.data_layout)
        cell = lbann.Add([cell_forget, cell_input], name=name + '_cell',
                         data_layout=self.data_layout)

        # Output
        cell_act = lbann.Tanh(cell, name=name + '_cell_activation',
                              data_layout=self.data_layout)
        output = lbann.Multiply([o, cell_act], name=name,
                                data_layout=self.data_layout)

        # Update state and return output
        self.last_cell = cell
        self.last_output = output
        return output
Exemplo n.º 9
0
 def forward_discriminator2(self,y):
     ch2 = self.inv_transform(lbann.Identity(y))
     y = lbann.Concatenation(lbann.Identity(y),ch2,axis=0)
     img = lbann.Reshape(y, dims='2 128 128')
     x = lbann.LeakyRelu(self.d2_conv[0](img), negative_slope=0.2)
     x = lbann.LeakyRelu(self.d2_conv[1](x), negative_slope=0.2)
     x = lbann.LeakyRelu(self.d2_conv[2](x), negative_slope=0.2)
     x = lbann.LeakyRelu(self.d2_conv[3](x), negative_slope=0.2)
     return self.d2_fc(lbann.Reshape(x,dims='32768')) 
Exemplo n.º 10
0
    def forward(self, x, x_concat=None):
        self.instance += 1
        if x_concat is not None:
            x = lbann.Concatenation([x, x_concat], axis=0)

        for c in self.convs:
            x = c(x)

        return x
Exemplo n.º 11
0
    def forward(
        self,
        motif_indices,
        motif_size,
        walk_indices,
        walk_length,
    ):

        # Apply generator
        fake_motif_indices, gen_prob, gen_log_prob = self.generator(
            walk_length,
            walk_indices,
            motif_size,
        )

        # Get discriminator embeddings in log-space
        all_motif_indices = lbann.Concatenation(motif_indices,
                                                fake_motif_indices)
        all_motif_log_embeddings = self.discriminator.get_log_embeddings(
            all_motif_indices)
        all_motif_log_embeddings = lbann.Slice(
            all_motif_log_embeddings,
            slice_points=str_list([0, motif_size, 2 * motif_size]),
        )
        real_motif_log_embeddings = lbann.Identity(all_motif_log_embeddings)
        fake_motif_log_embeddings = lbann.Identity(all_motif_log_embeddings)

        # Apply discriminator
        real_disc_prob, real_disc_log_not_prob \
            = self.discriminator(motif_size, real_motif_log_embeddings)
        fake_disc_prob, fake_disc_log_not_prob \
            = self.discriminator(motif_size, fake_motif_log_embeddings)

        # Loss function
        # L_disc = - log(D(real)) - log(1-D(fake))
        # L_gen = - log(G) * stop_gradient(log(1-D(fake)))
        real_disc_log_prob \
            = lbann.Log(lbann.Clamp(real_disc_prob, min=1e-37, max=1))
        disc_loss = lbann.WeightedSum(
            real_disc_log_prob,
            fake_disc_log_not_prob,
            scaling_factors=str_list([-1, -1]),
        )
        gen_loss = lbann.Multiply(
            gen_log_prob,
            lbann.StopGradient(fake_disc_log_not_prob),
        )
        loss = lbann.Add(disc_loss, gen_loss)

        return loss, real_disc_prob, fake_disc_prob, gen_prob
Exemplo n.º 12
0
    def get_mat(self, cols = None):
        """Generates a matrix representation of the graph data.

           args: cols (int) 
        """
        
        mat = lbann.Concatenation(self.layers)

        if (cols):
            mat = lbann.Reshape(mat, dims=str_list([self.shape[0], cols]))    
        else:
            mat = lbann.Reshape(mat, dims=str_list([self.shape[0], self.shape[1]]))

        return mat
Exemplo n.º 13
0
 def decoder_cnn(self, z):
     x = self.dec_cnn_fc(z)
     sca = self.dec_fc_sca(lbann.Identity(x))
     img = lbann.Reshape(lbann.Identity(x),
                         dims="16 8 8",
                         name=self.name + 'dec_reshape0')
     img = self.dec_convT[2](lbann.Relu(self.dec_convT[1](lbann.Relu(
         self.dec_convT[0](img)))))
     #concat for common interface, slice in output
     img = lbann.Reshape(img,
                         dims=str(64 * 64 * 4),
                         name=self.name +
                         'dec_reshape1')  #?? check tensor shape
     #todo check that concat size == dec_out_dim
     return lbann.Concatenation([img, sca], axis=0)
Exemplo n.º 14
0
 def forward(self, x):
     x_slice = lbann.Slice(x,
                           axis=0,
                           slice_points="0 921 4750 8579",
                           name='inp_slice')
     gene = self.geneT(lbann.Identity(x_slice))
     drug1 = self.drug1T(lbann.Identity(x_slice))
     drug2 = self.drug2T(lbann.Identity(x_slice))
     concat = self.concatT(
         lbann.Concatenation([gene, drug1, drug2],
                             name=self.name + 'concat'))
     response_fc = lbann.FullyConnected(concat,
                                        num_neurons=1,
                                        has_bias=True)
     return response_fc
Exemplo n.º 15
0
Arquivo: vae.py Projeto: oyamay/lbann
    def forward_decoder(self, x_emb, z):
        """Decoder step, emulating x ~ G(z)

        :param x_emb: (n_batch, len(x), d_z) of floats, embeddings for input sentence x
        :param z: (n_batch, d_z) of floats, latent vector z
        :return: float, recon component of loss
        :return: list of ints, reconstructed sentence
        """

        # z_0 = z.unsqueeze(1).repeat(1, x_emb.size(1), 1)
        # x_input = torch.cat([x_emb, z_0], dim=-1)
        z_0 = lbann.Tessellate(
            lbann.Reshape(z, dims=str_list([1, 128])),
            dims=str_list([self.input_feature_dims, 128]),
        )
        x_input = lbann.Concatenation(x_emb, z_0, axis=1)

        h_0 = self.decoder_lat(z)
        # h_0 = h_0.unsqueeze(0).repeat(self.decoder_rnn.num_layers, 1, 1)
        h_0 = lbann.Reshape(h_0, dims=str_list([1, 512]))
        h_0 = lbann.Tessellate(h_0, dims=str_list((3, 512)))

        # output, _ = self.decoder_rnn(x_input, h_0)
        output = self.decoder_rnn(x_input, h_0)

        # y = self.decoder_fc(output)
        y = lbann.ChannelwiseFullyConnected(
            output,
            output_channel_dims=self.dictionary_size,
            bias=True,
            name=f'{self.decoder_fc.name}',
            weights=self.decoder_fc.weights,
        )

        # Set datatype of layers
        # Note: Depth-first search from y to x_emb and z
        stack = [y]
        in_stack = {l: True for l in stack}
        while stack:
            l = stack.pop()
            if type(l) not in (lbann.Slice, lbann.Reshape, lbann.Tessellate):
                l.datatype = self.datatype
            for parent in l.parents:
                if parent not in in_stack and parent not in (x_emb, z):
                    stack.append(parent)
                    in_stack[parent] = True

        return y
Exemplo n.º 16
0
 def encoder_cnn(self, y):
     img_sca = lbann.Slice(y,
                           axis=0,
                           slice_points="0 16384 16399",
                           name=self.name + '_y_slice')
     #assume C first, is data C first?
     img = lbann.Reshape(img_sca,
                         dims='4 64 64',
                         name=self.name + 'enc_reshape0')
     x = self.enc_conv[2](self.enc_conv[1](self.enc_conv[0](img)))
     x = lbann.Reshape(x,
                       dims=str(16 * 8 * 8),
                       name=self.name + 'enc_reshape1')
     h_stack = lbann.Concatenation([x, img_sca], axis=0)
     z = self.enc_out(h_stack)
     return z
Exemplo n.º 17
0
    def forward(self, X, A):
        """Call the GatedGraphConv
        Args:
            X (GraphVertexData): LBANN Data object, which is a collection of Layers. Each Layer is of
                                 the shape (1,input_channels) 
            A (Layer): Adjacency matrix input with shape (num_nodes, num_nodes)
        Returns: 
            LBANN_Data_Mat: The output after Gated Graph Kernel. 
                        The output can passed into another Graph Conv layer directly

        """

        input_features = X.size(1)
        num_nodes = X.size(0)

        if (input_features < self.output_channels):
            for i in range(num_nodes):
                num_zeros = self.output_channels - input_features
                zeros = lbann.Constant(value=0,
                                       num_neurons=str_list([1, num_zeros]),
                                       name=self.name + '_zero_' + str(i))
                X[i] = lbann.Concatenation(X[i], zeros, axis=1)
        elif (input_features > self.output_channels):
            ValueError(
                'The feature size of the nodes {} cannot be greater than the output dimension {}'
                .format(input_features, self.output_channels))

        X.update_num_features(self.output_channels)

        for layer in range(self.num_layers):
            ##
            X_mat = X.get_mat()
            messages = lbann.MatMul(X_mat, self.weights[layer])
            aggregate = lbann.MatMul(A, messages)

            M = GraphVertexData.matrix_to_graph(aggregate, num_nodes,
                                                self.output_channels)

            for i in range(num_nodes):
                X[i] = lbann.Reshape(X[i], dims=str(self.output_channels))
                X[i] = lbann.Reshape(self.rnn(M[i], X[i])[1],
                                     dims=str_list([1, self.output_channels]))

        return X
Exemplo n.º 18
0
    def forward_generator(self, z, mcr):
        '''
        Build the Generator
        '''
        x = lbann.Relu(
            lbann.BatchNormalization(self.g_fc1(z),
                                     decay=0.9,
                                     scale_init=1.0,
                                     epsilon=1e-5))
        dims = '512 8 8'

        print("dims", dims)
        x = lbann.Reshape(x, dims=dims)  #channel first
        x = lbann.Relu(
            lbann.BatchNormalization(self.g_convT[0](x),
                                     decay=0.9,
                                     scale_init=1.0,
                                     epsilon=1e-5))
        x = lbann.Relu(
            lbann.BatchNormalization(self.g_convT[1](x),
                                     decay=0.9,
                                     scale_init=1.0,
                                     epsilon=1e-5))
        x = lbann.Relu(
            lbann.BatchNormalization(self.g_convT[2](x),
                                     decay=0.9,
                                     scale_init=1.0,
                                     epsilon=1e-5))
        img = self.g_convT3(x)

        if mcr:  ### For multi-channel rescaling, add extra channel to output image
            linear_scale = 1 / self.linear_scaler
            #ch2 = lbann.Tanh(self.inv_transform(img)/linear_scalar)
            ch2 = lbann.Tanh(
                lbann.WeightedSum(self.inv_transform(img),
                                  scaling_factors=str(linear_scale)))
            y = lbann.Concatenation(img, ch2, axis=0)
            img = lbann.Reshape(y, dims='2 128 128')
        else:
            img = lbann.Reshape(img, dims='1 128 128')

        print('Gen Img in GAN', img.__dict__)
        return img
Exemplo n.º 19
0
    def forward(self, x, prev_state):
        """Apply LSTM step.

        Args:
            x (Layer): Input.
            prev_state (tuple with two `Layer`s): State from previous
                LSTM step. Comprised of LSTM output and cell state.

        Returns:
            (Layer, (Layer, Layer)): The output and state (the output
                and cell state). The state can be passed directly into
                the next LSTM step.

        """
        self.step += 1
        name = '{0}_step{1}'.format(self.name, self.step)

        # Get output and cell state from previous step
        prev_output, prev_cell = prev_state

        # Apply linearity
        input_concat = lbann.Concatenation(x, prev_output,
                                           name=name + '_input',
                                           data_layout=self.data_layout)
        fc = self.fc(input_concat)

        # Get gates and cell update
        slice = lbann.Slice(fc,
                            slice_points=str_list([0, self.size, 4*self.size]),
                            name=name + '_fc_slice',
                            data_layout=self.data_layout)
        cell_update = lbann.Tanh(slice,
                                 name=name + '_cell_update',
                                 data_layout=self.data_layout)
        sigmoid = lbann.Sigmoid(slice,
                                name=name + '_sigmoid',
                                data_layout=self.data_layout)
        slice = lbann.Slice(sigmoid,
                            slice_points=str_list([0, self.size, 2*self.size, 3*self.size]),
                            name=name + '_sigmoid_slice',
                            data_layout=self.data_layout)
        f = lbann.Identity(slice, name=name + '_forget_gate',
                           data_layout=self.data_layout)
        i = lbann.Identity(slice, name=name + '_input_gate',
                           data_layout=self.data_layout)
        o = lbann.Identity(slice, name=name + '_output_gate',
                           data_layout=self.data_layout)

        # Cell state
        cell_forget = lbann.Multiply(f, prev_cell,
                                     name=name + '_cell_forget',
                                     data_layout=self.data_layout)
        cell_input = lbann.Multiply(i, cell_update,
                                    name=name + '_cell_input',
                                    data_layout=self.data_layout)
        cell = lbann.Add(cell_forget, cell_input, name=name + '_cell',
                         data_layout=self.data_layout)

        # Output
        cell_act = lbann.Tanh(cell, name=name + '_cell_activation',
                              data_layout=self.data_layout)
        output = lbann.Multiply(o, cell_act, name=name,
                                data_layout=self.data_layout)

        # Return output and state
        return output, (output, cell)
Exemplo n.º 20
0
def construct_model(run_args):
    """Construct LBANN model.

    Initial model for ATOM molecular VAE

    """
    import lbann

    print("Dump model dir ", run_args.dump_model_dir)
    assert run_args.dump_model_dir, "evaluate script asssumes a pretrained WAE model"
    pad_index = run_args.pad_index
    assert pad_index is not None

    sequence_length = run_args.sequence_length
    assert sequence_length is not None

    print("sequence length is {}".format(sequence_length))
    data_layout = "data_parallel"
    # Layer graph
    input_ = lbann.Identity(lbann.Input(name='inp', data_field='samples'),
                            name='inp1')
    wae_loss = []
    input_feature_dims = sequence_length

    embedding_size = run_args.embedding_dim
    dictionary_size = run_args.num_embeddings
    assert embedding_size is not None
    assert dictionary_size is not None

    save_output = True if run_args.dump_outputs_dir else False

    print("save output? ", save_output, "out dir ", run_args.dump_outputs_dir)
    z = lbann.Gaussian(mean=0.0, stdev=1.0, neuron_dims=run_args.z_dim)

    waemodel = molwae.MolWAE(input_feature_dims, dictionary_size,
                             embedding_size, pad_index, run_args.z_dim,
                             save_output)
    recon, d1_real, d1_fake, d_adv, arg_max = waemodel(input_, z)

    zero = lbann.Constant(value=0.0, num_neurons='1', name='zero')
    one = lbann.Constant(value=1.0, num_neurons='1', name='one')

    d1_real_bce = lbann.SigmoidBinaryCrossEntropy([d1_real, one],
                                                  name='d1_real_bce')
    d1_fake_bce = lbann.SigmoidBinaryCrossEntropy([d1_fake, zero],
                                                  name='d1_fake_bce')
    d_adv_bce = lbann.SigmoidBinaryCrossEntropy([d_adv, one], name='d_adv_bce')

    wae_loss.append(recon)

    layers = list(lbann.traverse_layer_graph(input_))
    # Setup objective function
    weights = set()
    src_layers = []
    dst_layers = []
    for l in layers:
        if (l.weights and "disc0" in l.name and "instance1" in l.name):
            src_layers.append(l.name)
        #freeze weights in disc2
        if (l.weights and "disc1" in l.name):
            dst_layers.append(l.name)
            for idx in range(len(l.weights)):
                l.weights[idx].optimizer = lbann.NoOptimizer()
        weights.update(l.weights)
    l2_weights = [
        w for w in weights if not isinstance(w.optimizer, lbann.NoOptimizer)
    ]
    l2_reg = lbann.L2WeightRegularization(weights=l2_weights, scale=1e-4)

    wae_loss.append(d1_real_bce)
    wae_loss.append(d_adv_bce)
    wae_loss.append(d1_fake_bce)
    wae_loss.append(l2_reg)
    print("LEN wae loss ", len(wae_loss))

    obj = lbann.ObjectiveFunction(wae_loss)

    # Initialize check metric callback
    metrics = [
        lbann.Metric(d_adv_bce, name='adv_loss'),
        lbann.Metric(recon, name='recon')
    ]

    callbacks = [
        lbann.CallbackPrint(),
        #lbann.CallbackStepLearningRate(step=10, amt=0.5),
        lbann.CallbackTimer()
    ]

    callbacks.append(
        lbann.CallbackReplaceWeights(source_layers=list2str(src_layers),
                                     destination_layers=list2str(dst_layers),
                                     batch_interval=2))

    #Dump output (activation) for post processing
    if (run_args.dump_outputs_dir):
        pred_tensor = lbann.Concatenation(arg_max, name='pred_tensor')
        callbacks.append(
            lbann.CallbackDumpOutputs(
                batch_interval=run_args.dump_outputs_interval,
                execution_modes='test',
                directory=run_args.dump_outputs_dir,
                layers=f'inp pred_tensor {waemodel.q_mu.name}'))
    # Construct model
    return lbann.Model(run_args.num_epochs,
                       weights=weights,
                       layers=layers,
                       objective_function=obj,
                       metrics=metrics,
                       callbacks=callbacks)
Exemplo n.º 21
0
def construct_model(run_args):
    """Construct LBANN model.

    Initial model for ATOM molecular VAE

    """
    import lbann

    pad_index = run_args.pad_index
    assert pad_index is not None

    sequence_length = run_args.sequence_length
    assert sequence_length is not None, 'should be training seq len + bos + eos'

    print("sequence length is {}, which is training sequence len + bos + eos".format(sequence_length))
    data_layout = "data_parallel"
    # Layer graph
    input_ = lbann.Input(data_field='samples',name='inp_data')
    #Note input assumes to come from encoder script concatenation of input smiles + z
    inp_slice = lbann.Slice(input_, axis=0,
                             slice_points=str_list([0, sequence_length, sequence_length+run_args.z_dim]),
                             name='inp_slice')
    inp_smile = lbann.Identity(inp_slice,name='inp_smile')
    z  = lbann.Identity(inp_slice, name='z')
    wae_loss= []
    input_feature_dims = sequence_length

    embedding_size = run_args.embedding_dim
    dictionary_size = run_args.num_embeddings
    assert embedding_size is not None
    assert dictionary_size is not None

    save_output = True if run_args.dump_outputs_dir else False

    print("save output? ", save_output, "out dir ",  run_args.dump_outputs_dir)
    #uncomment below for random sampling
    #z = lbann.Gaussian(mean=0.0,stdev=1.0, neuron_dims=str(run_args.z_dim))
    x = lbann.Slice(inp_smile, slice_points=str_list([0, input_feature_dims]))
    x = lbann.Identity(x)
    waemodel = molwae.MolWAE(input_feature_dims,
                           dictionary_size,
                           embedding_size,
                           pad_index,run_args.z_dim,save_output=save_output)
    x_emb = lbann.Embedding(
            x,
            num_embeddings=waemodel.dictionary_size,
            embedding_dim=waemodel.embedding_size,
            name='emb',
            weights=waemodel.emb_weights
    )


    pred, arg_max = waemodel.forward_decoder(x_emb,z)

    recon = waemodel.compute_loss(x, pred)



    wae_loss.append(recon)

    layers = list(lbann.traverse_layer_graph(input_))
    # Setup objective function
    weights = set()
    for l in layers:
      weights.update(l.weights)
    #l2_reg = lbann.L2WeightRegularization(weights=weights, scale=1e-4)

    #wae_loss.append(l2_reg)
    print("LEN wae loss ", len(wae_loss))

    obj = lbann.ObjectiveFunction(wae_loss)

    # Initialize check metric callback
    metrics = [lbann.Metric(recon, name='recon')]

    callbacks = [lbann.CallbackPrint(),
                 lbann.CallbackTimer()]


    #Dump output (activation) for post processing
    pred_tensor = lbann.Concatenation(arg_max, name='pred_tensor')
    conc_out = lbann.Concatenation([input_,pred_tensor], name='conc_out')
    callbacks.append(lbann.CallbackDumpOutputs(batch_interval=run_args.dump_outputs_interval,
                       execution_modes='test',
                       directory=run_args.dump_outputs_dir,
                       layers=f'{conc_out.name}'))
    # Construct model
    return lbann.Model(run_args.num_epochs,
                       weights=weights,
                       layers=layers,
                       objective_function=obj,
                       metrics=metrics,
                       callbacks=callbacks)
Exemplo n.º 22
0
    def forward(self, queries, keys, values, mask=None):
        """Apply multi-head attention.

        The input and output tensors are interpreted as sequences of
        vectors, where the first tensor dimension is the sequence
        dimension.

        Args:
            queries (lbann.Layer): Sequence of query vectors.
            keys (lbann.Layer): Sequence of key vectors.
            values (lbann.Layer): Sequence of value vectors.
            mask (lbann.Layer, optional): Additive attention mask. If
                the (i,j) entry is very negative (e.g. -1e9), then the
                ith query does not attend to the jth key/value pair.

        Returns:
            lbann.Layer: Sequence of output vectors. The sequence
                length is the same as `queries`.

        """
        ENABLE_SUBGRAPH = self.ENABLE_SUBGRAPH
        BRANCHES = self.BRANCHES
        if (ENABLE_SUBGRAPH):
            if (self.num_heads % BRANCHES != 0):
                raise ValueError('Num heads should be divisible by BRANCHES')
        self.instance += 1
        name = f'{self.name}_instance{self.instance}'

        # Apply fully-connected layers to input sequences
        queries_fc = []
        keys_fc = []
        values_fc = []

        # Slice embedding vectors for each head
        slice_points = str_list(
            self.head_dim * i
            for i in range(int(self.num_heads / self.BRANCHES) + 1))

        #Queries strong scaling in CFC
        attentions = []
        for count, query in enumerate(queries):
            temp = lbann.ChannelwiseFullyConnected(
                query,
                weights=self.query_weights[count],
                output_channel_dims=[self.inner_dim],
                name=f'{name}_subgrid{count}_queries_fc',
            )
            attentions.append(temp)

        grid_sum_slice = lbann.Cross_Grid_Sum_Slice(attentions)

        attentions = []

        for head in range(self.BRANCHES):
            attentions.append(lbann.Identity(grid_sum_slice))

        for head in range(self.BRANCHES):
            temp = lbann.Slice(
                attentions[head],
                axis=1,
                slice_points=slice_points,
                name=f'{name}_subgrid{head}_queries_slice',
            )

            queries_fc.append(temp)

        #keys strong scaling in CFC

        attentions = []
        for count, key in enumerate(keys):
            temp = lbann.ChannelwiseFullyConnected(
                key,
                weights=self.key_weights[count],
                output_channel_dims=[self.inner_dim],
                name=f'{name}_subgrid{count}_keys_fc',
            )

            attentions.append(temp)

        grid_sum_slice = lbann.Cross_Grid_Sum_Slice(attentions)

        attentions = []

        for head in range(self.BRANCHES):
            attentions.append(lbann.Identity(grid_sum_slice))

        for head in range(self.BRANCHES):

            temp = lbann.Slice(
                attentions[head],
                axis=1,
                slice_points=slice_points,
                name=f'{name}_subgrid{head}_keys_slice',
            )

            keys_fc.append(temp)

        #Values strong scaling in CFC
        attentions = []

        for count, value in enumerate(values):
            temp = lbann.ChannelwiseFullyConnected(
                value,
                weights=self.value_weights[count],
                output_channel_dims=[self.inner_dim],
                name=f'{name}_subgrid{count}_values_fc',
            )
            attentions.append(temp)

        grid_sum_slice = lbann.Cross_Grid_Sum_Slice(attentions)

        attentions = []

        for head in range(self.BRANCHES):
            attentions.append(lbann.Identity(grid_sum_slice))

        for head in range(self.BRANCHES):
            temp = lbann.Slice(
                attentions[head],
                axis=1,
                slice_points=slice_points,
                name=f'{name}_subgrid{head}_values_slice',
            )
            values_fc.append(temp)

        queries_slice = []
        keys_slice = []
        values_slice = []

        for branch in range(self.BRANCHES):
            querie_slice = queries_fc[branch]
            key_slice = keys_fc[branch]
            value_slice = values_fc[branch]

            for head in range(int(self.num_heads / self.BRANCHES)):
                queries_slice.append(lbann.Identity(querie_slice))
                keys_slice.append(lbann.Identity(key_slice))
                values_slice.append(lbann.Identity(value_slice))

        # Compute scaled dot-product attention for each head
        attentions = []

        #variable to combine heads locally in sub-grids
        temp_attentions = []
        tag = 0
        for head in range(self.num_heads):
            head_name = f'{name}_myattention_head{head}'

            # Attention inputs
            if (head % int(self.num_heads / BRANCHES) == 0):
                temp_attentions.append([])
                tag += 1

            q = lbann.Identity(queries_slice[head])
            k = lbann.Identity(keys_slice[head])
            v = lbann.Identity(values_slice[head])

            # Multiply queries and keys
            # Note: num_queries x num_keys
            y = lbann.MatMul(
                q,
                k,
                transpose_b=True,
                name=f'{head_name}_matmul',
            )
            y = lbann.WeightedSum(
                y,
                scaling_factors=str(1 / math.sqrt(self.head_dim)),
                name=f'{head_name}_scale',
            )

            if (ENABLE_SUBGRAPH):
                if mask != None:
                    y = lbann.Sum([y, mask[tag]], name=f'{head_name}_mask')
            else:
                if mask:
                    y = lbann.Sum([y, mask], name=f'{head_name}_mask')
            y = lbann.ChannelwiseSoftmax(y, name=f'{head_name}_softmax')

            # Attention output
            # Note: num_queries x head_dim
            y = lbann.MatMul(y, v, name=head_name)
            # attentions.append(lbann.MatMul(y, v, name=head_name))

            temp_attentions[-1].append(y)

        for count, temp_attention in enumerate(temp_attentions):

            if (self.BRANCHES == self.num_heads):
                # No need to concat the heads at subgrid level
                # if number of subgrids is equal to number of heads
                attention_single_subgrid = temp_attentions[count][0]
            else:
                attention_single_subgrid = lbann.Concatenation(
                    temp_attention,
                    axis=1,
                    name=f'{name}_subgrid_heads_concat{count}',
                    parallel_strategy={
                        'sub_branch_tag': 0,
                        'enable_subgraph': False
                    })

            attention_single_subgrid = lbann.ChannelwiseFullyConnected(
                attention_single_subgrid,
                weights=self.output_weights[count],
                output_channel_dims=[self.embed_dim],
                name=f'{name}_cfc_{count}',
            )

            attentions.append(attention_single_subgrid)

        #Strong scaling

        grid_sum_slice = lbann.Cross_Grid_Sum_Slice(attentions)

        attentions = []

        for head in range(self.BRANCHES):
            attentions.append(lbann.Identity(grid_sum_slice))

        return attentions
Exemplo n.º 23
0
def construct_model(run_args):
    """Construct LBANN model.

    Initial model for ATOM molecular VAE

    """
    import lbann

    print("Dump model dir ", run_args.dump_model_dir)
    assert run_args.dump_model_dir, "evaluate script asssumes a pretrained WAE model"
    pad_index = run_args.pad_index
    assert pad_index is not None

    sequence_length = run_args.sequence_length
    assert sequence_length is not None

    print("sequence length is {}".format(sequence_length))
    data_layout = "data_parallel"
    # Layer graph
    input_ = lbann.Identity(lbann.Input(name='inp', target_mode="N/A"),
                            name='inp1')
    wae_loss = []
    input_feature_dims = sequence_length

    embedding_size = run_args.embedding_dim
    dictionary_size = run_args.num_embeddings
    assert embedding_size is not None
    assert dictionary_size is not None

    save_output = False

    print("save output? ", save_output, "out dir ", run_args.dump_outputs_dir)
    z = lbann.Gaussian(mean=0.0, stdev=1.0, neuron_dims="128")

    x = lbann.Slice(input_, slice_points=str_list([0, input_feature_dims]))
    x = lbann.Identity(x)
    waemodel = molwae.MolWAE(input_feature_dims, dictionary_size,
                             embedding_size, pad_index, save_output)
    x_emb = lbann.Embedding(x,
                            num_embeddings=waemodel.dictionary_size,
                            embedding_dim=waemodel.embedding_size,
                            name='emb',
                            weights=waemodel.emb_weights)

    latentz = waemodel.forward_encoder(x_emb)

    fake_loss = lbann.MeanAbsoluteError(latentz, z)

    layers = list(lbann.traverse_layer_graph(input_))
    # Setup objective function
    weights = set()
    for l in layers:
        weights.update(l.weights)

    obj = lbann.ObjectiveFunction(fake_loss)

    callbacks = [lbann.CallbackPrint(), lbann.CallbackTimer()]

    #Dump output (activation) for post processing
    conc_out = lbann.Concatenation([input_, latentz], name='conc_out')
    callbacks.append(
        lbann.CallbackDumpOutputs(
            batch_interval=run_args.dump_outputs_interval,
            execution_modes='test',
            directory=run_args.dump_outputs_dir,
            layers=f'{conc_out.name}'))
    # Construct model
    return lbann.Model(run_args.num_epochs,
                       weights=weights,
                       layers=layers,
                       objective_function=obj,
                       callbacks=callbacks)
Exemplo n.º 24
0
    def compute_loss(self, x, y):

        # y[:, :-1]
        y = lbann.Slice(
            y,
            axis=0,
            slice_points=str_list([0, self.input_feature_dims - 1]),
        )
        y = lbann.Identity(y)

        # x[:, 1:]
        x = lbann.Slice(
            x,
            slice_points=str_list([1, self.input_feature_dims]),
        )
        x = lbann.Identity(x)

        # Convert indices in x to one-hot representation
        # Note: Ignored indices result in zero vectors
        ignore_mask = lbann.Equal(
            x,
            self.constant(self.label_to_ignore, hint_layer=x),
        )
        keep_mask = lbann.LogicalNot(ignore_mask)
        length = lbann.Reduction(keep_mask, mode='sum')
        length = lbann.Max(length, self.constant(1, [1]))
        x = lbann.Add(
            lbann.Multiply(keep_mask, x),
            lbann.Multiply(ignore_mask, self.constant(-1, hint_layer=x)),
        )
        x = lbann.Slice(x,
                        slice_points=str_list(range(self.input_feature_dims)))
        x = [lbann.Identity(x) for _ in range(self.input_feature_dims - 1)]
        x = [lbann.OneHot(xi, size=self.dictionary_size) for xi in x]
        x = [
            lbann.Reshape(xi, dims=str_list([1, self.dictionary_size]))
            for xi in x
        ]
        x = lbann.Concatenation(x, axis=0)

        # recon_loss = F.cross_entropy(
        #     y[:, :-1].contiguous().view(-1, y.size(-1)),
        #     x[:, 1:].contiguous().view(-1),
        #     ignore_index=self.pad
        # )
        # Note: Ideally we'd shift y by y.max(-1) for numerical stability
        shifts = lbann.MatMul(
            lbann.Max(y, self.constant(0, hint_layer=y)),
            self.constant(
                1 / math.sqrt(self.dictionary_size),
                [self.dictionary_size, self.dictionary_size],
            ),
        )
        y = lbann.Subtract(y, shifts)
        z = lbann.MatMul(
            lbann.Exp(y),
            self.constant(1, [self.dictionary_size, 1]),
        )
        z = lbann.Log(z)
        z = lbann.MatMul(
            lbann.Reshape(keep_mask, dims=str_list([1, -1])),
            z,
        )
        recon_loss = lbann.MatMul(
            lbann.Reshape(y, dims=str_list([1, -1])),
            lbann.Reshape(x, dims=str_list([1, -1])),
            transpose_b=True,
        )
        recon_loss = lbann.Subtract(z, recon_loss)
        recon_loss = lbann.Reshape(recon_loss, dims=str_list([1]))
        recon_loss = lbann.Divide(recon_loss, length)

        return recon_loss
Exemplo n.º 25
0
    def forward(self, queries, keys, values, mask=None):
        """Apply multi-head attention.

        The input and output tensors are interpreted as sequences of
        vectors, where the first tensor dimension is the sequence
        dimension.

        Args:
            queries (lbann.Layer): Sequence of query vectors.
            keys (lbann.Layer): Sequence of key vectors.
            values (lbann.Layer): Sequence of value vectors.
            mask (lbann.Layer, optional): Additive attention mask. If
                the (i,j) entry is very negative (e.g. -1e9), then the
                ith query does not attend to the jth key/value pair.

        Returns:
            lbann.Layer: Sequence of output vectors. The sequence
                length is the same as `queries`.

        """
        ENABLE_SUBGRAPH = self.ENABLE_SUBGRAPH
        BRANCHES = self.BRANCHES
        if (ENABLE_SUBGRAPH):
            if (self.num_heads % BRANCHES != 0):
                raise ValueError('Num heads should be divisible by BRANCHES')
        self.instance += 1
        name = f'{self.name}_instance{self.instance}'

        # Apply fully-connected layers to input sequences
        queries_fc = lbann.ChannelwiseFullyConnected(
            queries,
            weights=self.query_weights,
            output_channel_dims=[self.inner_dim],
            name=f'{name}_queries_fc',
        )
        keys_fc = lbann.ChannelwiseFullyConnected(
            keys,
            weights=self.key_weights,
            output_channel_dims=[self.inner_dim],
            name=f'{name}_keys_fc',
        )
        values_fc = lbann.ChannelwiseFullyConnected(
            values,
            weights=self.value_weights,
            output_channel_dims=[self.inner_dim],
            name=f'{name}_values_fc',
        )

        # Slice embedding vectors for each head
        slice_points = str_list(self.head_dim * i
                                for i in range(self.num_heads + 1))
        queries_slice = lbann.Slice(queries_fc,
                                    axis=1,
                                    slice_points=slice_points,
                                    name=f'{name}_queries_slice',
                                    parallel_strategy={
                                        'sub_branch_tag': 0,
                                        'enable_subgraph': ENABLE_SUBGRAPH
                                    })
        keys_slice = lbann.Slice(keys_fc,
                                 axis=1,
                                 slice_points=slice_points,
                                 name=f'{name}_keys_slice',
                                 parallel_strategy={
                                     'sub_branch_tag': 0,
                                     'enable_subgraph': ENABLE_SUBGRAPH
                                 })
        values_slice = lbann.Slice(values_fc,
                                   axis=1,
                                   slice_points=slice_points,
                                   name=f'{name}_values_slice',
                                   parallel_strategy={
                                       'sub_branch_tag': 0,
                                       'enable_subgraph': ENABLE_SUBGRAPH
                                   })

        # Compute scaled dot-product attention for each head
        attentions = []
        tag = 0
        for head in range(self.num_heads):
            head_name = f'{name}_myattention_head{head}'

            # Attention inputs

            if (ENABLE_SUBGRAPH):
                if (head % int(self.num_heads / BRANCHES) == 0):
                    tag += 1

                q = lbann.Identity(queries_slice,
                                   parallel_strategy={
                                       'sub_branch_tag': tag,
                                       'enable_subgraph': ENABLE_SUBGRAPH
                                   })
                k = lbann.Identity(keys_slice,
                                   parallel_strategy={
                                       'sub_branch_tag': tag,
                                       'enable_subgraph': ENABLE_SUBGRAPH
                                   })
                v = lbann.Identity(values_slice,
                                   parallel_strategy={
                                       'sub_branch_tag': tag,
                                       'enable_subgraph': ENABLE_SUBGRAPH
                                   })
            else:
                q = lbann.Identity(queries_slice)
                k = lbann.Identity(keys_slice)
                v = lbann.Identity(values_slice)

            # Multiply queries and keys
            # Note: num_queries x num_keys
            y = lbann.MatMul(
                q,
                k,
                transpose_b=True,
                name=f'{head_name}_matmul',
            )
            y = lbann.WeightedSum(
                y,
                scaling_factors=str(1 / math.sqrt(self.head_dim)),
                name=f'{head_name}_scale',
            )

            if (ENABLE_SUBGRAPH):
                if mask != None:
                    y = lbann.Sum([y, mask[tag]], name=f'{head_name}_mask')
            else:
                if mask:
                    y = lbann.Sum([y, mask], name=f'{head_name}_mask')
            y = lbann.ChannelwiseSoftmax(y, name=f'{head_name}_softmax')

            # Attention output
            # Note: num_queries x head_dim

            attentions.append(lbann.MatMul(y, v, name=head_name))

            #Strong scaling

        # Concatenate heads and apply fully-connected layer
        if (ENABLE_SUBGRAPH):
            attentions = lbann.Concatenation(attentions,
                                             axis=1,
                                             name=f'{name}_heads_concat',
                                             parallel_strategy={
                                                 'sub_branch_tag': 0,
                                                 'enable_subgraph':
                                                 ENABLE_SUBGRAPH
                                             })
        else:
            attentions = lbann.Concatenation(
                attentions,
                axis=1,
                name=f'{name}_heads_concat',
            )

        outputs_fc = lbann.ChannelwiseFullyConnected(
            attentions,
            weights=self.output_weights,
            output_channel_dims=[self.embed_dim],
            name=f'{name}',
        )
        return outputs_fc
Exemplo n.º 26
0
    def forward(self, queries, keys, values, mask=None):
        """Apply multi-head attention.

        The input and output tensors are interpreted as sequences of
        vectors, where the first tensor dimension is the sequence
        dimension.

        Args:
            queries (lbann.Layer): Sequence of query vectors.
            keys (lbann.Layer): Sequence of key vectors.
            values (lbann.Layer): Sequence of value vectors.
            mask (lbann.Layer, optional): Additive attention mask. If
                the (i,j) entry is very negative (e.g. -1e9), then the
                ith query does not attend to the jth key/value pair.

        Returns:
            lbann.Layer: Sequence of output vectors. The sequence
                length is the same as `queries`.

        """
        self.instance += 1
        name = f'{self.name}_instance{self.instance}'

        # Apply fully-connected layers to input sequences
        queries_fc = lbann.ChannelwiseFullyConnected(
            queries,
            weights=self.query_weights,
            output_channel_dims=[self.embed_dim],
            name=f'{name}_queries_fc',
        )
        keys_fc = lbann.ChannelwiseFullyConnected(
            keys,
            weights=self.key_weights,
            output_channel_dims=[self.embed_dim],
            name=f'{name}_keys_fc',
        )
        values_fc = lbann.ChannelwiseFullyConnected(
            values,
            weights=self.value_weights,
            output_channel_dims=[self.embed_dim],
            name=f'{name}_values_fc',
        )

        # Slice embedding vectors for each head
        slice_points = str_list(self.head_dim * i
                                for i in range(self.num_heads + 1))
        queries_slice = lbann.Slice(
            queries_fc,
            axis=1,
            slice_points=slice_points,
            name=f'{name}_queries_slice',
        )
        keys_slice = lbann.Slice(
            keys_fc,
            axis=1,
            slice_points=slice_points,
            name=f'{name}_keys_slice',
        )
        values_slice = lbann.Slice(
            values_fc,
            axis=1,
            slice_points=slice_points,
            name=f'{name}_values_slice',
        )

        # Compute scaled dot-product attention for each head
        attentions = []
        for head in range(self.num_heads):
            head_name = f'{name}_head{head}'

            # Attention inputs
            q = lbann.Identity(queries_slice)
            k = lbann.Identity(keys_slice)
            v = lbann.Identity(values_slice)

            # Multiply queries and keys
            # Note: num_queries x num_keys
            y = lbann.MatMul(
                q,
                k,
                transpose_b=True,
                name=f'{head_name}_matmul',
            )
            y = lbann.WeightedSum(
                y,
                scaling_factors=str(1 / math.sqrt(self.head_dim)),
                name=f'{head_name}_scale',
            )
            if mask:
                y = lbann.Add(y, mask, name=f'{head_name}_mask')
            y = lbann.ChannelwiseSoftmax(y, name=f'{head_name}_softmax')

            # Attention output
            # Note: num_queries x head_dim
            attentions.append(lbann.MatMul(y, v, name=head_name))

        # Concatenate heads and apply fully-connected layer
        attentions = lbann.Concatenation(attentions,
                                         axis=1,
                                         name=f'{name}_heads_concat')
        outputs_fc = lbann.ChannelwiseFullyConnected(
            attentions,
            weights=self.output_weights,
            output_channel_dims=[self.embed_dim],
            name=f'{name}',
        )
        return outputs_fc
Exemplo n.º 27
0
def construct_model():
    """Construct MACC surrogate model.

    See https://arxiv.org/pdf/1912.08113.pdf model architecture and other details

    """
    import lbann

    # Layer graph
    input = lbann.Input(data_field='samples', name='inp_data')
    # data is 64*64*4 images + 15 scalar + 5 param
    inp_slice = lbann.Slice(input,
                            axis=0,
                            slice_points=str_list(
                                [0, args.ydim, args.ydim + args.xdim]),
                            name='inp_slice')
    gt_y = lbann.Identity(inp_slice, name='gt_y')
    gt_x = lbann.Identity(inp_slice, name='gt_x')  #param not used

    zero = lbann.Constant(value=0.0, num_neurons='1', name='zero')
    one = lbann.Constant(value=1.0, num_neurons='1', name='one')

    z = lbann.Gaussian(mean=0.0, stdev=1.0, neuron_dims="20")
    wae = macc_models.MACCWAE(args.zdim,
                              args.ydim,
                              cf=args.wae_mcf,
                              use_CNN=args.useCNN)  #pretrained, freeze
    inv = macc_models.MACCInverse(args.xdim, cf=args.surrogate_mcf)
    fwd = macc_models.MACCForward(args.zdim, cf=args.surrogate_mcf)

    y_pred_fwd = wae.encoder(gt_y)

    param_pred_ = wae.encoder(gt_y)
    input_fake = inv(param_pred_)

    output_cyc = fwd(input_fake)
    y_image_re2 = wae.decoder(output_cyc)
    '''**** Train cycleGAN input params <--> latent space of (images, scalars) ****'''
    output_fake = fwd(gt_x)
    y_image_re = wae.decoder(output_fake)

    y_out = wae.decoder(y_pred_fwd)

    param_pred2_ = wae.encoder(y_image_re)
    input_cyc = inv(param_pred2_)

    L_l2_x = lbann.MeanSquaredError(
        input_fake, gt_x)  #(x,inv(enc(y)), (encoder+)inverse loss
    L_cyc_x = lbann.MeanSquaredError(
        input_cyc, gt_x)  #param, x cycle loss, from latent space

    L_l2_y = lbann.MeanSquaredError(
        output_fake, y_pred_fwd)  #pred error into latent space (enc(y),fw(x))
    L_cyc_y = lbann.MeanSquaredError(
        output_cyc,
        y_pred_fwd)  # pred error into latent space (enc(y), fw(inv(enc(y))))

    #@todo slice here to separate scalar from image
    img_sca_loss = lbann.MeanSquaredError(
        y_image_re,
        gt_y)  # (y,dec(fw(x))) #forward model to decoder, no latent space
    dec_fw_inv_enc_y = lbann.MeanSquaredError(
        y_image_re2, gt_y)  #(y, dec(fw(inv(enc(y))))) y->enc_z->x'->fw_z->y'
    wae_loss = lbann.MeanSquaredError(y_out, gt_y)  #(y, dec(enc(y)) '
    #L_cyc = L_cyc_y + L_cyc_x
    L_cyc = lbann.Add(L_cyc_y, L_cyc_x)

    #loss_gen0  = L_l2_y + lamda_cyc*L_cyc
    loss_gen0 = lbann.WeightedSum([L_l2_y, L_cyc],
                                  scaling_factors=f'1 {args.lamda_cyc}')
    loss_gen1 = lbann.WeightedSum([L_l2_x, L_cyc_y],
                                  scaling_factors=f'1 {args.lamda_cyc}')
    #loss_gen1  =  L_l2_x + lamda_cyc*L_cyc_y

    conc_out = lbann.Concatenation(
        [gt_x, wae_loss, img_sca_loss, dec_fw_inv_enc_y, L_l2_x],
        name='x_errors')
    layers = list(lbann.traverse_layer_graph(input))
    weights = set()
    for l in layers:
        weights.update(l.weights)

    # Setup objective function
    obj = lbann.ObjectiveFunction([loss_gen0, loss_gen1])
    # Initialize check metric callback
    metrics = [
        lbann.Metric(img_sca_loss, name='img_re1'),
        lbann.Metric(dec_fw_inv_enc_y, name='img_re2'),
        lbann.Metric(wae_loss, name='wae_loss'),
        lbann.Metric(L_l2_x, name='inverse loss'),
        lbann.Metric(L_cyc_y, name='output cycle loss'),
        lbann.Metric(L_cyc_x, name='param cycle loss')
    ]

    callbacks = [
        lbann.CallbackPrint(),
        lbann.CallbackDumpOutputs(layers=f'{conc_out.name}',
                                  execution_modes='test',
                                  directory=args.dump_outputs,
                                  batch_interval=1,
                                  format='npy'),
        lbann.CallbackTimer()
    ]

    # Construct model
    num_epochs = 1
    return lbann.Model(num_epochs,
                       weights=weights,
                       layers=layers,
                       serialize_io=True,
                       metrics=metrics,
                       objective_function=obj,
                       callbacks=callbacks)
Exemplo n.º 28
0
def make_model(
    num_epochs,
    embed_dim,
    num_heads,
    label_smoothing,
):

    # Embedding weights
    var = 2 / (embed_dim + vocab_size)  # Glorot initialization
    embedding_weights = lbann.Weights(
        name='embeddings',
        initializer=lbann.NormalInitializer(standard_deviation=math.sqrt(var)),
    )

    # Input is two sequences of token IDs
    input_ = lbann.Input(data_field='samples')

    # Get sequences of embedding vectors
    # Note: Scale embeddings by sqrt(embed_dim).
    # Note: Decoder input is shifted right, so embedding for last
    # token isn't needed.
    embeddings_tokens = lbann.Identity(
        lbann.Slice(
            input_,
            axis=0,
            slice_points=str_list([0, 2 * sequence_length - 1]),
        ))
    embeddings = lbann.Embedding(
        embeddings_tokens,
        weights=embedding_weights,
        num_embeddings=vocab_size,
        embedding_dim=embed_dim,
        padding_idx=pad_index,
    )
    embeddings = lbann.WeightedSum(
        embeddings,
        scaling_factors=str(math.sqrt(embed_dim)),
    )
    embeddings_slice = lbann.Slice(
        embeddings,
        axis=0,
        slice_points=str_list([0, sequence_length, 2 * sequence_length - 1]),
    )
    encoder_input = lbann.Identity(embeddings_slice)
    decoder_input = lbann.Identity(embeddings_slice)

    # Apply transformer model
    transformer = lbann.models.Transformer(
        hidden_size=embed_dim,
        num_heads=num_heads,
        name='transformer',
    )
    result = transformer(
        encoder_input,
        sequence_length,
        decoder_input,
        sequence_length - 1,
    )

    # Reconstruct decoder input
    preds = lbann.ChannelwiseFullyConnected(
        result,
        weights=embedding_weights,
        output_channel_dims=[vocab_size],
        bias=False,
        transpose=True,
    )
    preds = lbann.ChannelwiseSoftmax(preds)
    preds = lbann.Slice(preds,
                        axis=0,
                        slice_points=str_list(range(sequence_length)))
    preds = [lbann.Identity(preds) for _ in range(sequence_length - 1)]

    # Count number of non-pad tokens
    label_tokens = lbann.Identity(
        lbann.Slice(
            input_,
            slice_points=str_list([sequence_length + 1, 2 * sequence_length]),
        ))
    pads = lbann.Constant(value=pad_index,
                          num_neurons=str(sequence_length - 1))
    is_not_pad = lbann.NotEqual(label_tokens, pads)
    num_not_pad = lbann.Reduction(is_not_pad, mode='sum')

    # Cross entropy loss with label smoothing
    label_tokens = lbann.Slice(
        label_tokens,
        slice_points=str_list(range(sequence_length)),
    )
    label_tokens = [
        lbann.Identity(label_tokens) for _ in range(sequence_length - 1)
    ]
    if label_smoothing > 0:
        uniform_label = lbann.Constant(value=1 / vocab_size,
                                       num_neurons=str_list([1, vocab_size]))
    loss = []
    for i in range(sequence_length - 1):
        label = lbann.OneHot(label_tokens[i], size=vocab_size)
        label = lbann.Reshape(label, dims=str_list([1, vocab_size]))
        if label_smoothing > 0:
            label = lbann.WeightedSum(
                label,
                uniform_label,
                scaling_factors=str_list(
                    [1 - label_smoothing, label_smoothing]),
            )
        loss.append(lbann.CrossEntropy(preds[i], label))
    loss = lbann.Concatenation(loss)

    # Average cross entropy over non-pad tokens
    loss_scales = lbann.Divide(
        is_not_pad,
        lbann.Tessellate(num_not_pad, hint_layer=is_not_pad),
    )
    loss = lbann.Multiply(loss, loss_scales)
    loss = lbann.Reduction(loss, mode='sum')

    # Construct model
    metrics = []
    callbacks = [lbann.CallbackPrint(), lbann.CallbackTimer()]
    return lbann.Model(
        num_epochs,
        layers=lbann.traverse_layer_graph(input_),
        objective_function=loss,
        metrics=metrics,
        callbacks=callbacks,
    )
Exemplo n.º 29
0
def construct_model(run_args):
    """Construct LBANN model.

    Initial model for ATOM molecular VAE

    """
    import lbann

    pad_index = run_args.pad_index
    assert pad_index is not None

    #sequence_length = run_args.sequence_length
    sequence_length = 102
    assert sequence_length is not None

    print("sequence length is {}".format(sequence_length))
    data_layout = "data_parallel"
    # Layer graph
    input_ = lbann.Input(target_mode='N/A',name='inp_data')
    inp_slice = lbann.Slice(input_, axis=0, slice_points="0 102 230",name='inp_slice')
    inp_smile = lbann.Identity(inp_slice,name='inp_smile')
    z = lbann.Identity(inp_slice, name='z') #param not used
    #input_ = lbann.Identity(lbann.Input(name='inp',target_mode="N/A"), name='inp1')
    vae_loss= []
    input_feature_dims = sequence_length

    embedding_size = run_args.embedding_dim
    dictionary_size = run_args.num_embeddings
    assert embedding_size is not None
    assert dictionary_size is not None

    save_output = True if run_args.dump_outputs_dir else False

    #print("Inp smile len ", len(inp_smile), "z len ",  len(z))
    print("save output? ", save_output, "out dir ",  run_args.dump_outputs_dir)
    #z = lbann.Gaussian(mean=0.0,stdev=1.0, neuron_dims="128")
    x = lbann.Slice(inp_smile, slice_points=str_list([0, input_feature_dims]))
    x = lbann.Identity(x)
    waemodel = molwae.MolWAE(input_feature_dims,
                           dictionary_size,
                           embedding_size,
                           pad_index,save_output)
    x_emb = lbann.Embedding(
            x,
            num_embeddings=waemodel.dictionary_size,
            embedding_dim=waemodel.embedding_size,
            name='emb',
            weights=waemodel.emb_weights
    )

    
    pred, arg_max = waemodel.forward_decoder(x_emb,z)

    recon = waemodel.compute_loss(x, pred)



    vae_loss.append(recon)

    layers = list(lbann.traverse_layer_graph(input_))
    # Setup objective function
    weights = set()
    for l in layers:
      weights.update(l.weights)
    #l2_reg = lbann.L2WeightRegularization(weights=weights, scale=1e-4)

    #vae_loss.append(l2_reg)
    print("LEN vae loss ", len(vae_loss))

    obj = lbann.ObjectiveFunction(vae_loss)

    # Initialize check metric callback
    metrics = [lbann.Metric(recon, name='recon')]

    callbacks = [lbann.CallbackPrint(),
                 lbann.CallbackTimer()]


    #Dump output (activation) for post processing
    pred_tensor = lbann.Concatenation(arg_max, name='pred_tensor')
    conc_out = lbann.Concatenation([input_,pred_tensor], name='conc_out')
    callbacks.append(lbann.CallbackDumpOutputs(batch_interval=run_args.dump_outputs_interval, 
                       execution_modes='test', 
                       directory=run_args.dump_outputs_dir,
                       layers=f'{conc_out.name}'))
    # Construct model
    return lbann.Model(run_args.num_epochs,
                       weights=weights,
                       layers=layers,
                       objective_function=obj,
                       metrics=metrics,
                       callbacks=callbacks)
Exemplo n.º 30
0
def setup(num_patches=3,
          mini_batch_size=512,
          num_epochs=75,
          learning_rate=0.005,
          bn_statistics_group_size=2,
          fc_data_layout='model_parallel',
          warmup=True,
          checkpoint_interval=None):

    # Data dimensions
    patch_dims = patch_generator.patch_dims
    num_labels = patch_generator.num_labels(num_patches)

    # Extract tensors from data sample
    input = lbann.Input()
    slice_points = [0]
    for _ in range(num_patches):
        patch_size = functools.reduce(operator.mul, patch_dims)
        slice_points.append(slice_points[-1] + patch_size)
    slice_points.append(slice_points[-1] + num_labels)
    sample = lbann.Slice(input, slice_points=str_list(slice_points))
    patches = [
        lbann.Reshape(sample, dims=str_list(patch_dims))
        for _ in range(num_patches)
    ]
    labels = lbann.Identity(sample)

    # Siamese network
    head_cnn = modules.ResNet(
        bn_statistics_group_size=bn_statistics_group_size)
    heads = [head_cnn(patch) for patch in patches]
    heads_concat = lbann.Concatenation(heads)

    # Classification network
    class_fc1 = modules.FcBnRelu(
        4096,
        statistics_group_size=bn_statistics_group_size,
        name='siamese_class_fc1',
        data_layout=fc_data_layout)
    class_fc2 = modules.FcBnRelu(
        4096,
        statistics_group_size=bn_statistics_group_size,
        name='siamese_class_fc2',
        data_layout=fc_data_layout)
    class_fc3 = lbann.modules.FullyConnectedModule(num_labels,
                                                   activation=lbann.Softmax,
                                                   name='siamese_class_fc3',
                                                   data_layout=fc_data_layout)
    x = class_fc1(heads_concat)
    x = class_fc2(x)
    probs = class_fc3(x)

    # Setup objective function
    cross_entropy = lbann.CrossEntropy([probs, labels])
    l2_reg_weights = set()
    for l in lbann.traverse_layer_graph(input):
        if type(l) == lbann.Convolution or type(l) == lbann.FullyConnected:
            l2_reg_weights.update(l.weights)
    l2_reg = lbann.L2WeightRegularization(weights=l2_reg_weights, scale=0.0002)
    obj = lbann.ObjectiveFunction([cross_entropy, l2_reg])

    # Setup model
    metrics = [
        lbann.Metric(lbann.CategoricalAccuracy([probs, labels]),
                     name='accuracy',
                     unit='%')
    ]
    callbacks = [lbann.CallbackPrint(), lbann.CallbackTimer()]
    if checkpoint_interval:
        callbacks.append(
            lbann.CallbackCheckpoint(checkpoint_dir='ckpt',
                                     checkpoint_epochs=5))

    # Learning rate schedules
    if warmup:
        callbacks.append(
            lbann.CallbackLinearGrowthLearningRate(target=learning_rate *
                                                   mini_batch_size / 128,
                                                   num_epochs=5))
    callbacks.append(
        lbann.CallbackDropFixedLearningRate(drop_epoch=list(range(0, 100, 15)),
                                            amt=0.25))

    # Construct model
    model = lbann.Model(num_epochs,
                        layers=lbann.traverse_layer_graph(input),
                        objective_function=obj,
                        metrics=metrics,
                        callbacks=callbacks)

    # Setup optimizer
    opt = lbann.SGD(learn_rate=learning_rate, momentum=0.9)
    # opt = lbann.Adam(learn_rate=learning_rate, beta1=0.9, beta2=0.999, eps=1e-8)

    # Setup data reader
    data_reader = make_data_reader(num_patches)

    # Return experiment objects
    return model, data_reader, opt