예제 #1
0
파일: vae.py 프로젝트: szaman19/lbann
    def forward_encoder(self, x_emb):
        """Encoder step, emulating z ~ E(x) = q_E(z|x)

        :param x_emb: (n_batch, len(x), d_z) of floats, embeddings for input sentence x
        :return: (n_batch, d_z) of floats, sample of latent vector z
        :return: float, kl term component of loss
        """

        # _, h = self.encoder_rnn(x, None)
        h = self.encoder_rnn(x_emb, None)

        h = lbann.Slice(
            h,
            slice_points=str_list(
                [self.input_feature_dims - 1, self.input_feature_dims]),
            axis=0,
        )
        h = lbann.Identity(h)

        mu, logvar = self.q_mu(h), self.q_logvar(h)

        # Set datatype of previous layers
        # Note: Depth-first search from mu and logvar to x_emb
        stack = [mu, logvar]
        in_stack = {l: True for l in stack}
        while stack:
            l = stack.pop()
            if type(l) not in (lbann.Slice, lbann.Reshape, lbann.Tessellate):
                l.datatype = self.datatype
            for parent in l.parents:
                if parent not in in_stack and parent is not x_emb:
                    stack.append(parent)
                    in_stack[parent] = True

        # eps = torch.randn_like(mu)
        eps = lbann.Gaussian(mean=0, stdev=1, hint_layer=mu)

        # z = mu + (logvar / 2).exp() * eps
        z = lbann.Add([
            mu,
            (lbann.Multiply([
                lbann.Exp(lbann.WeightedSum(logvar, scaling_factors='0.5')),
                eps
            ]))
        ])

        # kl_loss = 0.5 * (logvar.exp() + mu ** 2 - 1 - logvar).sum(1).mean()
        kl_loss = lbann.Reduction(
            lbann.WeightedSum(
                lbann.Exp(logvar),
                lbann.Square(mu),
                self.constant(1, hint_layer=mu),
                logvar,
                scaling_factors='0.5 0.5 -0.5 -0.5',
            ),
            mode='sum',
        )

        return z, kl_loss
예제 #2
0
 def inv_transform(self,y): 
     inv_transform = lbann.WeightedSum(
                                   lbann.SafeDivide(
                                   lbann.Add(lbann.Constant(value=1.0, hint_layer=y),lbann.Identity(y)),
                                   lbann.Subtract(lbann.Constant(value=1.0, hint_layer=y),lbann.Identity(y))),
                                   scaling_factors=str(self.datascale))
     linear_scale = 1/self.linear_scaler
     CH2 = lbann.Tanh(lbann.WeightedSum(inv_transform,scaling_factors=str(linear_scale)))
     return CH2  
예제 #3
0
def negative_samples_loss(embeddings, negative_samples_embeddings):
    scores = lbann.MatMul(
        embeddings,
        negative_samples_embeddings,
        transpose_b=True,
    )
    scores = lbann.WeightedSum(scores, scaling_factors='-1')
    scores = lbann.LogSigmoid(scores)
    loss = lbann.Reduction(scores, mode='average')
    loss = lbann.WeightedSum(loss, scaling_factors='-1')
    return loss
예제 #4
0
    def forward(self, img, z, mcr):
        '''
        Steps: 
        - Modify image if using mcr
        - D1 + imgs -> d1_real
        - G + noise -> gen_imgs
        - D1 + gen_imgs -> d1_fake
        - Adv (D2) + gen_imgs
        Return D outputs and gen_imgs
        '''

        print('MCR in forward', mcr)
        if mcr:  ### Multi-channel rescaling. Add extra channel for real images. Generated images are rescaled inside generator
            linear_scale = 1 / self.linear_scaler
            ch2 = lbann.Tanh(
                lbann.WeightedSum(self.inv_transform(lbann.Identity(img)),
                                  scaling_factors=str(linear_scale)))
            y = lbann.Concatenation(lbann.Identity(img), ch2, axis=0)
            img = lbann.Reshape(y, dims='2 128 128')
        else:
            img = lbann.Reshape(img, dims='1 128 128')

        d1_real = self.forward_discriminator1(img)  #instance1
        gen_img = self.forward_generator(z, mcr=mcr)

        d1_fake = self.forward_discriminator1(
            lbann.StopGradient(gen_img))  #instance2
        d_adv = self.forward_discriminator2(
            gen_img)  #instance 3 //need to freeze
        #d1s share weights, d1_w is copied to d_adv (through replace weight callback) and freeze

        return d1_real, d1_fake, d_adv, gen_img, img
예제 #5
0
    def forward_generator(self, z, mcr):
        '''
        Build the Generator
        '''
        x = lbann.Relu(
            lbann.BatchNormalization(self.g_fc1(z),
                                     decay=0.9,
                                     scale_init=1.0,
                                     epsilon=1e-5))
        dims = '512 8 8'
        x = lbann.Reshape(x, dims=dims)  #channel first

        for count, lyr in enumerate(self.g_convT):
            x = lbann.Relu(
                lbann.BatchNormalization(lyr(x),
                                         decay=0.9,
                                         scale_init=1.0,
                                         epsilon=1e-5))

        img = self.g_convT3(x)

        if mcr:  ### For multi-channel rescaling, add extra channel to output image
            linear_scale = 1 / self.linear_scaler
            #             linear_scale=lbann.Constant(value=0.001)
            ch2 = lbann.Tanh(
                lbann.WeightedSum(self.inv_transform(img),
                                  scaling_factors=str(linear_scale)))
            y = lbann.Concatenation(img, ch2, axis=0)
            img = lbann.Reshape(y, dims='2 128 128')
        else:
            img = lbann.Reshape(img, dims='1 128 128')

        return img
예제 #6
0
def random_projection(indices, num_projections, projection_dim):

    # Expand input indices to get an index for each vector entry
    # Note: proj_indices(i) = index*projection_dim + i
    proj_indices = lbann.WeightedSum(
        indices,
        scaling_factors=utils.str_list(projection_dim),
    )
    iota = lbann.WeightsLayer(
        dims=utils.str_list(projection_dim),
        weights=lbann.Weights(
            initializer=lbann.ValueInitializer(
                values=utils.str_list(range(projection_dim))),
            optimizer=lbann.NoOptimizer(),
        ),
    )
    proj_indices = lbann.Sum(
        lbann.Tessellate(
            lbann.Reshape(proj_indices,
                          dims=utils.str_list([num_projections, 1])),
            dims=utils.str_list([num_projections, projection_dim]),
        ),
        lbann.Tessellate(
            lbann.Reshape(iota, dims=utils.str_list([1, projection_dim])),
            dims=utils.str_list([num_projections, projection_dim]),
        ),
    )

    # Apply hash function and convert to Gaussian distribution
    proj = lbann.UniformHash(proj_indices)
    ones = lbann.Constant(
        value=1,
        num_neurons=utils.str_list([num_projections, projection_dim]),
    )
    eps = 0.001
    proj = lbann.ErfInv(
        lbann.WeightedSum(
            proj,
            ones,
            scaling_factors=utils.str_list([2 * (1 - eps), -(1 - eps)]),
        ))
    proj = lbann.InstanceNorm(proj)
    proj = lbann.WeightedSum(
        proj,
        scaling_factors=utils.str_list(1 / projection_dim),
    )
    return proj
예제 #7
0
def f_invtransform(y, scale=4.0):  ### Transform to original space
    '''
    The inverse of the transformation function that scales the data before training
    '''
    inv_transform = lbann.WeightedSum(lbann.SafeDivide(
        lbann.Add(lbann.Constant(value=1.0, hint_layer=y), lbann.Identity(y)),
        lbann.Subtract(lbann.Constant(value=1.0, hint_layer=y),
                       lbann.Identity(y))),
                                      scaling_factors=str(scale))

    return inv_transform
