示例#1
0
    def forward(self, x):
        """Perform LSTM step.

        State from previous steps is used to compute output.

        """
        self.step += 1
        name = '{0}_step{1}'.format(self.name, self.step)

        # Apply linearity
        input_concat = lbann.Concatenation([x, self.last_output],
                                           name=name + '_input',
                                           data_layout=self.data_layout)
        fc = self.fc(input_concat)

        # Get gates and cell update
        slice = lbann.Slice(fc,
                            slice_points=_str_list([0, self.size, 4*self.size]),
                            name=name + '_fc_slice',
                            data_layout=self.data_layout)
        cell_update = lbann.Tanh(slice,
                                 name=name + '_cell_update',
                                 data_layout=self.data_layout)
        sigmoid = lbann.Sigmoid(slice,
                                name=name + '_sigmoid',
                                data_layout=self.data_layout)
        slice = lbann.Slice(sigmoid,
                            slice_points=_str_list([0, self.size, 2*self.size, 3*self.size]),
                            name=name + '_sigmoid_slice',
                            data_layout=self.data_layout)
        f = lbann.Identity(slice, name=name + '_forget_gate',
                           data_layout=self.data_layout)
        i = lbann.Identity(slice, name=name + '_input_gate',
                           data_layout=self.data_layout)
        o = lbann.Identity(slice, name=name + '_output_gate',
                           data_layout=self.data_layout)

        # Cell state
        cell_forget = lbann.Multiply([f, self.last_cell],
                                     name=name + '_cell_forget',
                                     data_layout=self.data_layout)
        cell_input = lbann.Multiply([i, cell_update],
                                    name=name + '_cell_input',
                                    data_layout=self.data_layout)
        cell = lbann.Add([cell_forget, cell_input], name=name + '_cell',
                         data_layout=self.data_layout)

        # Output
        cell_act = lbann.Tanh(cell, name=name + '_cell_activation',
                              data_layout=self.data_layout)
        output = lbann.Multiply([o, cell_act], name=name,
                                data_layout=self.data_layout)

        # Update state and return output
        self.last_cell = cell
        self.last_output = output
        return output
示例#2
0
    def forward_generator(self, z, mcr):
        '''
        Build the Generator
        '''
        x = lbann.Relu(
            lbann.BatchNormalization(self.g_fc1(z),
                                     decay=0.9,
                                     scale_init=1.0,
                                     epsilon=1e-5))
        dims = '512 8 8'
        x = lbann.Reshape(x, dims=dims)  #channel first

        for count, lyr in enumerate(self.g_convT):
            x = lbann.Relu(
                lbann.BatchNormalization(lyr(x),
                                         decay=0.9,
                                         scale_init=1.0,
                                         epsilon=1e-5))

        img = self.g_convT3(x)

        if mcr:  ### For multi-channel rescaling, add extra channel to output image
            linear_scale = 1 / self.linear_scaler
            #             linear_scale=lbann.Constant(value=0.001)
            ch2 = lbann.Tanh(
                lbann.WeightedSum(self.inv_transform(img),
                                  scaling_factors=str(linear_scale)))
            y = lbann.Concatenation(img, ch2, axis=0)
            img = lbann.Reshape(y, dims='2 128 128')
        else:
            img = lbann.Reshape(img, dims='1 128 128')

        return img
示例#3
0
    def forward(self, img, z, mcr):
        '''
        Steps: 
        - Modify image if using mcr
        - D1 + imgs -> d1_real
        - G + noise -> gen_imgs
        - D1 + gen_imgs -> d1_fake
        - Adv (D2) + gen_imgs
        Return D outputs and gen_imgs
        '''

        print('MCR in forward', mcr)
        if mcr:  ### Multi-channel rescaling. Add extra channel for real images. Generated images are rescaled inside generator
            linear_scale = 1 / self.linear_scaler
            ch2 = lbann.Tanh(
                lbann.WeightedSum(self.inv_transform(lbann.Identity(img)),
                                  scaling_factors=str(linear_scale)))
            y = lbann.Concatenation(lbann.Identity(img), ch2, axis=0)
            img = lbann.Reshape(y, dims='2 128 128')
        else:
            img = lbann.Reshape(img, dims='1 128 128')

        d1_real = self.forward_discriminator1(img)  #instance1
        gen_img = self.forward_generator(z, mcr=mcr)

        d1_fake = self.forward_discriminator1(
            lbann.StopGradient(gen_img))  #instance2
        d_adv = self.forward_discriminator2(
            gen_img)  #instance 3 //need to freeze
        #d1s share weights, d1_w is copied to d_adv (through replace weight callback) and freeze

        return d1_real, d1_fake, d_adv, gen_img, img
