def forward(self, input_batch): # Create data batch for the RNN out = input_batch.view(input_batch.size(0), input_batch.size(1), input_batch.size(2) * input_batch.size(3)) out = self.relu(self.first_layer(out)) out = out + self.relu(self.second_layer(out)) out = out + self.relu(self.third_layer(out)) out = self.relu(self.fourth_layer(out)) # RNN consumes batch backwards to create z0 reversed_mini_batch = utils.reverse_sequences_torch(out) h0 = torch.zeros(self.rnn_layers, input_batch.size(0), self.rnn.hidden_size, device=input_batch.device) rnn_output, _ = self.rnn(reversed_mini_batch, h0) rnn_output = rnn_output[:, -1] z_0_loc = self.rnn_to_z0_loc(rnn_output) z_0_log_var = self.rnn_to_z0_log_var(rnn_output) # LSTM creates params lstm_all_output, _ = self.lstm(out) lstm_output = lstm_all_output[:, -1] latent_params_loc = self.lstm_to_latent_loc(lstm_output) latent_params_log_var = self.lstm_to_latent_log_var(lstm_output) return z_0_loc, z_0_log_var, latent_params_loc, latent_params_log_var
def forward(self, mini_batch): mini_batch = self.input_to_rnn_net(mini_batch) reversed_mini_batch = utils.reverse_sequences_torch(mini_batch) rnn_output, _ = self.rnn(reversed_mini_batch) rnn_output = rnn_output[:, -1] z_0_loc = self.rnn_to_latent_loc(rnn_output) z_0_log_var = self.rnn_to_latent_log_var(rnn_output) return z_0_loc, z_0_log_var
def forward(self, mini_batch): mini_batch = mini_batch.view(mini_batch.size(0), mini_batch.size(1), mini_batch.size(2) * mini_batch.size(3)) mini_batch = self.relu(self.first_layer(mini_batch)) mini_batch = mini_batch + self.relu(self.second_layer(mini_batch)) mini_batch = mini_batch + self.relu(self.third_layer(mini_batch)) mini_batch = self.relu(self.fourth_layer(mini_batch)) reversed_mini_batch = utils.reverse_sequences_torch(mini_batch) rnn_output, _ = self.rnn(reversed_mini_batch) rnn_output = rnn_output[:, -1] z_0_loc = self.rnn_to_latent_loc(rnn_output) z_0_log_var = self.rnn_to_latent_log_var(rnn_output) return z_0_loc, z_0_log_var
def forward(self, mini_batch): mini_batch = self.input_to_rnn_net(mini_batch) reversed_mini_batch = utils.reverse_sequences_torch(mini_batch) rnn_output, _ = self.rnn(reversed_mini_batch) rnn_output = rnn_output[:, -1] latent_z_0_loc = self.rnn_to_latent_loc(rnn_output) latent_z_0_log_var = self.rnn_to_latent_log_var(rnn_output) lstm_all_output, _ = self.lstm(mini_batch) lstm_output = lstm_all_output[:, -1] latent_params_loc = self.lstm_to_latent_loc(lstm_output) latent_params_log_var = self.lstm_to_latent_log_var(lstm_output) return latent_z_0_loc, latent_z_0_log_var, latent_params_loc, latent_params_log_var