예제 #8
0
 def forward(self, inputs):
     if len(inputs) != 2:
         raise ValueError('expected two inputs: predictions and labels')
     pred = inputs[0]
     label = inputs[1]
     ones = p.Constant(hint_layer=pred, value=1.0)
     term1 = lbann.Multiply(
         [label, lbann.Log(lbann.Subtract([ones, pred]))])
     term2 = lbann.Log(pred)
     full = lbann.WeightedSum([term1, term2], scaling_factors='-1.0 -1.0')
     return lbann.Reduction(full)
예제 #9
0
    def inv_transform(self, y):  ### Original transformation
        '''
        The inverse of the transformation function that scales the data before training
        '''
        inv_transform = lbann.WeightedSum(lbann.SafeDivide(
            lbann.Add(lbann.Constant(value=1.0, hint_layer=y),
                      lbann.Identity(y)),
            lbann.Subtract(lbann.Constant(value=1.0, hint_layer=y),
                           lbann.Identity(y))),
                                          scaling_factors=str(self.datascale))

        return inv_transform
예제 #10
0
 def forward(self, inputs):
     raise NotImplementedError  # Requires log-gamma function
     if len(inputs) != 2:
         raise ValueError('expected two inputs: predictions and labels')
     pred = inputs[0]
     label = inputs[1]
     ones = lbann.Constant(hint_layer=pred, value=1.0)
     term1 = pred
     term2 = lbann.Multiply([label, lbann.Log(pred)])
     term3 = lbann.LogGamma(lbann.Add([label, ones]))
     full = lbann.WeightedSum([term1, term2, term3],
                              scaling_factors='1.0 -1.0 1.0')
     return lbann.Reduction(full)
예제 #11
0
    def forward(
        self,
        motif_indices,
        motif_size,
        walk_indices,
        walk_length,
    ):

        # Apply generator
        fake_motif_indices, gen_prob, gen_log_prob = self.generator(
            walk_length,
            walk_indices,
            motif_size,
        )

        # Get discriminator embeddings in log-space
        all_motif_indices = lbann.Concatenation(motif_indices,
                                                fake_motif_indices)
        all_motif_log_embeddings = self.discriminator.get_log_embeddings(
            all_motif_indices)
        all_motif_log_embeddings = lbann.Slice(
            all_motif_log_embeddings,
            slice_points=str_list([0, motif_size, 2 * motif_size]),
        )
        real_motif_log_embeddings = lbann.Identity(all_motif_log_embeddings)
        fake_motif_log_embeddings = lbann.Identity(all_motif_log_embeddings)

        # Apply discriminator
        real_disc_prob, real_disc_log_not_prob \
            = self.discriminator(motif_size, real_motif_log_embeddings)
        fake_disc_prob, fake_disc_log_not_prob \
            = self.discriminator(motif_size, fake_motif_log_embeddings)

        # Loss function
        # L_disc = - log(D(real)) - log(1-D(fake))
        # L_gen = - log(G) * stop_gradient(log(1-D(fake)))
        real_disc_log_prob \
            = lbann.Log(lbann.Clamp(real_disc_prob, min=1e-37, max=1))
        disc_loss = lbann.WeightedSum(
            real_disc_log_prob,
            fake_disc_log_not_prob,
            scaling_factors=str_list([-1, -1]),
        )
        gen_loss = lbann.Multiply(
            gen_log_prob,
            lbann.StopGradient(fake_disc_log_not_prob),
        )
        loss = lbann.Add(disc_loss, gen_loss)

        return loss, real_disc_prob, fake_disc_prob, gen_prob
예제 #12
0
 def inv_transform(self, y):
     '''
     The inverse of the transformation function that scales the data before training
     '''
     inv_transform = lbann.WeightedSum(lbann.SafeDivide(
         lbann.Add(lbann.Constant(value=1.0, hint_layer=y),
                   lbann.Identity(y)),
         lbann.Subtract(lbann.Constant(value=1.0, hint_layer=y),
                        lbann.Identity(y))),
                                       scaling_factors=str(self.datascale))
     #linear_scale = 1/self.linear_scaler
     #CH2 = lbann.Tanh(lbann.WeightedSum(inv_transform,scaling_factors=str(linear_scale)))
     #return CH2
     return inv_transform
예제 #13
0
 def forward(self, inputs):
     raise NotImplementedError  # Requires log-gamma function
     if len(inputs) != 2:
         raise ValueError('expected two inputs: predictions and labels')
     pred = inputs[0]
     label = inputs[1]
     count = lbann.Reduction(label)
     alpha_sum = lbann.Reduction(pred)
     lgamma_alpha_sum = lbann.Reduction(lbann.LogGamma(pred))
     lgamma_alpha_level_count_sum = lbann.Reduction(
         lbann.LogGamma(lbann.Add([pred, label])))
     return lbann.WeightedSum([
         lbann.LogGamma(alpha_sum),
         lbann.LogGamma(lbann.Sum([count, alpha_sum])),
         lgamma_alpha_level_count, lgamma_alpha_sum
     ],
                              scaling_factors='-1.0 1.0 -1.0 1.0')
예제 #14
0
def mean_squared_error(
    data_dim,
    sequence_length,
    source_sequence,
    target_sequence,
    scale_decay=0.8,
):

    # Compute inner product between source and target vectors
    # Note: Inner products are computed for each (x,y) pair and a
    # weighted sum is computed. The scaling factors sum to 1 and decay
    # exponentially as x and y get further apart in the sequence.
    prods = lbann.MatMul(
        source_sequence,
        target_sequence,
        transpose_b=True,
    )
    scale_dims = (sequence_length, sequence_length)
    scales = np.zeros(scale_dims)
    for i in range(sequence_length):
        for j in range(sequence_length):
            if i != j:
                scales[i, j] = ((1 - scale_decay) / (2 * scale_decay) *
                                scale_decay**np.abs(j - i))
    scales = lbann.Weights(
        initializer=lbann.ValueInitializer(
            values=utils.str_list(np.nditer(scales))),
        optimizer=lbann.NoOptimizer(),
    )
    scales = lbann.WeightsLayer(dims=utils.str_list(scale_dims),
                                weights=scales)
    prods = lbann.MatMul(
        lbann.Reshape(prods, dims='1 -1'),
        lbann.Reshape(scales, dims='1 -1'),
        transpose_b=True,
    )
    prods = lbann.Reshape(prods, dims='1')

    # MSE(x,y) = ( norm(x)^2 + norm(y)^T - 2*prod(x,y) ) / dim(x)
    scale = 1 / (data_dim * sequence_length)
    return lbann.WeightedSum(lbann.L2Norm2(source_sequence),
                             lbann.L2Norm2(target_sequence),
                             prods,
                             scaling_factors=utils.str_list(
                                 [scale, scale, -2 * scale]))
예제 #15
0
    def inv_transform(self, y):  ### Original transformation
        '''
        The inverse of the transformation function that scales the data before training
        '''
        inv_transform = lbann.WeightedSum(lbann.SafeDivide(
            lbann.Add(lbann.Constant(value=1.0, hint_layer=y),
                      lbann.Identity(y)),
            lbann.Subtract(lbann.Constant(value=1.0, hint_layer=y),
                           lbann.Identity(y))),
                                          scaling_factors=str(self.datascale))

        return inv_transform

#      def inv_transform(self, y):### New tranformation : log-linear


#         threshold = lbann.Constant(value=0.5, hint_layer=y)
#         is_above_threshold = lbann.Greater(y, threshold)
#         is_below_threshold = lbann.LogicalNot(is_above_threshold)

#         below = lbann.SafeDivide(
#             lbann.Subtract(y, lbann.Constant(value=1, hint_layer=y)),
#             lbann.Constant(value=0.03, hint_layer=y),
#         )
#         above = lbann.Exp(lbann.SafeDivide(
#             lbann.Subtract(
#                 y,
#                 lbann.Constant(value=0.5-0.5/math.log(300)*math.log(50), hint_layer=y)),
#             lbann.Constant(value=0.5/math.log(300), hint_layer=y),
#         ))
#         return lbann.Add(
#             lbann.Multiply(is_above_threshold, above),
#             lbann.Multiply(is_below_threshold, below),
#         )

