def forward(self, x): """Perform LSTM step. State from previous steps is used to compute output. """ self.step += 1 name = '{0}_step{1}'.format(self.name, self.step) # Apply linearity input_concat = lbann.Concatenation([x, self.last_output], name=name + '_input', data_layout=self.data_layout) fc = self.fc(input_concat) # Get gates and cell update slice = lbann.Slice(fc, slice_points=_str_list([0, self.size, 4*self.size]), name=name + '_fc_slice', data_layout=self.data_layout) cell_update = lbann.Tanh(slice, name=name + '_cell_update', data_layout=self.data_layout) sigmoid = lbann.Sigmoid(slice, name=name + '_sigmoid', data_layout=self.data_layout) slice = lbann.Slice(sigmoid, slice_points=_str_list([0, self.size, 2*self.size, 3*self.size]), name=name + '_sigmoid_slice', data_layout=self.data_layout) f = lbann.Identity(slice, name=name + '_forget_gate', data_layout=self.data_layout) i = lbann.Identity(slice, name=name + '_input_gate', data_layout=self.data_layout) o = lbann.Identity(slice, name=name + '_output_gate', data_layout=self.data_layout) # Cell state cell_forget = lbann.Multiply([f, self.last_cell], name=name + '_cell_forget', data_layout=self.data_layout) cell_input = lbann.Multiply([i, cell_update], name=name + '_cell_input', data_layout=self.data_layout) cell = lbann.Add([cell_forget, cell_input], name=name + '_cell', data_layout=self.data_layout) # Output cell_act = lbann.Tanh(cell, name=name + '_cell_activation', data_layout=self.data_layout) output = lbann.Multiply([o, cell_act], name=name, data_layout=self.data_layout) # Update state and return output self.last_cell = cell self.last_output = output return output
def forward_generator(self, z, mcr): ''' Build the Generator ''' x = lbann.Relu( lbann.BatchNormalization(self.g_fc1(z), decay=0.9, scale_init=1.0, epsilon=1e-5)) dims = '512 8 8' x = lbann.Reshape(x, dims=dims) #channel first for count, lyr in enumerate(self.g_convT): x = lbann.Relu( lbann.BatchNormalization(lyr(x), decay=0.9, scale_init=1.0, epsilon=1e-5)) img = self.g_convT3(x) if mcr: ### For multi-channel rescaling, add extra channel to output image linear_scale = 1 / self.linear_scaler # linear_scale=lbann.Constant(value=0.001) ch2 = lbann.Tanh( lbann.WeightedSum(self.inv_transform(img), scaling_factors=str(linear_scale))) y = lbann.Concatenation(img, ch2, axis=0) img = lbann.Reshape(y, dims='2 128 128') else: img = lbann.Reshape(img, dims='1 128 128') return img
def forward(self, img, z, mcr): ''' Steps: - Modify image if using mcr - D1 + imgs -> d1_real - G + noise -> gen_imgs - D1 + gen_imgs -> d1_fake - Adv (D2) + gen_imgs Return D outputs and gen_imgs ''' print('MCR in forward', mcr) if mcr: ### Multi-channel rescaling. Add extra channel for real images. Generated images are rescaled inside generator linear_scale = 1 / self.linear_scaler ch2 = lbann.Tanh( lbann.WeightedSum(self.inv_transform(lbann.Identity(img)), scaling_factors=str(linear_scale))) y = lbann.Concatenation(lbann.Identity(img), ch2, axis=0) img = lbann.Reshape(y, dims='2 128 128') else: img = lbann.Reshape(img, dims='1 128 128') d1_real = self.forward_discriminator1(img) #instance1 gen_img = self.forward_generator(z, mcr=mcr) d1_fake = self.forward_discriminator1( lbann.StopGradient(gen_img)) #instance2 d_adv = self.forward_discriminator2( gen_img) #instance 3 //need to freeze #d1s share weights, d1_w is copied to d_adv (through replace weight callback) and freeze return d1_real, d1_fake, d_adv, gen_img, img
def inv_transform(self,y): inv_transform = lbann.WeightedSum( lbann.SafeDivide( lbann.Add(lbann.Constant(value=1.0, hint_layer=y),lbann.Identity(y)), lbann.Subtract(lbann.Constant(value=1.0, hint_layer=y),lbann.Identity(y))), scaling_factors=str(self.datascale)) linear_scale = 1/self.linear_scaler CH2 = lbann.Tanh(lbann.WeightedSum(inv_transform,scaling_factors=str(linear_scale))) return CH2
def Gelu_approx(x): # This approximates gelu and may be more performant # return 0.5 * x * (1 + tanh(sqrt(2 / pi) * (x + 0.044715 * x ** 3))) # Based on: https://paperswithcode.com/method/gelu sqrt_2_over_pi = math.sqrt(2 / math.pi) b_coef = 0.044715 x_cubed = lbann.Multiply(lbann.Multiply(lbann.Identity(x), x), x) inner_tanh_x_comp = lbann.Add(x, lbann.Scale(x_cubed, constant=b_coef)) tanh_x = lbann.Tanh(lbann.Scale(inner_tanh_x_comp, constant=sqrt_2_over_pi)) return lbann.Scale(lbann.Multiply(x, lbann.AddConstant(tanh_x, constant=1)), constant=0.5)
def forward_generator(self, z, mcr): ''' Build the Generator ''' x = lbann.Relu( lbann.BatchNormalization(self.g_fc1(z), decay=0.9, scale_init=1.0, epsilon=1e-5)) dims = '512 8 8' print("dims", dims) x = lbann.Reshape(x, dims=dims) #channel first x = lbann.Relu( lbann.BatchNormalization(self.g_convT[0](x), decay=0.9, scale_init=1.0, epsilon=1e-5)) x = lbann.Relu( lbann.BatchNormalization(self.g_convT[1](x), decay=0.9, scale_init=1.0, epsilon=1e-5)) x = lbann.Relu( lbann.BatchNormalization(self.g_convT[2](x), decay=0.9, scale_init=1.0, epsilon=1e-5)) img = self.g_convT3(x) if mcr: ### For multi-channel rescaling, add extra channel to output image linear_scale = 1 / self.linear_scaler #ch2 = lbann.Tanh(self.inv_transform(img)/linear_scalar) ch2 = lbann.Tanh( lbann.WeightedSum(self.inv_transform(img), scaling_factors=str(linear_scale))) y = lbann.Concatenation(img, ch2, axis=0) img = lbann.Reshape(y, dims='2 128 128') else: img = lbann.Reshape(img, dims='1 128 128') print('Gen Img in GAN', img.__dict__) return img
def forward(self, hidden_states): # We "pool" the model by simply taking the hidden state corresponding # to the first token. first_token_tensor = lbann.Slice(hidden_states, axis=1, slice_points=str_list([0, 1])) pooled_output = lbann.modules.PytorchLinear( first_token_tensor, (self.input_shape[0], self.input_shape[-1]), self.hidden_size, weights=_load_pretrained_weights( ".".join((self.name, "dense.weight")), ".".join((self.name, "dense.bias")), load_weights=self.load_weights, ), name=".".join((self.name, "dense")), ) pooled_output = lbann.Tanh(pooled_output, name=".".join((self.name, "activation"))) return pooled_output
def gen_layers(latent_dim, number_of_atoms): ''' Generates the model for the 3D Convolutional Auto Encoder. returns the Directed Acyclic Graph (DAG) that the lbann model will run on. ''' input_ = lbann.Input(target_mode="reconstruction") tensors = lbann.Identity(input_) tensors = lbann.Reshape(tensors, dims="11 32 32 32", name="Sample") # Input tensor shape is (number_of_atoms)x32x32x32 # Encoder x = lbann.Identity(tensors) for i in range(4): out_channels = latent_dim // (2**(3 - i)) x = lbann.Convolution(x, num_dims=3, num_output_channels=out_channels, num_groups=1, conv_dims_i=4, conv_strides_i=2, conv_dilations_i=1, conv_pads_i=1, has_bias=True, name="Conv_{0}".format(i)) x = lbann.BatchNormalization(x, name="Batch_NORM_{0}".format(i + 1)) x = lbann.LeakyRelu(x, name="Conv_{0}_Activation".format(i + 1)) # Shape: (latent_dim)x2x2x2 encoded = lbann.Convolution(x, num_dims=3, num_output_channels=latent_dim, num_groups=1, conv_dims_i=2, conv_strides_i=2, conv_dilations_i=1, conv_pads_i=0, has_bias=True, name="encoded") # Shape: (latent_dim)1x1x1 # Decoder x = lbann.Deconvolution(encoded, num_dims=3, num_output_channels=number_of_atoms * 16, num_groups=1, conv_dims_i=4, conv_pads_i=0, conv_strides_i=2, conv_dilations_i=1, has_bias=True, name="Deconv_1") x = lbann.BatchNormalization(x, name="BN_D1") x = lbann.Tanh(x, name="Deconv_1_Activation") for i in range(3): out_channels = number_of_atoms * (2**(2 - i)) x = lbann.Deconvolution(x, num_dims=3, num_output_channels=out_channels, num_groups=1, conv_dims_i=4, conv_pads_i=1, conv_strides_i=2, conv_dilations_i=1, has_bias=True, name="Deconv_{0}".format(i + 2)) x = lbann.BatchNormalization(x, name="BN_D{0}".format(i + 2)) if ( i != 2 ): #Save the last activation layer because we want to dump the outputs x = lbann.Tanh(x, name="Deconv_{0}_Activation".format(i + 2)) decoded = lbann.Tanh(x, name="decoded") img_loss = lbann.MeanSquaredError([decoded, tensors]) metrics = [lbann.Metric(img_loss, name='recon_error')] # ---------------------------------- # Set up DAG # ---------------------------------- layers = lbann.traverse_layer_graph(input_) #Generate Model DAG return layers, img_loss, metrics
def forward(self, x, prev_state): """Apply LSTM step. Args: x (Layer): Input. prev_state (tuple with two `Layer`s): State from previous LSTM step. Comprised of LSTM output and cell state. Returns: (Layer, (Layer, Layer)): The output and state (the output and cell state). The state can be passed directly into the next LSTM step. """ self.step += 1 name = '{0}_step{1}'.format(self.name, self.step) # Get output and cell state from previous step prev_output, prev_cell = prev_state # Apply linearity input_concat = lbann.Concatenation(x, prev_output, name=name + '_input', data_layout=self.data_layout) fc = self.fc(input_concat) # Get gates and cell update slice = lbann.Slice(fc, slice_points=str_list([0, self.size, 4*self.size]), name=name + '_fc_slice', data_layout=self.data_layout) cell_update = lbann.Tanh(slice, name=name + '_cell_update', data_layout=self.data_layout) sigmoid = lbann.Sigmoid(slice, name=name + '_sigmoid', data_layout=self.data_layout) slice = lbann.Slice(sigmoid, slice_points=str_list([0, self.size, 2*self.size, 3*self.size]), name=name + '_sigmoid_slice', data_layout=self.data_layout) f = lbann.Identity(slice, name=name + '_forget_gate', data_layout=self.data_layout) i = lbann.Identity(slice, name=name + '_input_gate', data_layout=self.data_layout) o = lbann.Identity(slice, name=name + '_output_gate', data_layout=self.data_layout) # Cell state cell_forget = lbann.Multiply(f, prev_cell, name=name + '_cell_forget', data_layout=self.data_layout) cell_input = lbann.Multiply(i, cell_update, name=name + '_cell_input', data_layout=self.data_layout) cell = lbann.Add(cell_forget, cell_input, name=name + '_cell', data_layout=self.data_layout) # Output cell_act = lbann.Tanh(cell, name=name + '_cell_activation', data_layout=self.data_layout) output = lbann.Multiply(o, cell_act, name=name, data_layout=self.data_layout) # Return output and state return output, (output, cell)
def forward(self, x, prev_state): """Apply GRU step. Args: x (Layer): Input. prev_state: State from previous GRU step. Returns: (Layer, Layer): The output (out) and state (hn). The state can be passed directly into the next GRU step. """ self.step += 1 name = '{0}_step{1}'.format(self.name, self.step) fc1 = self.ih_fc(x) #input_fc fc2 = self.hh_fc(prev_state) #hidden_fc # Get gates and cell update fc1_slice = lbann.Slice(fc1, slice_points=str_list([0, self.size, 2*self.size, 3*self.size]), name=name + '_fc1_slice', data_layout=self.data_layout) Wir_x = lbann.Identity(fc1_slice, name=name + '_Wrx', data_layout=self.data_layout) Wiz_x = lbann.Identity(fc1_slice, name=name + '_Wzx', data_layout=self.data_layout) Win_x = lbann.Identity(fc1_slice, name=name + '_Wnx', data_layout=self.data_layout) fc2_slice = lbann.Slice(fc2, slice_points=str_list([0, self.size, 2*self.size, 3*self.size]), name=name + '_fc2_slice', data_layout=self.data_layout) Whr_prev = lbann.Identity(fc2_slice, name=name + '_Wrh', data_layout=self.data_layout) Whz_prev = lbann.Identity(fc2_slice, name=name + '_Wzh', data_layout=self.data_layout) Whn_prev = lbann.Identity(fc2_slice, name=name + '_Wnh', data_layout=self.data_layout) rt = \ lbann.Sigmoid( lbann.Add(Wir_x, Whr_prev, data_layout=self.data_layout), name=name + '_reset_gate', data_layout=self.data_layout ) zt = \ lbann.Sigmoid( lbann.Add(Wiz_x, Whz_prev, data_layout=self.data_layout), name=name + '_update_gate', data_layout=self.data_layout, ) nt = \ lbann.Tanh( lbann.Add( Win_x, lbann.Multiply(rt, Whn_prev, data_layout=self.data_layout), data_layout=self.data_layout, ), name=name + '_new_gate', data_layout=self.data_layout, ) ht = \ lbann.Add( lbann.Multiply( lbann.WeightedSum( self.ones, zt, scaling_factors='1 -1', data_layout=self.data_layout ), nt, data_layout=self.data_layout ), lbann.Multiply(zt, prev_state, data_layout=self.data_layout), name=name+ '_output', data_layout=self.data_layout, ) # Return output return ht, ht
def forward(self, x, prev_state): """ Apply GRU step channelwise Args: x (Layer): Input (shape: (num_channels, *)) prev_state (Layer): Sate from previous GRU step (shape: (num_channels, size)) Returns: (Layer, Layer): The output (out) and state (hn). The state can be passed directly into the next GRU step """ self.step += 1 name = f"{self.name}_step{self.step}" mat_size = self.num_channels * self.size prev_state = lbann.Reshape(prev_state, dims=str_list( [self.num_channels, self.size]), name=name + "_prev_state_reshape") fc1 = self.ih_fc(x) fc2 = self.hh_fc(prev_state) fc1_slice = lbann.Slice( fc1, axis=1, slice_points=str_list([0, self.size, 2 * self.size, 3 * self.size])) Wir_x = lbann.Reshape(lbann.Identity(fc1_slice), dims=str_list([self.num_channels, self.size]), name=name + '_Wir_x') Wiz_z = lbann.Reshape(lbann.Identity(fc1_slice), dims=str_list([self.num_channels, self.size]), name=name + '_Wiz_z') Win_x = lbann.Reshape(lbann.Identity(fc1_slice), dims=str_list([self.num_channels, self.size]), name=name + '_Win_x') fc2_slice = lbann.Slice( fc2, axis=1, slice_points=str_list([0, self.size, 2 * self.size, 3 * self.size])) Whr_x = lbann.Reshape(lbann.Identity(fc2_slice), dims=str_list([self.num_channels, self.size]), name=name + '_Whr_x') Whz_z = lbann.Reshape(lbann.Identity(fc2_slice), dims=str_list([self.num_channels, self.size]), name=name + '_Whz_z') Whn_x = lbann.Reshape(lbann.Identity(fc2_slice), dims=str_list([self.num_channels, self.size]), name=name + '_Whn_x') rt = \ lbann.Sigmoid( lbann.Add(Wir_x, Whr_x, data_layout=self.data_layout), name=name + '_reset_gate', data_layout=self.data_layout ) zt = \ lbann.Sigmoid( lbann.Add(Wiz_z, Whz_z, data_layout=self.data_layout), name=name + '_update_gate', data_layout=self.data_layout, ) nt = \ lbann.Tanh( lbann.Add( Win_x, lbann.Multiply(rt, Whn_x, data_layout=self.data_layout), data_layout=self.data_layout, ), name=name + '_new_gate', data_layout=self.data_layout, ) ht = \ lbann.Add( lbann.Multiply( lbann.WeightedSum( self.ones, zt, scaling_factors='1 -1', data_layout=self.data_layout ), nt, data_layout=self.data_layout ), lbann.Multiply(zt, prev_state, data_layout=self.data_layout), name=name+ '_output', data_layout=self.data_layout, ) ht = lbann.Reshape(ht, dims=str_list([self.num_channels, self.size])) return ht, ht