示例#4
0
 def inv_transform(self,y): 
     inv_transform = lbann.WeightedSum(
                                   lbann.SafeDivide(
                                   lbann.Add(lbann.Constant(value=1.0, hint_layer=y),lbann.Identity(y)),
                                   lbann.Subtract(lbann.Constant(value=1.0, hint_layer=y),lbann.Identity(y))),
                                   scaling_factors=str(self.datascale))
     linear_scale = 1/self.linear_scaler
     CH2 = lbann.Tanh(lbann.WeightedSum(inv_transform,scaling_factors=str(linear_scale)))
     return CH2  
示例#5
0
def Gelu_approx(x):
    # This approximates gelu and may be more performant
    # return 0.5 * x * (1 + tanh(sqrt(2 / pi) * (x + 0.044715 * x ** 3)))
    # Based on: https://paperswithcode.com/method/gelu
    sqrt_2_over_pi = math.sqrt(2 / math.pi)
    b_coef = 0.044715
    x_cubed = lbann.Multiply(lbann.Multiply(lbann.Identity(x), x), x)
    inner_tanh_x_comp = lbann.Add(x, lbann.Scale(x_cubed, constant=b_coef))
    tanh_x = lbann.Tanh(lbann.Scale(inner_tanh_x_comp,
                                    constant=sqrt_2_over_pi))
    return lbann.Scale(lbann.Multiply(x, lbann.AddConstant(tanh_x,
                                                           constant=1)),
                       constant=0.5)
示例#6
0
    def forward_generator(self, z, mcr):
        '''
        Build the Generator
        '''
        x = lbann.Relu(
            lbann.BatchNormalization(self.g_fc1(z),
                                     decay=0.9,
                                     scale_init=1.0,
                                     epsilon=1e-5))
        dims = '512 8 8'

        print("dims", dims)
        x = lbann.Reshape(x, dims=dims)  #channel first
        x = lbann.Relu(
            lbann.BatchNormalization(self.g_convT[0](x),
                                     decay=0.9,
                                     scale_init=1.0,
                                     epsilon=1e-5))
        x = lbann.Relu(
            lbann.BatchNormalization(self.g_convT[1](x),
                                     decay=0.9,
                                     scale_init=1.0,
                                     epsilon=1e-5))
        x = lbann.Relu(
            lbann.BatchNormalization(self.g_convT[2](x),
                                     decay=0.9,
                                     scale_init=1.0,
                                     epsilon=1e-5))
        img = self.g_convT3(x)

        if mcr:  ### For multi-channel rescaling, add extra channel to output image
            linear_scale = 1 / self.linear_scaler
            #ch2 = lbann.Tanh(self.inv_transform(img)/linear_scalar)
            ch2 = lbann.Tanh(
                lbann.WeightedSum(self.inv_transform(img),
                                  scaling_factors=str(linear_scale)))
            y = lbann.Concatenation(img, ch2, axis=0)
            img = lbann.Reshape(y, dims='2 128 128')
        else:
            img = lbann.Reshape(img, dims='1 128 128')

        print('Gen Img in GAN', img.__dict__)
        return img