# def f_invtransform_new(y):
#     if y<=0.5:
#         a=0.03;b=-1.0
#         return (y-b)/a
#     elif y>0.5:
#         a=0.5/np.log(300)
#         b=0.5-a*np.log(50)
#         return np.exp((y-b)/a)
예제 #16
0
    def forward_generator(self, z, mcr):
        '''
        Build the Generator
        '''
        x = lbann.Relu(
            lbann.BatchNormalization(self.g_fc1(z),
                                     decay=0.9,
                                     scale_init=1.0,
                                     epsilon=1e-5))
        dims = '512 8 8'

        print("dims", dims)
        x = lbann.Reshape(x, dims=dims)  #channel first
        x = lbann.Relu(
            lbann.BatchNormalization(self.g_convT[0](x),
                                     decay=0.9,
                                     scale_init=1.0,
                                     epsilon=1e-5))
        x = lbann.Relu(
            lbann.BatchNormalization(self.g_convT[1](x),
                                     decay=0.9,
                                     scale_init=1.0,
                                     epsilon=1e-5))
        x = lbann.Relu(
            lbann.BatchNormalization(self.g_convT[2](x),
                                     decay=0.9,
                                     scale_init=1.0,
                                     epsilon=1e-5))
        img = self.g_convT3(x)

        if mcr:  ### For multi-channel rescaling, add extra channel to output image
            linear_scale = 1 / self.linear_scaler
            #ch2 = lbann.Tanh(self.inv_transform(img)/linear_scalar)
            ch2 = lbann.Tanh(
                lbann.WeightedSum(self.inv_transform(img),
                                  scaling_factors=str(linear_scale)))
            y = lbann.Concatenation(img, ch2, axis=0)
            img = lbann.Reshape(y, dims='2 128 128')
        else:
            img = lbann.Reshape(img, dims='1 128 128')

        print('Gen Img in GAN', img.__dict__)
        return img
예제 #17
0
    def forward(self, queries, keys, values, mask=None):
        """Apply multi-head attention.

        The input and output tensors are interpreted as sequences of
        vectors, where the first tensor dimension is the sequence
        dimension.

        Args:
            queries (lbann.Layer): Sequence of query vectors.
            keys (lbann.Layer): Sequence of key vectors.
            values (lbann.Layer): Sequence of value vectors.
            mask (lbann.Layer, optional): Additive attention mask. If
                the (i,j) entry is very negative (e.g. -1e9), then the
                ith query does not attend to the jth key/value pair.

        Returns:
            lbann.Layer: Sequence of output vectors. The sequence
                length is the same as `queries`.

        """
        ENABLE_SUBGRAPH = self.ENABLE_SUBGRAPH
        BRANCHES = self.BRANCHES
        if (ENABLE_SUBGRAPH):
            if (self.num_heads % BRANCHES != 0):
                raise ValueError('Num heads should be divisible by BRANCHES')
        self.instance += 1
        name = f'{self.name}_instance{self.instance}'

        # Apply fully-connected layers to input sequences
        queries_fc = []
        keys_fc = []
        values_fc = []

        # Slice embedding vectors for each head
        slice_points = str_list(
            self.head_dim * i
            for i in range(int(self.num_heads / self.BRANCHES) + 1))

        #Queries strong scaling in CFC
        attentions = []
        for count, query in enumerate(queries):
            temp = lbann.ChannelwiseFullyConnected(
                query,
                weights=self.query_weights[count],
                output_channel_dims=[self.inner_dim],
                name=f'{name}_subgrid{count}_queries_fc',
            )
            attentions.append(temp)

        grid_sum_slice = lbann.Cross_Grid_Sum_Slice(attentions)

        attentions = []

        for head in range(self.BRANCHES):
            attentions.append(lbann.Identity(grid_sum_slice))

        for head in range(self.BRANCHES):
            temp = lbann.Slice(
                attentions[head],
                axis=1,
                slice_points=slice_points,
                name=f'{name}_subgrid{head}_queries_slice',
            )

            queries_fc.append(temp)

        #keys strong scaling in CFC

        attentions = []
        for count, key in enumerate(keys):
            temp = lbann.ChannelwiseFullyConnected(
                key,
                weights=self.key_weights[count],
                output_channel_dims=[self.inner_dim],
                name=f'{name}_subgrid{count}_keys_fc',
            )

            attentions.append(temp)

        grid_sum_slice = lbann.Cross_Grid_Sum_Slice(attentions)

        attentions = []

        for head in range(self.BRANCHES):
            attentions.append(lbann.Identity(grid_sum_slice))

        for head in range(self.BRANCHES):

            temp = lbann.Slice(
                attentions[head],
                axis=1,
                slice_points=slice_points,
                name=f'{name}_subgrid{head}_keys_slice',
            )

            keys_fc.append(temp)

        #Values strong scaling in CFC
        attentions = []

        for count, value in enumerate(values):
            temp = lbann.ChannelwiseFullyConnected(
                value,
                weights=self.value_weights[count],
                output_channel_dims=[self.inner_dim],
                name=f'{name}_subgrid{count}_values_fc',
            )
            attentions.append(temp)

        grid_sum_slice = lbann.Cross_Grid_Sum_Slice(attentions)

        attentions = []

        for head in range(self.BRANCHES):
            attentions.append(lbann.Identity(grid_sum_slice))

        for head in range(self.BRANCHES):
            temp = lbann.Slice(
                attentions[head],
                axis=1,
                slice_points=slice_points,
                name=f'{name}_subgrid{head}_values_slice',
            )
            values_fc.append(temp)

        queries_slice = []
        keys_slice = []
        values_slice = []

        for branch in range(self.BRANCHES):
            querie_slice = queries_fc[branch]
            key_slice = keys_fc[branch]
            value_slice = values_fc[branch]

            for head in range(int(self.num_heads / self.BRANCHES)):
                queries_slice.append(lbann.Identity(querie_slice))
                keys_slice.append(lbann.Identity(key_slice))
                values_slice.append(lbann.Identity(value_slice))

        # Compute scaled dot-product attention for each head
        attentions = []

        #variable to combine heads locally in sub-grids
        temp_attentions = []
        tag = 0
        for head in range(self.num_heads):
            head_name = f'{name}_myattention_head{head}'

            # Attention inputs
            if (head % int(self.num_heads / BRANCHES) == 0):
                temp_attentions.append([])
                tag += 1

            q = lbann.Identity(queries_slice[head])
            k = lbann.Identity(keys_slice[head])
            v = lbann.Identity(values_slice[head])

            # Multiply queries and keys
            # Note: num_queries x num_keys
            y = lbann.MatMul(
                q,
                k,
                transpose_b=True,
                name=f'{head_name}_matmul',
            )
            y = lbann.WeightedSum(
                y,
                scaling_factors=str(1 / math.sqrt(self.head_dim)),
                name=f'{head_name}_scale',
            )

            if (ENABLE_SUBGRAPH):
                if mask != None:
                    y = lbann.Sum([y, mask[tag]], name=f'{head_name}_mask')
            else:
                if mask:
                    y = lbann.Sum([y, mask], name=f'{head_name}_mask')
            y = lbann.ChannelwiseSoftmax(y, name=f'{head_name}_softmax')

            # Attention output
            # Note: num_queries x head_dim
            y = lbann.MatMul(y, v, name=head_name)
            # attentions.append(lbann.MatMul(y, v, name=head_name))

            temp_attentions[-1].append(y)

        for count, temp_attention in enumerate(temp_attentions):

            if (self.BRANCHES == self.num_heads):
                # No need to concat the heads at subgrid level
                # if number of subgrids is equal to number of heads
                attention_single_subgrid = temp_attentions[count][0]
            else:
                attention_single_subgrid = lbann.Concatenation(
                    temp_attention,
                    axis=1,
                    name=f'{name}_subgrid_heads_concat{count}',
                    parallel_strategy={
                        'sub_branch_tag': 0,
                        'enable_subgraph': False
                    })

            attention_single_subgrid = lbann.ChannelwiseFullyConnected(
                attention_single_subgrid,
                weights=self.output_weights[count],
                output_channel_dims=[self.embed_dim],
                name=f'{name}_cfc_{count}',
            )

            attentions.append(attention_single_subgrid)

        #Strong scaling

        grid_sum_slice = lbann.Cross_Grid_Sum_Slice(attentions)

        attentions = []

        for head in range(self.BRANCHES):
            attentions.append(lbann.Identity(grid_sum_slice))

        return attentions
