def forward(self, data, time_steps, run_backwards=True, save_info=False): # data, time_steps -- observations and their time stamps # IMPORTANT: assumes that 'data' already has mask concatenated to it assert (not torch.isnan(data).any()) assert (not torch.isnan(time_steps).any()) n_traj, n_tp, n_dims = data.size() if len(time_steps) == 1: prev_y = torch.zeros((1, n_traj, self.latent_dim)).to(self.device) prev_std = torch.zeros( (1, n_traj, self.latent_dim)).to(self.device) xi = data[:, 0, :].unsqueeze(0) last_yi, last_yi_std = self.GRU_update(prev_y, prev_std, xi) extra_info = None else: last_yi, last_yi_std, _, extra_info = self.run_odernn( data, time_steps, run_backwards=run_backwards, save_info=save_info) means_z0 = last_yi.reshape(1, n_traj, self.latent_dim) std_z0 = last_yi_std.reshape(1, n_traj, self.latent_dim) mean_z0, std_z0 = utils.split_last_dim( self.transform_z0(torch.cat((means_z0, std_z0), -1))) std_z0 = std_z0.abs() if save_info: self.extra_info = extra_info return mean_z0, std_z0
def forward(self, data, time_steps, run_backwards=True): # IMPORTANT: assumes that 'data' already has mask concatenated to it # data shape: [n_traj, n_tp, n_dims] # shape required for rnn: (seq_len, batch, input_size) # t0: not used here n_traj = data.size(0) assert (not torch.isnan(data).any()) assert (not torch.isnan(time_steps).any()) if mask is not None: assert (not torch.isnan(mask).any()) data = data.permute(1, 0, 2) if mask is not None: mask = mask.permute(1, 0, 2) data = torch.cat((data, mask), -1) if run_backwards: # Look at data in the reverse order: from later points to the first data = utils.reverse(data) if self.use_delta_t: delta_t = time_steps[1:] - time_steps[:-1] if run_backwards: # we are going backwards in time with delta_t = utils.reverse(delta_t) # append zero delta t in the end delta_t = torch.cat((delta_t, torch.zeros(1).to(self.device))) delta_t = delta_t.unsqueeze(1).repeat((1, n_traj)).unsqueeze(-1) data = torch.cat((delta_t, data), -1) outputs, _ = self.gru_rnn(data) # LSTM output shape: (seq_len, batch, num_directions * hidden_size) last_output = outputs[-1] self.extra_info = {"rnn_outputs": outputs, "time_points": time_steps} mean, std = utils.split_last_dim(self.hiddens_to_z0(last_output)) std = std.abs() assert (not torch.isnan(mean).any()) assert (not torch.isnan(std).any()) return mean.unsqueeze(0), std.unsqueeze(0)
def forward(self, y_mean, y_std, x, masked_update=True): y_concat = torch.cat([y_mean, y_std, x], -1) update_gate = self.update_gate(y_concat) reset_gate = self.reset_gate(y_concat) concat = torch.cat([y_mean * reset_gate, y_std * reset_gate, x], -1) new_state, new_state_std = utils.split_last_dim( self.new_state_net(concat)) new_state_std = new_state_std.abs() new_y = (1 - update_gate) * new_state + update_gate * y_mean new_y_std = F.softplus((1 - update_gate) * new_state_std + update_gate * y_std) assert (not torch.isnan(new_y).any()) if masked_update: # IMPORTANT: assumes that x contains both data and mask # update only the hidden states for hidden state only if at least one feature is present for the current time point n_data_dims = x.size(-1) // 2 mask = x[:, :, n_data_dims:] utils.check_mask(x[:, :, :n_data_dims], mask) mask = (torch.sum(mask, -1, keepdim=True) > 0).float() assert (not torch.isnan(mask).any()) new_y = mask * new_y + (1 - mask) * y_mean # new_y_std = mask * new_y_std + (1-mask) * y_std if torch.isnan(new_y).any(): print("new_y is nan!") print(mask) print(y_mean) print(new_y) exit() # new_y_std = new_y_std.abs() return new_y, new_y_std
def get_reconstruction(self, time_steps_to_predict, data, truth_time_steps, mask=None, n_traj_samples=1, mode=None): assert (mask is not None) batch_size = data.size(0) zero_delta_t = torch.Tensor([0.]).to(self.device) # run encoder backwards run_backwards = bool(time_steps_to_predict[0] < truth_time_steps[-1]) if run_backwards: # Look at data in the reverse order: from later points to the first data = utils.reverse(data) mask = utils.reverse(mask) delta_ts = truth_time_steps[1:] - truth_time_steps[:-1] if run_backwards: # we are going backwards in time delta_ts = utils.reverse(delta_ts) delta_ts = torch.cat((delta_ts, zero_delta_t)) if len(delta_ts.size()) == 1: # delta_ts are shared for all trajectories in a batch assert (data.size(1) == delta_ts.size(0)) delta_ts = delta_ts.unsqueeze(-1).repeat((batch_size, 1, 1)) input_decay_params = None if self.input_space_decay: input_decay_params = (self.w_input_decay, self.b_input_decay) hidden_state, _ = run_rnn(data, delta_ts, cell=self.rnn_cell_enc, mask=mask, input_decay_params=input_decay_params) z0_mean, z0_std = utils.split_last_dim(self.z0_net(hidden_state)) z0_std = z0_std.abs() z0_sample = utils.sample_standard_gaussian(z0_mean, z0_std) # Decoder # # # # # # # # # # # # # # # # # # # # delta_ts = torch.cat( (zero_delta_t, time_steps_to_predict[1:] - time_steps_to_predict[:-1])) if len(delta_ts.size()) == 1: delta_ts = delta_ts.unsqueeze(-1).repeat((batch_size, 1, 1)) _, all_hiddens = run_rnn(data, delta_ts, cell=self.rnn_cell_dec, first_hidden=z0_sample, feed_previous=True, n_steps=time_steps_to_predict.size(0), decoder=self.decoder, input_decay_params=input_decay_params) outputs = self.decoder(all_hiddens) # Shift outputs for computing the loss -- we should compare the first output to the second data point, etc. first_point = data[:, 0, :] outputs = utils.shift_outputs(outputs, first_point) extra_info = { "first_point": (z0_mean.unsqueeze(0), z0_std.unsqueeze(0), z0_sample.unsqueeze(0)) } if self.use_binary_classif: if self.classif_per_tp: extra_info["label_predictions"] = self.classifier(all_hiddens) else: extra_info["label_predictions"] = self.classifier( z0_mean).reshape(1, -1) # outputs shape: [n_traj_samples, n_traj, n_tp, n_dims] return outputs, extra_info