示例#7
0
 def forward(self, hidden_states):
     # We "pool" the model by simply taking the hidden state corresponding
     # to the first token.
     first_token_tensor = lbann.Slice(hidden_states,
                                      axis=1,
                                      slice_points=str_list([0, 1]))
     pooled_output = lbann.modules.PytorchLinear(
         first_token_tensor,
         (self.input_shape[0], self.input_shape[-1]),
         self.hidden_size,
         weights=_load_pretrained_weights(
             ".".join((self.name, "dense.weight")),
             ".".join((self.name, "dense.bias")),
             load_weights=self.load_weights,
         ),
         name=".".join((self.name, "dense")),
     )
     pooled_output = lbann.Tanh(pooled_output,
                                name=".".join((self.name, "activation")))
     return pooled_output
示例#8
0
def gen_layers(latent_dim, number_of_atoms):
    ''' Generates the model for the 3D Convolutional Auto Encoder. 
        
                returns the Directed Acyclic Graph (DAG) that the lbann 
        model will run on. 
    '''

    input_ = lbann.Input(target_mode="reconstruction")
    tensors = lbann.Identity(input_)

    tensors = lbann.Reshape(tensors, dims="11 32 32 32", name="Sample")
    # Input tensor shape is  (number_of_atoms)x32x32x32

    # Encoder

    x = lbann.Identity(tensors)
    for i in range(4):
        out_channels = latent_dim // (2**(3 - i))

        x = lbann.Convolution(x,
                              num_dims=3,
                              num_output_channels=out_channels,
                              num_groups=1,
                              conv_dims_i=4,
                              conv_strides_i=2,
                              conv_dilations_i=1,
                              conv_pads_i=1,
                              has_bias=True,
                              name="Conv_{0}".format(i))

        x = lbann.BatchNormalization(x, name="Batch_NORM_{0}".format(i + 1))
        x = lbann.LeakyRelu(x, name="Conv_{0}_Activation".format(i + 1))

    # Shape: (latent_dim)x2x2x2
    encoded = lbann.Convolution(x,
                                num_dims=3,
                                num_output_channels=latent_dim,
                                num_groups=1,
                                conv_dims_i=2,
                                conv_strides_i=2,
                                conv_dilations_i=1,
                                conv_pads_i=0,
                                has_bias=True,
                                name="encoded")

    # Shape: (latent_dim)1x1x1

    # Decoder

    x = lbann.Deconvolution(encoded,
                            num_dims=3,
                            num_output_channels=number_of_atoms * 16,
                            num_groups=1,
                            conv_dims_i=4,
                            conv_pads_i=0,
                            conv_strides_i=2,
                            conv_dilations_i=1,
                            has_bias=True,
                            name="Deconv_1")
    x = lbann.BatchNormalization(x, name="BN_D1")
    x = lbann.Tanh(x, name="Deconv_1_Activation")

    for i in range(3):
        out_channels = number_of_atoms * (2**(2 - i))
        x = lbann.Deconvolution(x,
                                num_dims=3,
                                num_output_channels=out_channels,
                                num_groups=1,
                                conv_dims_i=4,
                                conv_pads_i=1,
                                conv_strides_i=2,
                                conv_dilations_i=1,
                                has_bias=True,
                                name="Deconv_{0}".format(i + 2))
        x = lbann.BatchNormalization(x, name="BN_D{0}".format(i + 2))

        if (
                i != 2
        ):  #Save the last activation layer because we want to dump the outputs
            x = lbann.Tanh(x, name="Deconv_{0}_Activation".format(i + 2))

    decoded = lbann.Tanh(x, name="decoded")

    img_loss = lbann.MeanSquaredError([decoded, tensors])

    metrics = [lbann.Metric(img_loss, name='recon_error')]
    # ----------------------------------
    # Set up DAG
    # ----------------------------------

    layers = lbann.traverse_layer_graph(input_)  #Generate Model DAG
    return layers, img_loss, metrics