예제 #18
0
파일: main.py 프로젝트: szaman19/lbann
    weights=encoder_embeddings_weights,
    num_embeddings=num_graph_nodes,
    embedding_dim=args.latent_dim,
)

# Skip-Gram with negative sampling
preds = lbann.MatMul(decoder_embeddings, encoder_embeddings, transpose_b=True)
preds_slice = lbann.Slice(
    preds,
    axis=0,
    slice_points=f'0 {num_negative_samples} {num_negative_samples+1}')
preds_negative = lbann.Identity(preds_slice)
preds_positive = lbann.Identity(preds_slice)
obj_positive = lbann.LogSigmoid(preds_positive)
obj_positive = lbann.Reduction(obj_positive, mode='sum')
obj_negative = lbann.WeightedSum(preds_negative, scaling_factors='-1')
obj_negative = lbann.LogSigmoid(obj_negative)
obj_negative = lbann.Reduction(obj_negative, mode='sum')
obj = [
    lbann.LayerTerm(obj_positive, scale=-1),
    lbann.LayerTerm(obj_negative, scale=-1/num_negative_samples),
]

# ----------------------------------
# Create data reader
# ----------------------------------

reader = lbann.reader_pb2.DataReader()
_reader = reader.reader.add()
_reader.name = 'python'
_reader.role = 'train'
예제 #19
0
    def forward(self, x, prev_state):
        """ Apply GRU step channelwise 
        Args: 
            x (Layer): Input (shape: (num_channels, *))
            prev_state (Layer): Sate from previous GRU step  (shape: (num_channels, size))
        Returns:
            (Layer, Layer): The output (out) and state (hn). The state can be passed directly into the next GRU step
        """

        self.step += 1

        name = f"{self.name}_step{self.step}"

        mat_size = self.num_channels * self.size

        prev_state = lbann.Reshape(prev_state,
                                   dims=str_list(
                                       [self.num_channels, self.size]),
                                   name=name + "_prev_state_reshape")

        fc1 = self.ih_fc(x)
        fc2 = self.hh_fc(prev_state)

        fc1_slice = lbann.Slice(
            fc1,
            axis=1,
            slice_points=str_list([0, self.size, 2 * self.size,
                                   3 * self.size]))

        Wir_x = lbann.Reshape(lbann.Identity(fc1_slice),
                              dims=str_list([self.num_channels, self.size]),
                              name=name + '_Wir_x')
        Wiz_z = lbann.Reshape(lbann.Identity(fc1_slice),
                              dims=str_list([self.num_channels, self.size]),
                              name=name + '_Wiz_z')
        Win_x = lbann.Reshape(lbann.Identity(fc1_slice),
                              dims=str_list([self.num_channels, self.size]),
                              name=name + '_Win_x')
        fc2_slice = lbann.Slice(
            fc2,
            axis=1,
            slice_points=str_list([0, self.size, 2 * self.size,
                                   3 * self.size]))

        Whr_x = lbann.Reshape(lbann.Identity(fc2_slice),
                              dims=str_list([self.num_channels, self.size]),
                              name=name + '_Whr_x')
        Whz_z = lbann.Reshape(lbann.Identity(fc2_slice),
                              dims=str_list([self.num_channels, self.size]),
                              name=name + '_Whz_z')
        Whn_x = lbann.Reshape(lbann.Identity(fc2_slice),
                              dims=str_list([self.num_channels, self.size]),
                              name=name + '_Whn_x')

        rt = \
            lbann.Sigmoid(
                lbann.Add(Wir_x, Whr_x, data_layout=self.data_layout),
                name=name + '_reset_gate',
                data_layout=self.data_layout
            )

        zt = \
            lbann.Sigmoid(
                lbann.Add(Wiz_z, Whz_z, data_layout=self.data_layout),
                name=name + '_update_gate',
                data_layout=self.data_layout,
            )

        nt = \
            lbann.Tanh(
                lbann.Add(
                    Win_x,
                    lbann.Multiply(rt, Whn_x, data_layout=self.data_layout),
                    data_layout=self.data_layout,
                ),
                name=name + '_new_gate', data_layout=self.data_layout,
            )

        ht = \
            lbann.Add(
                lbann.Multiply(
                    lbann.WeightedSum(
                        self.ones,
                        zt,
                        scaling_factors='1 -1', data_layout=self.data_layout
                    ),
                    nt,
                    data_layout=self.data_layout
                ),
                lbann.Multiply(zt, prev_state, data_layout=self.data_layout),
                name=name+ '_output', data_layout=self.data_layout,
            )

        ht = lbann.Reshape(ht, dims=str_list([self.num_channels, self.size]))

        return ht, ht
예제 #20
0
파일: vae_mnist.py 프로젝트: benson31/lbann
                          has_bias=True)

logsd = lbann.FullyConnected(encode3,
                             name="logsd",
                             num_neurons=30,
                             has_bias=True)

# KL divergence
sd = lbann.Exp(logsd, name="sd")

var = lbann.Square(sd, name="var")

meansq = lbann.Square(mu, name="meansq")

kldiv_plus_half = lbann.WeightedSum([meansq, var, logsd],
                                    name="kldiv_plus_half",
                                    scaling_factors='0.5 0.5 -1')

kldiv_full = lbann.Rsqrt(kldiv_plus_half, name="kldiv_full")

kldiv = lbann.Reduction(kldiv_full, name="kldiv", mode="sum")

# Generate sample
noise = lbann.Gaussian(name="noise", mean=0, stdev=1, hint_layer=mu)

sdnoise = lbann.Hadamard([noise, sd], name="sdnoise")

sample = lbann.Add([mu, sdnoise], name="sample")

