def forward(self, x): """Perform LSTM step. State from previous steps is used to compute output. """ self.step += 1 name = '{0}_step{1}'.format(self.name, self.step) # Apply linearity input_concat = lbann.Concatenation([x, self.last_output], name=name + '_input', data_layout=self.data_layout) fc = self.fc(input_concat) # Get gates and cell update slice = lbann.Slice(fc, slice_points=_str_list([0, self.size, 4*self.size]), name=name + '_fc_slice', data_layout=self.data_layout) cell_update = lbann.Tanh(slice, name=name + '_cell_update', data_layout=self.data_layout) sigmoid = lbann.Sigmoid(slice, name=name + '_sigmoid', data_layout=self.data_layout) slice = lbann.Slice(sigmoid, slice_points=_str_list([0, self.size, 2*self.size, 3*self.size]), name=name + '_sigmoid_slice', data_layout=self.data_layout) f = lbann.Identity(slice, name=name + '_forget_gate', data_layout=self.data_layout) i = lbann.Identity(slice, name=name + '_input_gate', data_layout=self.data_layout) o = lbann.Identity(slice, name=name + '_output_gate', data_layout=self.data_layout) # Cell state cell_forget = lbann.Multiply([f, self.last_cell], name=name + '_cell_forget', data_layout=self.data_layout) cell_input = lbann.Multiply([i, cell_update], name=name + '_cell_input', data_layout=self.data_layout) cell = lbann.Add([cell_forget, cell_input], name=name + '_cell', data_layout=self.data_layout) # Output cell_act = lbann.Tanh(cell, name=name + '_cell_activation', data_layout=self.data_layout) output = lbann.Multiply([o, cell_act], name=name, data_layout=self.data_layout) # Update state and return output self.last_cell = cell self.last_output = output return output
def Silu(x): return lbann.Multiply(x, lbann.Sigmoid(x))
decode3neuron = lbann.Relu(decode3, name="decode3neuron") decode2 = lbann.FullyConnected(decode3neuron, name="decode2", has_bias=True, hint_layer=encode1) decode2neuron = lbann.Relu(decode2, name="decode2neuron") decode1 = lbann.FullyConnected(decode2neuron, name="decode1", has_bias=True, hint_layer=image) # Reconstruction error reconstruction = lbann.Sigmoid(decode1, name="reconstruction") bin_cross_entropy = lbann.SigmoidBinaryCrossEntropy([decode1, image], name="bin_cross_entropy") bin_cross_entropy_sum = lbann.Reduction(bin_cross_entropy, name="bin_cross_entropy_sum", mode="sum") mean_squared_error = lbann.MeanSquaredError([reconstruction, image], name="mean_squared_error") layer_list = list(lbann.traverse_layer_graph(input_)) # Set up objective function layer_term1 = lbann.LayerTerm(bin_cross_entropy)
relu1 = lbann.Relu(encode1, name="relu1", data_layout="model_parallel") dropout1 = lbann.Dropout(relu1, name="dropout1", data_layout="model_parallel", keep_prob=0.8) decode1 = lbann.FullyConnected(dropout1, name="decode1", data_layout="model_parallel", hint_layer=image, has_bias=True) reconstruction = lbann.Sigmoid(decode1, name="reconstruction", data_layout="model_parallel") dropout2 = lbann.Dropout(reconstruction, name="dropout2", data_layout="model_parallel", keep_prob=0.8) # Reconstruction mean_squared_error = lbann.MeanSquaredError([dropout2, image], name="mean_squared_error") layer_term = lbann.LayerTerm(mean_squared_error) obj = lbann.ObjectiveFunction(layer_term) metrics = [lbann.Metric(mean_squared_error, name=mean_squared_error.name)]
def forward(self, x, prev_state): """Apply LSTM step. Args: x (Layer): Input. prev_state (tuple with two `Layer`s): State from previous LSTM step. Comprised of LSTM output and cell state. Returns: (Layer, (Layer, Layer)): The output and state (the output and cell state). The state can be passed directly into the next LSTM step. """ self.step += 1 name = '{0}_step{1}'.format(self.name, self.step) # Get output and cell state from previous step prev_output, prev_cell = prev_state # Apply linearity input_concat = lbann.Concatenation(x, prev_output, name=name + '_input', data_layout=self.data_layout) fc = self.fc(input_concat) # Get gates and cell update slice = lbann.Slice(fc, slice_points=str_list([0, self.size, 4*self.size]), name=name + '_fc_slice', data_layout=self.data_layout) cell_update = lbann.Tanh(slice, name=name + '_cell_update', data_layout=self.data_layout) sigmoid = lbann.Sigmoid(slice, name=name + '_sigmoid', data_layout=self.data_layout) slice = lbann.Slice(sigmoid, slice_points=str_list([0, self.size, 2*self.size, 3*self.size]), name=name + '_sigmoid_slice', data_layout=self.data_layout) f = lbann.Identity(slice, name=name + '_forget_gate', data_layout=self.data_layout) i = lbann.Identity(slice, name=name + '_input_gate', data_layout=self.data_layout) o = lbann.Identity(slice, name=name + '_output_gate', data_layout=self.data_layout) # Cell state cell_forget = lbann.Multiply(f, prev_cell, name=name + '_cell_forget', data_layout=self.data_layout) cell_input = lbann.Multiply(i, cell_update, name=name + '_cell_input', data_layout=self.data_layout) cell = lbann.Add(cell_forget, cell_input, name=name + '_cell', data_layout=self.data_layout) # Output cell_act = lbann.Tanh(cell, name=name + '_cell_activation', data_layout=self.data_layout) output = lbann.Multiply(o, cell_act, name=name, data_layout=self.data_layout) # Return output and state return output, (output, cell)
def forward(self, x, prev_state): """Apply GRU step. Args: x (Layer): Input. prev_state: State from previous GRU step. Returns: (Layer, Layer): The output (out) and state (hn). The state can be passed directly into the next GRU step. """ self.step += 1 name = '{0}_step{1}'.format(self.name, self.step) fc1 = self.ih_fc(x) #input_fc fc2 = self.hh_fc(prev_state) #hidden_fc # Get gates and cell update fc1_slice = lbann.Slice(fc1, slice_points=str_list([0, self.size, 2*self.size, 3*self.size]), name=name + '_fc1_slice', data_layout=self.data_layout) Wir_x = lbann.Identity(fc1_slice, name=name + '_Wrx', data_layout=self.data_layout) Wiz_x = lbann.Identity(fc1_slice, name=name + '_Wzx', data_layout=self.data_layout) Win_x = lbann.Identity(fc1_slice, name=name + '_Wnx', data_layout=self.data_layout) fc2_slice = lbann.Slice(fc2, slice_points=str_list([0, self.size, 2*self.size, 3*self.size]), name=name + '_fc2_slice', data_layout=self.data_layout) Whr_prev = lbann.Identity(fc2_slice, name=name + '_Wrh', data_layout=self.data_layout) Whz_prev = lbann.Identity(fc2_slice, name=name + '_Wzh', data_layout=self.data_layout) Whn_prev = lbann.Identity(fc2_slice, name=name + '_Wnh', data_layout=self.data_layout) rt = \ lbann.Sigmoid( lbann.Add(Wir_x, Whr_prev, data_layout=self.data_layout), name=name + '_reset_gate', data_layout=self.data_layout ) zt = \ lbann.Sigmoid( lbann.Add(Wiz_x, Whz_prev, data_layout=self.data_layout), name=name + '_update_gate', data_layout=self.data_layout, ) nt = \ lbann.Tanh( lbann.Add( Win_x, lbann.Multiply(rt, Whn_prev, data_layout=self.data_layout), data_layout=self.data_layout, ), name=name + '_new_gate', data_layout=self.data_layout, ) ht = \ lbann.Add( lbann.Multiply( lbann.WeightedSum( self.ones, zt, scaling_factors='1 -1', data_layout=self.data_layout ), nt, data_layout=self.data_layout ), lbann.Multiply(zt, prev_state, data_layout=self.data_layout), name=name+ '_output', data_layout=self.data_layout, ) # Return output return ht, ht
def forward(self, x, prev_state): """ Apply GRU step channelwise Args: x (Layer): Input (shape: (num_channels, *)) prev_state (Layer): Sate from previous GRU step (shape: (num_channels, size)) Returns: (Layer, Layer): The output (out) and state (hn). The state can be passed directly into the next GRU step """ self.step += 1 name = f"{self.name}_step{self.step}" mat_size = self.num_channels * self.size prev_state = lbann.Reshape(prev_state, dims=str_list( [self.num_channels, self.size]), name=name + "_prev_state_reshape") fc1 = self.ih_fc(x) fc2 = self.hh_fc(prev_state) fc1_slice = lbann.Slice( fc1, axis=1, slice_points=str_list([0, self.size, 2 * self.size, 3 * self.size])) Wir_x = lbann.Reshape(lbann.Identity(fc1_slice), dims=str_list([self.num_channels, self.size]), name=name + '_Wir_x') Wiz_z = lbann.Reshape(lbann.Identity(fc1_slice), dims=str_list([self.num_channels, self.size]), name=name + '_Wiz_z') Win_x = lbann.Reshape(lbann.Identity(fc1_slice), dims=str_list([self.num_channels, self.size]), name=name + '_Win_x') fc2_slice = lbann.Slice( fc2, axis=1, slice_points=str_list([0, self.size, 2 * self.size, 3 * self.size])) Whr_x = lbann.Reshape(lbann.Identity(fc2_slice), dims=str_list([self.num_channels, self.size]), name=name + '_Whr_x') Whz_z = lbann.Reshape(lbann.Identity(fc2_slice), dims=str_list([self.num_channels, self.size]), name=name + '_Whz_z') Whn_x = lbann.Reshape(lbann.Identity(fc2_slice), dims=str_list([self.num_channels, self.size]), name=name + '_Whn_x') rt = \ lbann.Sigmoid( lbann.Add(Wir_x, Whr_x, data_layout=self.data_layout), name=name + '_reset_gate', data_layout=self.data_layout ) zt = \ lbann.Sigmoid( lbann.Add(Wiz_z, Whz_z, data_layout=self.data_layout), name=name + '_update_gate', data_layout=self.data_layout, ) nt = \ lbann.Tanh( lbann.Add( Win_x, lbann.Multiply(rt, Whn_x, data_layout=self.data_layout), data_layout=self.data_layout, ), name=name + '_new_gate', data_layout=self.data_layout, ) ht = \ lbann.Add( lbann.Multiply( lbann.WeightedSum( self.ones, zt, scaling_factors='1 -1', data_layout=self.data_layout ), nt, data_layout=self.data_layout ), lbann.Multiply(zt, prev_state, data_layout=self.data_layout), name=name+ '_output', data_layout=self.data_layout, ) ht = lbann.Reshape(ht, dims=str_list([self.num_channels, self.size])) return ht, ht