示例#9
0
    def forward(self, x, prev_state):
        """Apply LSTM step.

        Args:
            x (Layer): Input.
            prev_state (tuple with two `Layer`s): State from previous
                LSTM step. Comprised of LSTM output and cell state.

        Returns:
            (Layer, (Layer, Layer)): The output and state (the output
                and cell state). The state can be passed directly into
                the next LSTM step.

        """
        self.step += 1
        name = '{0}_step{1}'.format(self.name, self.step)

        # Get output and cell state from previous step
        prev_output, prev_cell = prev_state

        # Apply linearity
        input_concat = lbann.Concatenation(x, prev_output,
                                           name=name + '_input',
                                           data_layout=self.data_layout)
        fc = self.fc(input_concat)

        # Get gates and cell update
        slice = lbann.Slice(fc,
                            slice_points=str_list([0, self.size, 4*self.size]),
                            name=name + '_fc_slice',
                            data_layout=self.data_layout)
        cell_update = lbann.Tanh(slice,
                                 name=name + '_cell_update',
                                 data_layout=self.data_layout)
        sigmoid = lbann.Sigmoid(slice,
                                name=name + '_sigmoid',
                                data_layout=self.data_layout)
        slice = lbann.Slice(sigmoid,
                            slice_points=str_list([0, self.size, 2*self.size, 3*self.size]),
                            name=name + '_sigmoid_slice',
                            data_layout=self.data_layout)
        f = lbann.Identity(slice, name=name + '_forget_gate',
                           data_layout=self.data_layout)
        i = lbann.Identity(slice, name=name + '_input_gate',
                           data_layout=self.data_layout)
        o = lbann.Identity(slice, name=name + '_output_gate',
                           data_layout=self.data_layout)

        # Cell state
        cell_forget = lbann.Multiply(f, prev_cell,
                                     name=name + '_cell_forget',
                                     data_layout=self.data_layout)
        cell_input = lbann.Multiply(i, cell_update,
                                    name=name + '_cell_input',
                                    data_layout=self.data_layout)
        cell = lbann.Add(cell_forget, cell_input, name=name + '_cell',
                         data_layout=self.data_layout)

        # Output
        cell_act = lbann.Tanh(cell, name=name + '_cell_activation',
                              data_layout=self.data_layout)
        output = lbann.Multiply(o, cell_act, name=name,
                                data_layout=self.data_layout)

        # Return output and state
        return output, (output, cell)
示例#10
0
    def forward(self, x, prev_state):
        """Apply GRU step.

        Args:
            x (Layer): Input.
            prev_state: State from previous GRU step.

        Returns:
            (Layer, Layer): The output (out)  and state (hn).
                          The state can be passed directly into
                           the next GRU step.

        """
        self.step += 1
        name = '{0}_step{1}'.format(self.name, self.step)


        fc1 = self.ih_fc(x)   #input_fc
        fc2 = self.hh_fc(prev_state)  #hidden_fc


        # Get gates and cell update
        fc1_slice = lbann.Slice(fc1,
                            slice_points=str_list([0, self.size, 2*self.size, 3*self.size]),
                            name=name + '_fc1_slice',
                            data_layout=self.data_layout)
        Wir_x = lbann.Identity(fc1_slice, name=name + '_Wrx',
                           data_layout=self.data_layout)
        Wiz_x = lbann.Identity(fc1_slice, name=name + '_Wzx',
                           data_layout=self.data_layout)
        Win_x = lbann.Identity(fc1_slice, name=name + '_Wnx',
                           data_layout=self.data_layout)

        fc2_slice = lbann.Slice(fc2,
                            slice_points=str_list([0, self.size, 2*self.size, 3*self.size]),
                            name=name + '_fc2_slice',
                            data_layout=self.data_layout)
        Whr_prev = lbann.Identity(fc2_slice, name=name + '_Wrh',
                           data_layout=self.data_layout)
        Whz_prev = lbann.Identity(fc2_slice, name=name + '_Wzh',
                           data_layout=self.data_layout)
        Whn_prev = lbann.Identity(fc2_slice, name=name + '_Wnh',
                           data_layout=self.data_layout)

        rt = \
            lbann.Sigmoid(
                lbann.Add(Wir_x, Whr_prev, data_layout=self.data_layout),
                name=name + '_reset_gate',
                data_layout=self.data_layout
            )

        zt = \
            lbann.Sigmoid(
                lbann.Add(Wiz_x, Whz_prev, data_layout=self.data_layout),
                name=name + '_update_gate',
                data_layout=self.data_layout,
            )

        nt = \
            lbann.Tanh(
                lbann.Add(
                    Win_x,
                    lbann.Multiply(rt, Whn_prev, data_layout=self.data_layout),
                    data_layout=self.data_layout,
                ),
                name=name + '_new_gate', data_layout=self.data_layout,
            )

        ht = \
            lbann.Add(
                lbann.Multiply(
                    lbann.WeightedSum(
                        self.ones,
                        zt,
                        scaling_factors='1 -1', data_layout=self.data_layout
                    ),
                    nt,
                    data_layout=self.data_layout
                ),
                lbann.Multiply(zt, prev_state, data_layout=self.data_layout),
                name=name+ '_output', data_layout=self.data_layout,
            )

        # Return output
        return ht, ht