# Decoder
decode4 = lbann.FullyConnected(sample,
예제 #21
0
    def forward(self, queries, keys, values, mask=None):
        """Apply multi-head attention.

        The input and output tensors are interpreted as sequences of
        vectors, where the first tensor dimension is the sequence
        dimension.

        Args:
            queries (lbann.Layer): Sequence of query vectors.
            keys (lbann.Layer): Sequence of key vectors.
            values (lbann.Layer): Sequence of value vectors.
            mask (lbann.Layer, optional): Additive attention mask. If
                the (i,j) entry is very negative (e.g. -1e9), then the
                ith query does not attend to the jth key/value pair.

        Returns:
            lbann.Layer: Sequence of output vectors. The sequence
                length is the same as `queries`.

        """
        self.instance += 1
        name = f'{self.name}_instance{self.instance}'

        # Apply fully-connected layers to input sequences
        queries_fc = lbann.ChannelwiseFullyConnected(
            queries,
            weights=self.query_weights,
            output_channel_dims=[self.embed_dim],
            name=f'{name}_queries_fc',
        )
        keys_fc = lbann.ChannelwiseFullyConnected(
            keys,
            weights=self.key_weights,
            output_channel_dims=[self.embed_dim],
            name=f'{name}_keys_fc',
        )
        values_fc = lbann.ChannelwiseFullyConnected(
            values,
            weights=self.value_weights,
            output_channel_dims=[self.embed_dim],
            name=f'{name}_values_fc',
        )

        # Slice embedding vectors for each head
        slice_points = str_list(self.head_dim * i
                                for i in range(self.num_heads + 1))
        queries_slice = lbann.Slice(
            queries_fc,
            axis=1,
            slice_points=slice_points,
            name=f'{name}_queries_slice',
        )
        keys_slice = lbann.Slice(
            keys_fc,
            axis=1,
            slice_points=slice_points,
            name=f'{name}_keys_slice',
        )
        values_slice = lbann.Slice(
            values_fc,
            axis=1,
            slice_points=slice_points,
            name=f'{name}_values_slice',
        )

        # Compute scaled dot-product attention for each head
        attentions = []
        for head in range(self.num_heads):
            head_name = f'{name}_head{head}'

            # Attention inputs
            q = lbann.Identity(queries_slice)
            k = lbann.Identity(keys_slice)
            v = lbann.Identity(values_slice)

            # Multiply queries and keys
            # Note: num_queries x num_keys
            y = lbann.MatMul(
                q,
                k,
                transpose_b=True,
                name=f'{head_name}_matmul',
            )
            y = lbann.WeightedSum(
                y,
                scaling_factors=str(1 / math.sqrt(self.head_dim)),
                name=f'{head_name}_scale',
            )
            if mask:
                y = lbann.Add(y, mask, name=f'{head_name}_mask')
            y = lbann.ChannelwiseSoftmax(y, name=f'{head_name}_softmax')

            # Attention output
            # Note: num_queries x head_dim
            attentions.append(lbann.MatMul(y, v, name=head_name))

        # Concatenate heads and apply fully-connected layer
        attentions = lbann.Concatenation(attentions,
                                         axis=1,
                                         name=f'{name}_heads_concat')
        outputs_fc = lbann.ChannelwiseFullyConnected(
            attentions,
            weights=self.output_weights,
            output_channel_dims=[self.embed_dim],
            name=f'{name}',
        )
        return outputs_fc
예제 #22
0
def construct_macc_surrogate_model(xdim, ydim, zdim, wae_mcf, surrogate_mcf,
                                   lambda_cyc, useCNN, dump_models,
                                   pretrained_dir, ltfb_batch_interval,
                                   num_epochs):
    """Construct MACC surrogate model.

    See https://arxiv.org/pdf/1912.08113.pdf model architecture and other details

    """
    # Layer graph
    input = lbann.Input(data_field='samples', name='inp_data')
    # data is 64*64*4 images + 15 scalar + 5 param
    inp_slice = lbann.Slice(input,
                            axis=0,
                            slice_points=str_list([0, ydim, ydim + xdim]),
                            name='inp_slice')
    gt_y = lbann.Identity(inp_slice, name='gt_y')
    gt_x = lbann.Identity(inp_slice, name='gt_x')  #param not used

    zero = lbann.Constant(value=0.0, num_neurons='1', name='zero')
    one = lbann.Constant(value=1.0, num_neurons='1', name='one')

    z = lbann.Gaussian(mean=0.0, stdev=1.0, neuron_dims="20")
    wae = macc_network_architectures.MACCWAE(
        zdim, ydim, cf=wae_mcf, use_CNN=useCNN)  #pretrained, freeze
    inv = macc_network_architectures.MACCInverse(xdim, cf=surrogate_mcf)
    fwd = macc_network_architectures.MACCForward(zdim, cf=surrogate_mcf)

    y_pred_fwd = wae.encoder(gt_y)

    param_pred_ = wae.encoder(gt_y)
    input_fake = inv(param_pred_)

    output_cyc = fwd(input_fake)
    y_image_re2 = wae.decoder(output_cyc)
    '''**** Train cycleGAN input params <--> latent space of (images, scalars) ****'''
    output_fake = fwd(gt_x)
    y_image_re = wae.decoder(output_fake)

    param_pred2_ = wae.encoder(y_image_re)
    input_cyc = inv(param_pred2_)

    L_l2_x = lbann.MeanSquaredError(input_fake, gt_x)
    L_cyc_x = lbann.MeanSquaredError(input_cyc, gt_x)

    L_l2_y = lbann.MeanSquaredError(output_fake, y_pred_fwd)
    L_cyc_y = lbann.MeanSquaredError(output_cyc, y_pred_fwd)

    #@todo slice here to separate scalar from image
    img_sca_loss = lbann.MeanSquaredError(y_image_re, gt_y)
    #L_cyc = L_cyc_y + L_cyc_x
    L_cyc = lbann.Add(L_cyc_y, L_cyc_x)

    #loss_gen0  = L_l2_y + lamda_cyc*L_cyc
    loss_gen0 = lbann.WeightedSum([L_l2_y, L_cyc],
                                  scaling_factors=f'1 {lambda_cyc}')
    loss_gen1 = lbann.WeightedSum([L_l2_x, L_cyc_y],
                                  scaling_factors=f'1 {lambda_cyc}')
    #loss_gen1  =  L_l2_x + lamda_cyc*L_cyc_y

    layers = list(lbann.traverse_layer_graph(input))
    weights = set()
    #Freeze appropriate (pretrained) weights
    pretrained_models = ["wae"]  #add macc?
    for l in layers:
        for idx in range(len(pretrained_models)):
            if (l.weights and pretrained_models[idx] in l.name):
                for w in range(len(l.weights)):
                    l.weights[w].optimizer = lbann.NoOptimizer()
        weights.update(l.weights)

    l2_reg = lbann.L2WeightRegularization(weights=weights, scale=1e-4)
    #d_adv_bce = lbann.LayerTerm(d_adv_bce,scale=0.01)
    # Setup objective function
    obj = lbann.ObjectiveFunction([loss_gen0, loss_gen1, l2_reg])
    # Initialize check metric callback
    metrics = [
        lbann.Metric(img_sca_loss, name='fw_loss'),
        lbann.Metric(L_l2_x, name='inverse loss'),
        lbann.Metric(L_cyc_y, name='output cycle loss'),
        lbann.Metric(L_cyc_x, name='param cycle loss')
    ]

    callbacks = [
        lbann.CallbackPrint(),
        lbann.CallbackSaveModel(dir=dump_models),
        lbann.CallbackLoadModel(dirs=str(pretrained_dir)),
        lbann.CallbackTimer()
    ]

    if (ltfb_batch_interval > 0):
        callbacks.append(
            lbann.CallbackLTFB(batch_interval=ltfb_batch_interval,
                               metric='fw_loss',
                               low_score_wins=True,
                               exchange_hyperparameters=True))
    # Construct model
    return lbann.Model(num_epochs,
                       weights=weights,
                       layers=layers,
                       metrics=metrics,
                       objective_function=obj,
                       callbacks=callbacks)
예제 #23
0
def construct_jag_wae_model(ydim, zdim, mcf, useCNN, dump_models,
                            ltfb_batch_interval, num_epochs):
    """Construct LBANN model.

    JAG Wasserstein autoencoder  model

    """

    # Layer graph
    input = lbann.Input(data_field='samples', name='inp_data')
    # data is 64*64*4 images + 15 scalar + 5 param
    #inp_slice = lbann.Slice(input, axis=0, slice_points="0 16399 16404",name='inp_slice')
    inp_slice = lbann.Slice(input,
                            axis=0,
                            slice_points=str_list([0, ydim, ydim + 5]),
                            name='inp_slice')
    gt_y = lbann.Identity(inp_slice, name='gt_y')
    gt_x = lbann.Identity(inp_slice, name='gt_x')  #param not used

    zero = lbann.Constant(value=0.0, num_neurons='1', name='zero')
    one = lbann.Constant(value=1.0, num_neurons='1', name='one')

    z_dim = 20  #Latent space dim

    z = lbann.Gaussian(mean=0.0, stdev=1.0, neuron_dims="20")
    model = macc_network_architectures.MACCWAE(zdim,
                                               ydim,
                                               cf=mcf,
                                               use_CNN=useCNN)
    d1_real, d1_fake, d_adv, pred_y = model(z, gt_y)

    d1_real_bce = lbann.SigmoidBinaryCrossEntropy([d1_real, one],
                                                  name='d1_real_bce')
    d1_fake_bce = lbann.SigmoidBinaryCrossEntropy([d1_fake, zero],
                                                  name='d1_fake_bce')
    d_adv_bce = lbann.SigmoidBinaryCrossEntropy([d_adv, one], name='d_adv_bce')
    img_loss = lbann.MeanSquaredError([pred_y, gt_y])
    rec_error = lbann.L2Norm2(
        lbann.WeightedSum([pred_y, gt_y], scaling_factors="1 -1"))

    layers = list(lbann.traverse_layer_graph(input))
    # Setup objective function
    weights = set()
    src_layers = []
    dst_layers = []
    for l in layers:
        if (l.weights and "disc0" in l.name and "instance1" in l.name):
            src_layers.append(l.name)
        #freeze weights in disc2
        if (l.weights and "disc1" in l.name):
            dst_layers.append(l.name)
            for idx in range(len(l.weights)):
                l.weights[idx].optimizer = lbann.NoOptimizer()
        weights.update(l.weights)
    l2_reg = lbann.L2WeightRegularization(weights=weights, scale=1e-4)
    d_adv_bce = lbann.LayerTerm(d_adv_bce, scale=0.01)
    obj = lbann.ObjectiveFunction(
        [d1_real_bce, d1_fake_bce, d_adv_bce, img_loss, rec_error, l2_reg])
    # Initialize check metric callback
    metrics = [lbann.Metric(img_loss, name='recon_error')]
    #pred_y = macc_models.MACCWAE.pred_y_name
    callbacks = [
        lbann.CallbackPrint(),
        lbann.CallbackTimer(),
        lbann.CallbackPrintModelDescription(),
        lbann.CallbackSaveModel(dir=dump_models),
        lbann.CallbackReplaceWeights(source_layers=list2str(src_layers),
                                     destination_layers=list2str(dst_layers),
                                     batch_interval=2)
    ]

    if (ltfb_batch_interval > 0):
        callbacks.append(
            lbann.CallbackLTFB(batch_interval=ltfb_batch_interval,
                               metric='recon_error',
                               low_score_wins=True,
                               exchange_hyperparameters=True))

    # Construct model
    return lbann.Model(num_epochs,
                       weights=weights,
                       layers=layers,
                       metrics=metrics,
                       objective_function=obj,
                       callbacks=callbacks)
예제 #24
0
def make_model(
    num_epochs,
    embed_dim,
    num_heads,
    label_smoothing,
):

    # Embedding weights
    var = 2 / (embed_dim + vocab_size)  # Glorot initialization
    embedding_weights = lbann.Weights(
        name='embeddings',
        initializer=lbann.NormalInitializer(standard_deviation=math.sqrt(var)),
    )

    # Input is two sequences of token IDs
    input_ = lbann.Input(data_field='samples')

    # Get sequences of embedding vectors
    # Note: Scale embeddings by sqrt(embed_dim).
    # Note: Decoder input is shifted right, so embedding for last
    # token isn't needed.
    embeddings_tokens = lbann.Identity(
        lbann.Slice(
            input_,
            axis=0,
            slice_points=str_list([0, 2 * sequence_length - 1]),
        ))
    embeddings = lbann.Embedding(
        embeddings_tokens,
        weights=embedding_weights,
        num_embeddings=vocab_size,
        embedding_dim=embed_dim,
        padding_idx=pad_index,
    )
    embeddings = lbann.WeightedSum(
        embeddings,
        scaling_factors=str(math.sqrt(embed_dim)),
    )
    embeddings_slice = lbann.Slice(
        embeddings,
        axis=0,
        slice_points=str_list([0, sequence_length, 2 * sequence_length - 1]),
    )
    encoder_input = lbann.Identity(embeddings_slice)
    decoder_input = lbann.Identity(embeddings_slice)

    # Apply transformer model
    transformer = lbann.models.Transformer(
        hidden_size=embed_dim,
        num_heads=num_heads,
        name='transformer',
    )
    result = transformer(
        encoder_input,
        sequence_length,
        decoder_input,
        sequence_length - 1,
    )

    # Reconstruct decoder input
    preds = lbann.ChannelwiseFullyConnected(
        result,
        weights=embedding_weights,
        output_channel_dims=[vocab_size],
        bias=False,
        transpose=True,
    )
    preds = lbann.ChannelwiseSoftmax(preds)
    preds = lbann.Slice(preds,
                        axis=0,
                        slice_points=str_list(range(sequence_length)))
    preds = [lbann.Identity(preds) for _ in range(sequence_length - 1)]

    # Count number of non-pad tokens
    label_tokens = lbann.Identity(
        lbann.Slice(
            input_,
            slice_points=str_list([sequence_length + 1, 2 * sequence_length]),
        ))
    pads = lbann.Constant(value=pad_index,
                          num_neurons=str(sequence_length - 1))
    is_not_pad = lbann.NotEqual(label_tokens, pads)
    num_not_pad = lbann.Reduction(is_not_pad, mode='sum')

    # Cross entropy loss with label smoothing
    label_tokens = lbann.Slice(
        label_tokens,
        slice_points=str_list(range(sequence_length)),
    )
    label_tokens = [
        lbann.Identity(label_tokens) for _ in range(sequence_length - 1)
    ]
    if label_smoothing > 0:
        uniform_label = lbann.Constant(value=1 / vocab_size,
                                       num_neurons=str_list([1, vocab_size]))
    loss = []
    for i in range(sequence_length - 1):
        label = lbann.OneHot(label_tokens[i], size=vocab_size)
        label = lbann.Reshape(label, dims=str_list([1, vocab_size]))
        if label_smoothing > 0:
            label = lbann.WeightedSum(
                label,
                uniform_label,
                scaling_factors=str_list(
                    [1 - label_smoothing, label_smoothing]),
            )
        loss.append(lbann.CrossEntropy(preds[i], label))
    loss = lbann.Concatenation(loss)

    # Average cross entropy over non-pad tokens
    loss_scales = lbann.Divide(
        is_not_pad,
        lbann.Tessellate(num_not_pad, hint_layer=is_not_pad),
    )
    loss = lbann.Multiply(loss, loss_scales)
    loss = lbann.Reduction(loss, mode='sum')

    # Construct model
    metrics = []
    callbacks = [lbann.CallbackPrint(), lbann.CallbackTimer()]
    return lbann.Model(
        num_epochs,
        layers=lbann.traverse_layer_graph(input_),
        objective_function=loss,
        metrics=metrics,
        callbacks=callbacks,
    )
예제 #25
0
def construct_model():
    """Construct MACC surrogate model.

    See https://arxiv.org/pdf/1912.08113.pdf model architecture and other details

    """
    import lbann

    # Layer graph
    input = lbann.Input(data_field='samples', name='inp_data')
    # data is 64*64*4 images + 15 scalar + 5 param
    inp_slice = lbann.Slice(input,
                            axis=0,
                            slice_points=str_list(
                                [0, args.ydim, args.ydim + args.xdim]),
                            name='inp_slice')
    gt_y = lbann.Identity(inp_slice, name='gt_y')
    gt_x = lbann.Identity(inp_slice, name='gt_x')  #param not used

    zero = lbann.Constant(value=0.0, num_neurons='1', name='zero')
    one = lbann.Constant(value=1.0, num_neurons='1', name='one')

    z = lbann.Gaussian(mean=0.0, stdev=1.0, neuron_dims="20")
    wae = macc_models.MACCWAE(args.zdim,
                              args.ydim,
                              cf=args.wae_mcf,
                              use_CNN=args.useCNN)  #pretrained, freeze
    inv = macc_models.MACCInverse(args.xdim, cf=args.surrogate_mcf)
    fwd = macc_models.MACCForward(args.zdim, cf=args.surrogate_mcf)

    y_pred_fwd = wae.encoder(gt_y)

    param_pred_ = wae.encoder(gt_y)
    input_fake = inv(param_pred_)

    output_cyc = fwd(input_fake)
    y_image_re2 = wae.decoder(output_cyc)
    '''**** Train cycleGAN input params <--> latent space of (images, scalars) ****'''
    output_fake = fwd(gt_x)
    y_image_re = wae.decoder(output_fake)

    y_out = wae.decoder(y_pred_fwd)

    param_pred2_ = wae.encoder(y_image_re)
    input_cyc = inv(param_pred2_)

    L_l2_x = lbann.MeanSquaredError(
        input_fake, gt_x)  #(x,inv(enc(y)), (encoder+)inverse loss
    L_cyc_x = lbann.MeanSquaredError(
        input_cyc, gt_x)  #param, x cycle loss, from latent space

    L_l2_y = lbann.MeanSquaredError(
        output_fake, y_pred_fwd)  #pred error into latent space (enc(y),fw(x))
    L_cyc_y = lbann.MeanSquaredError(
        output_cyc,
        y_pred_fwd)  # pred error into latent space (enc(y), fw(inv(enc(y))))

    #@todo slice here to separate scalar from image
    img_sca_loss = lbann.MeanSquaredError(
        y_image_re,
        gt_y)  # (y,dec(fw(x))) #forward model to decoder, no latent space
    dec_fw_inv_enc_y = lbann.MeanSquaredError(
        y_image_re2, gt_y)  #(y, dec(fw(inv(enc(y))))) y->enc_z->x'->fw_z->y'
    wae_loss = lbann.MeanSquaredError(y_out, gt_y)  #(y, dec(enc(y)) '
    #L_cyc = L_cyc_y + L_cyc_x
    L_cyc = lbann.Add(L_cyc_y, L_cyc_x)

    #loss_gen0  = L_l2_y + lamda_cyc*L_cyc
    loss_gen0 = lbann.WeightedSum([L_l2_y, L_cyc],
                                  scaling_factors=f'1 {args.lamda_cyc}')
    loss_gen1 = lbann.WeightedSum([L_l2_x, L_cyc_y],
                                  scaling_factors=f'1 {args.lamda_cyc}')
    #loss_gen1  =  L_l2_x + lamda_cyc*L_cyc_y

    conc_out = lbann.Concatenation(
        [gt_x, wae_loss, img_sca_loss, dec_fw_inv_enc_y, L_l2_x],
        name='x_errors')
    layers = list(lbann.traverse_layer_graph(input))
    weights = set()
    for l in layers:
        weights.update(l.weights)

    # Setup objective function
    obj = lbann.ObjectiveFunction([loss_gen0, loss_gen1])
    # Initialize check metric callback
    metrics = [
        lbann.Metric(img_sca_loss, name='img_re1'),
        lbann.Metric(dec_fw_inv_enc_y, name='img_re2'),
        lbann.Metric(wae_loss, name='wae_loss'),
        lbann.Metric(L_l2_x, name='inverse loss'),
        lbann.Metric(L_cyc_y, name='output cycle loss'),
        lbann.Metric(L_cyc_x, name='param cycle loss')
    ]

    callbacks = [
        lbann.CallbackPrint(),
        lbann.CallbackDumpOutputs(layers=f'{conc_out.name}',
                                  execution_modes='test',
                                  directory=args.dump_outputs,
                                  batch_interval=1,
                                  format='npy'),
        lbann.CallbackTimer()
    ]

    # Construct model
    num_epochs = 1
    return lbann.Model(num_epochs,
                       weights=weights,
                       layers=layers,
                       serialize_io=True,
                       metrics=metrics,
                       objective_function=obj,
                       callbacks=callbacks)
예제 #26
0
    def forward(self, x, prev_state):
        """Apply GRU step.

        Args:
            x (Layer): Input.
            prev_state: State from previous GRU step.

        Returns:
            (Layer, Layer): The output (out)  and state (hn).
                          The state can be passed directly into
                           the next GRU step.

        """
        self.step += 1
        name = '{0}_step{1}'.format(self.name, self.step)


        fc1 = self.ih_fc(x)   #input_fc
        fc2 = self.hh_fc(prev_state)  #hidden_fc


        # Get gates and cell update
        fc1_slice = lbann.Slice(fc1,
                            slice_points=str_list([0, self.size, 2*self.size, 3*self.size]),
                            name=name + '_fc1_slice',
                            data_layout=self.data_layout)
        Wir_x = lbann.Identity(fc1_slice, name=name + '_Wrx',
                           data_layout=self.data_layout)
        Wiz_x = lbann.Identity(fc1_slice, name=name + '_Wzx',
                           data_layout=self.data_layout)
        Win_x = lbann.Identity(fc1_slice, name=name + '_Wnx',
                           data_layout=self.data_layout)

        fc2_slice = lbann.Slice(fc2,
                            slice_points=str_list([0, self.size, 2*self.size, 3*self.size]),
                            name=name + '_fc2_slice',
                            data_layout=self.data_layout)
        Whr_prev = lbann.Identity(fc2_slice, name=name + '_Wrh',
                           data_layout=self.data_layout)
        Whz_prev = lbann.Identity(fc2_slice, name=name + '_Wzh',
                           data_layout=self.data_layout)
        Whn_prev = lbann.Identity(fc2_slice, name=name + '_Wnh',
                           data_layout=self.data_layout)

        rt = \
            lbann.Sigmoid(
                lbann.Add(Wir_x, Whr_prev, data_layout=self.data_layout),
                name=name + '_reset_gate',
                data_layout=self.data_layout
            )

        zt = \
            lbann.Sigmoid(
                lbann.Add(Wiz_x, Whz_prev, data_layout=self.data_layout),
                name=name + '_update_gate',
                data_layout=self.data_layout,
            )

        nt = \
            lbann.Tanh(
                lbann.Add(
                    Win_x,
                    lbann.Multiply(rt, Whn_prev, data_layout=self.data_layout),
                    data_layout=self.data_layout,
                ),
                name=name + '_new_gate', data_layout=self.data_layout,
            )

        ht = \
            lbann.Add(
                lbann.Multiply(
                    lbann.WeightedSum(
                        self.ones,
                        zt,
                        scaling_factors='1 -1', data_layout=self.data_layout
                    ),
                    nt,
                    data_layout=self.data_layout
                ),
                lbann.Multiply(zt, prev_state, data_layout=self.data_layout),
                name=name+ '_output', data_layout=self.data_layout,
            )

        # Return output
        return ht, ht
예제 #27
0
    def forward(self, queries, keys, values, mask=None):
        """Apply multi-head attention.

        The input and output tensors are interpreted as sequences of
        vectors, where the first tensor dimension is the sequence
        dimension.

        Args:
            queries (lbann.Layer): Sequence of query vectors.
            keys (lbann.Layer): Sequence of key vectors.
            values (lbann.Layer): Sequence of value vectors.
            mask (lbann.Layer, optional): Additive attention mask. If
                the (i,j) entry is very negative (e.g. -1e9), then the
                ith query does not attend to the jth key/value pair.

        Returns:
            lbann.Layer: Sequence of output vectors. The sequence
                length is the same as `queries`.

        """
        ENABLE_SUBGRAPH = self.ENABLE_SUBGRAPH
        BRANCHES = self.BRANCHES
        if (ENABLE_SUBGRAPH):
            if (self.num_heads % BRANCHES != 0):
                raise ValueError('Num heads should be divisible by BRANCHES')
        self.instance += 1
        name = f'{self.name}_instance{self.instance}'

        # Apply fully-connected layers to input sequences
        queries_fc = lbann.ChannelwiseFullyConnected(
            queries,
            weights=self.query_weights,
            output_channel_dims=[self.inner_dim],
            name=f'{name}_queries_fc',
        )
        keys_fc = lbann.ChannelwiseFullyConnected(
            keys,
            weights=self.key_weights,
            output_channel_dims=[self.inner_dim],
            name=f'{name}_keys_fc',
        )
        values_fc = lbann.ChannelwiseFullyConnected(
            values,
            weights=self.value_weights,
            output_channel_dims=[self.inner_dim],
            name=f'{name}_values_fc',
        )

        # Slice embedding vectors for each head
        slice_points = str_list(self.head_dim * i
                                for i in range(self.num_heads + 1))
        queries_slice = lbann.Slice(queries_fc,
                                    axis=1,
                                    slice_points=slice_points,
                                    name=f'{name}_queries_slice',
                                    parallel_strategy={
                                        'sub_branch_tag': 0,
                                        'enable_subgraph': ENABLE_SUBGRAPH
                                    })
        keys_slice = lbann.Slice(keys_fc,
                                 axis=1,
                                 slice_points=slice_points,
                                 name=f'{name}_keys_slice',
                                 parallel_strategy={
                                     'sub_branch_tag': 0,
                                     'enable_subgraph': ENABLE_SUBGRAPH
                                 })
        values_slice = lbann.Slice(values_fc,
                                   axis=1,
                                   slice_points=slice_points,
                                   name=f'{name}_values_slice',
                                   parallel_strategy={
                                       'sub_branch_tag': 0,
                                       'enable_subgraph': ENABLE_SUBGRAPH
                                   })

        # Compute scaled dot-product attention for each head
        attentions = []
        tag = 0
        for head in range(self.num_heads):
            head_name = f'{name}_myattention_head{head}'

            # Attention inputs

            if (ENABLE_SUBGRAPH):
                if (head % int(self.num_heads / BRANCHES) == 0):
                    tag += 1

                q = lbann.Identity(queries_slice,
                                   parallel_strategy={
                                       'sub_branch_tag': tag,
                                       'enable_subgraph': ENABLE_SUBGRAPH
                                   })
                k = lbann.Identity(keys_slice,
                                   parallel_strategy={
                                       'sub_branch_tag': tag,
                                       'enable_subgraph': ENABLE_SUBGRAPH
                                   })
                v = lbann.Identity(values_slice,
                                   parallel_strategy={
                                       'sub_branch_tag': tag,
                                       'enable_subgraph': ENABLE_SUBGRAPH
                                   })
            else:
                q = lbann.Identity(queries_slice)
                k = lbann.Identity(keys_slice)
                v = lbann.Identity(values_slice)

            # Multiply queries and keys
            # Note: num_queries x num_keys
            y = lbann.MatMul(
                q,
                k,
                transpose_b=True,
                name=f'{head_name}_matmul',
            )
            y = lbann.WeightedSum(
                y,
                scaling_factors=str(1 / math.sqrt(self.head_dim)),
                name=f'{head_name}_scale',
            )

            if (ENABLE_SUBGRAPH):
                if mask != None:
                    y = lbann.Sum([y, mask[tag]], name=f'{head_name}_mask')
            else:
                if mask:
                    y = lbann.Sum([y, mask], name=f'{head_name}_mask')
            y = lbann.ChannelwiseSoftmax(y, name=f'{head_name}_softmax')

            # Attention output
            # Note: num_queries x head_dim

            attentions.append(lbann.MatMul(y, v, name=head_name))

            #Strong scaling

        # Concatenate heads and apply fully-connected layer
        if (ENABLE_SUBGRAPH):
            attentions = lbann.Concatenation(attentions,
                                             axis=1,
                                             name=f'{name}_heads_concat',
                                             parallel_strategy={
                                                 'sub_branch_tag': 0,
                                                 'enable_subgraph':
                                                 ENABLE_SUBGRAPH
                                             })
        else:
            attentions = lbann.Concatenation(
                attentions,
                axis=1,
                name=f'{name}_heads_concat',
            )

        outputs_fc = lbann.ChannelwiseFullyConnected(
            attentions,
            weights=self.output_weights,
            output_channel_dims=[self.embed_dim],
            name=f'{name}',
        )
        return outputs_fc
예제 #28
0
파일: main.py 프로젝트: benson31/lbann
negative_samples_embeddings = lbann.Identity(embeddings_slice)
walk_embeddings = lbann.Identity(embeddings_slice)

# Skip-Gram objective function
positive_loss = model.skip_gram.positive_samples_loss(
    walk_length,
    lbann.Identity(walk_embeddings),
    lbann.Identity(walk_embeddings),
    scale_decay=0.8,
)
negative_loss = model.skip_gram.negative_samples_loss(
    walk_embeddings,
    negative_samples_embeddings,
)
obj.append(positive_loss)
obj.append(lbann.WeightedSum(negative_loss, scaling_factors='2'))
metrics.append(lbann.Metric(positive_loss, name='positive loss'))
metrics.append(lbann.Metric(negative_loss, name='negative loss'))

# Perform computation at double precision
for l in lbann.traverse_layer_graph(input_):
    l.datatype = lbann.DataType.DOUBLE
    for w in l.weights:
        w.datatype = lbann.DataType.DOUBLE

# ----------------------------------
# Run LBANN
# ----------------------------------

# Create optimizer
opt = lbann.SGD(learn_rate=args.learning_rate)
예제 #29
0
def construct_model():
    """Construct LBANN model.

    JAG Wasserstein autoencoder  model

    """
    import lbann

    # Layer graph
    input = lbann.Input(target_mode='N/A',name='inp_data')
    # data is 64*64*4 images + 15 scalar + 5 param
    inp_slice = lbann.Slice(input, axis=0, slice_points="0 16399 16404",name='inp_slice')
    gt_y = lbann.Identity(inp_slice,name='gt_y')
    gt_x = lbann.Identity(inp_slice, name='gt_x') #param not used

    zero  = lbann.Constant(value=0.0,num_neurons='1',name='zero')
    one  = lbann.Constant(value=1.0,num_neurons='1',name='one')

    y_dim = 16399 #image+scalar shape
    z_dim = 20  #Latent space dim

    z = lbann.Gaussian(mean=0.0,stdev=1.0, neuron_dims="20")
    d1_real, d1_fake, d_adv, pred_y  = jag_models.WAE(z_dim,y_dim)(z,gt_y)

    d1_real_bce = lbann.SigmoidBinaryCrossEntropy([d1_real,one],name='d1_real_bce')
    d1_fake_bce = lbann.SigmoidBinaryCrossEntropy([d1_fake,zero],name='d1_fake_bce')
    d_adv_bce = lbann.SigmoidBinaryCrossEntropy([d_adv,one],name='d_adv_bce')

    img_loss = lbann.MeanSquaredError([pred_y,gt_y])
    rec_error = lbann.L2Norm2(lbann.WeightedSum([pred_y,gt_y], scaling_factors="1 -1"))

    layers = list(lbann.traverse_layer_graph(input))
    # Setup objective function
    weights = set()
    src_layers = []
    dst_layers = []
    for l in layers:
      if(l.weights and "disc0" in l.name and "instance1" in l.name):
        src_layers.append(l.name)
      #freeze weights in disc2
      if(l.weights and "disc1" in l.name):
        dst_layers.append(l.name)
        for idx in range(len(l.weights)):
          l.weights[idx].optimizer = lbann.NoOptimizer()
      weights.update(l.weights)
    l2_reg = lbann.L2WeightRegularization(weights=weights, scale=1e-4)
    d_adv_bce = lbann.LayerTerm(d_adv_bce,scale=0.01)
    obj = lbann.ObjectiveFunction([d1_real_bce,d1_fake_bce,d_adv_bce,img_loss,rec_error,l2_reg])
    # Initialize check metric callback
    metrics = [lbann.Metric(img_loss, name='recon_error')]

    callbacks = [lbann.CallbackPrint(),
                 lbann.CallbackTimer(),
                 lbann.CallbackReplaceWeights(source_layers=list2str(src_layers),
                                      destination_layers=list2str(dst_layers),
                                      batch_interval=2)]

    # Construct model
    num_epochs = 100
    return lbann.Model(num_epochs,
                       weights=weights,
                       layers=layers,
                       metrics=metrics,
                       objective_function=obj,
                       callbacks=callbacks)