def forward(self, data, time_steps, run_backwards=True): # IMPORTANT: assumes that 'data' already has mask concatenated to it # data shape: [n_traj, n_tp, n_dims] # shape required for rnn: (seq_len, batch, input_size) # t0: not used here n_traj = data.size(0) assert (not torch.isnan(data).any()) assert (not torch.isnan(time_steps).any()) if mask is not None: assert (not torch.isnan(mask).any()) data = data.permute(1, 0, 2) if mask is not None: mask = mask.permute(1, 0, 2) data = torch.cat((data, mask), -1) if run_backwards: # Look at data in the reverse order: from later points to the first data = utils.reverse(data) if self.use_delta_t: delta_t = time_steps[1:] - time_steps[:-1] if run_backwards: # we are going backwards in time with delta_t = utils.reverse(delta_t) # append zero delta t in the end delta_t = torch.cat((delta_t, torch.zeros(1).to(self.device))) delta_t = delta_t.unsqueeze(1).repeat((1, n_traj)).unsqueeze(-1) data = torch.cat((delta_t, data), -1) outputs, _ = self.gru_rnn(data) # LSTM output shape: (seq_len, batch, num_directions * hidden_size) last_output = outputs[-1] self.extra_info = {"rnn_outputs": outputs, "time_points": time_steps} mean, std = utils.split_last_dim(self.hiddens_to_z0(last_output)) std = std.abs() assert (not torch.isnan(mean).any()) assert (not torch.isnan(std).any()) return mean.unsqueeze(0), std.unsqueeze(0)
def get_reconstruction(self, time_steps_to_predict, data, truth_time_steps, mask=None, n_traj_samples=1, mode=None): assert (mask is not None) batch_size = data.size(0) zero_delta_t = torch.Tensor([0.]).to(self.device) # run encoder backwards run_backwards = bool(time_steps_to_predict[0] < truth_time_steps[-1]) if run_backwards: # Look at data in the reverse order: from later points to the first data = utils.reverse(data) mask = utils.reverse(mask) delta_ts = truth_time_steps[1:] - truth_time_steps[:-1] if run_backwards: # we are going backwards in time delta_ts = utils.reverse(delta_ts) delta_ts = torch.cat((delta_ts, zero_delta_t)) if len(delta_ts.size()) == 1: # delta_ts are shared for all trajectories in a batch assert (data.size(1) == delta_ts.size(0)) delta_ts = delta_ts.unsqueeze(-1).repeat((batch_size, 1, 1)) input_decay_params = None if self.input_space_decay: input_decay_params = (self.w_input_decay, self.b_input_decay) hidden_state, _ = run_rnn(data, delta_ts, cell=self.rnn_cell_enc, mask=mask, input_decay_params=input_decay_params) z0_mean, z0_std = utils.split_last_dim(self.z0_net(hidden_state)) z0_std = z0_std.abs() z0_sample = utils.sample_standard_gaussian(z0_mean, z0_std) # Decoder # # # # # # # # # # # # # # # # # # # # delta_ts = torch.cat( (zero_delta_t, time_steps_to_predict[1:] - time_steps_to_predict[:-1])) if len(delta_ts.size()) == 1: delta_ts = delta_ts.unsqueeze(-1).repeat((batch_size, 1, 1)) _, all_hiddens = run_rnn(data, delta_ts, cell=self.rnn_cell_dec, first_hidden=z0_sample, feed_previous=True, n_steps=time_steps_to_predict.size(0), decoder=self.decoder, input_decay_params=input_decay_params) outputs = self.decoder(all_hiddens) # Shift outputs for computing the loss -- we should compare the first output to the second data point, etc. first_point = data[:, 0, :] outputs = utils.shift_outputs(outputs, first_point) extra_info = { "first_point": (z0_mean.unsqueeze(0), z0_std.unsqueeze(0), z0_sample.unsqueeze(0)) } if self.use_binary_classif: if self.classif_per_tp: extra_info["label_predictions"] = self.classifier(all_hiddens) else: extra_info["label_predictions"] = self.classifier( z0_mean).reshape(1, -1) # outputs shape: [n_traj_samples, n_traj, n_tp, n_dims] return outputs, extra_info