示例#11
0
    def forward(self, x, prev_state):
        """ Apply GRU step channelwise 
        Args: 
            x (Layer): Input (shape: (num_channels, *))
            prev_state (Layer): Sate from previous GRU step  (shape: (num_channels, size))
        Returns:
            (Layer, Layer): The output (out) and state (hn). The state can be passed directly into the next GRU step
        """

        self.step += 1

        name = f"{self.name}_step{self.step}"

        mat_size = self.num_channels * self.size

        prev_state = lbann.Reshape(prev_state,
                                   dims=str_list(
                                       [self.num_channels, self.size]),
                                   name=name + "_prev_state_reshape")

        fc1 = self.ih_fc(x)
        fc2 = self.hh_fc(prev_state)

        fc1_slice = lbann.Slice(
            fc1,
            axis=1,
            slice_points=str_list([0, self.size, 2 * self.size,
                                   3 * self.size]))

        Wir_x = lbann.Reshape(lbann.Identity(fc1_slice),
                              dims=str_list([self.num_channels, self.size]),
                              name=name + '_Wir_x')
        Wiz_z = lbann.Reshape(lbann.Identity(fc1_slice),
                              dims=str_list([self.num_channels, self.size]),
                              name=name + '_Wiz_z')
        Win_x = lbann.Reshape(lbann.Identity(fc1_slice),
                              dims=str_list([self.num_channels, self.size]),
                              name=name + '_Win_x')
        fc2_slice = lbann.Slice(
            fc2,
            axis=1,
            slice_points=str_list([0, self.size, 2 * self.size,
                                   3 * self.size]))

        Whr_x = lbann.Reshape(lbann.Identity(fc2_slice),
                              dims=str_list([self.num_channels, self.size]),
                              name=name + '_Whr_x')
        Whz_z = lbann.Reshape(lbann.Identity(fc2_slice),
                              dims=str_list([self.num_channels, self.size]),
                              name=name + '_Whz_z')
        Whn_x = lbann.Reshape(lbann.Identity(fc2_slice),
                              dims=str_list([self.num_channels, self.size]),
                              name=name + '_Whn_x')

        rt = \
            lbann.Sigmoid(
                lbann.Add(Wir_x, Whr_x, data_layout=self.data_layout),
                name=name + '_reset_gate',
                data_layout=self.data_layout
            )

        zt = \
            lbann.Sigmoid(
                lbann.Add(Wiz_z, Whz_z, data_layout=self.data_layout),
                name=name + '_update_gate',
                data_layout=self.data_layout,
            )

        nt = \
            lbann.Tanh(
                lbann.Add(
                    Win_x,
                    lbann.Multiply(rt, Whn_x, data_layout=self.data_layout),
                    data_layout=self.data_layout,
                ),
                name=name + '_new_gate', data_layout=self.data_layout,
            )

        ht = \
            lbann.Add(
                lbann.Multiply(
                    lbann.WeightedSum(
                        self.ones,
                        zt,
                        scaling_factors='1 -1', data_layout=self.data_layout
                    ),
                    nt,
                    data_layout=self.data_layout
                ),
                lbann.Multiply(zt, prev_state, data_layout=self.data_layout),
                name=name+ '_output', data_layout=self.data_layout,
            )

        ht = lbann.Reshape(ht, dims=str_list([self.num_channels, self.size]))

        return ht, ht