def run_odernn(self, data, time_steps, run_backwards=True, save_info=False): # IMPORTANT: assumes that 'data' already has mask concatenated to it n_traj, n_tp, n_dims = data.size() extra_info = [] t0 = time_steps[-1] if run_backwards: t0 = time_steps[0] device = get_device(data) prev_y = torch.zeros((1, n_traj, self.latent_dim)).to(device) prev_std = torch.zeros((1, n_traj, self.latent_dim)).to(device) prev_t, t_i = time_steps[-1] + 0.01, time_steps[-1] interval_length = time_steps[-1] - time_steps[0] minimum_step = interval_length / 50 # print("minimum step: {}".format(minimum_step)) assert (not torch.isnan(data).any()) assert (not torch.isnan(time_steps).any()) latent_ys = [] # Run ODE backwards and combine the y(t) estimates using gating time_points_iter = range(0, len(time_steps)) if run_backwards: time_points_iter = reversed(time_points_iter) for i in time_points_iter: if (prev_t - t_i) < minimum_step: time_points = torch.stack((prev_t, t_i)) inc = self.z0_diffeq_solver.ode_func(prev_t, prev_y) * (t_i - prev_t) assert (not torch.isnan(inc).any()) ode_sol = prev_y + inc ode_sol = torch.stack((prev_y, ode_sol), 2).to(device) assert (not torch.isnan(ode_sol).any()) else: n_intermediate_tp = max(2, ((prev_t - t_i) / minimum_step).int()) time_points = utils.linspace_vector(prev_t, t_i, n_intermediate_tp) ode_sol = self.z0_diffeq_solver(prev_y, time_points) assert (not torch.isnan(ode_sol).any()) if torch.mean(ode_sol[:, :, 0, :] - prev_y) >= 0.001: print( "Error: first point of the ODE is not equal to initial value" ) print(torch.mean(ode_sol[:, :, 0, :] - prev_y)) exit() # assert(torch.mean(ode_sol[:, :, 0, :] - prev_y) < 0.001) yi_ode = ode_sol[:, :, -1, :] xi = data[:, i, :].unsqueeze(0) yi, yi_std = self.GRU_update(yi_ode, prev_std, xi) prev_y, prev_std = yi, yi_std prev_t, t_i = time_steps[i], time_steps[i - 1] latent_ys.append(yi) if save_info: d = { "yi_ode": yi_ode.detach(), # "yi_from_data": yi_from_data, "yi": yi.detach(), "yi_std": yi_std.detach(), "time_points": time_points.detach(), "ode_sol": ode_sol.detach() } extra_info.append(d) latent_ys = torch.stack(latent_ys, 1) assert (not torch.isnan(yi).any()) assert (not torch.isnan(yi_std).any()) return yi, yi_std, latent_ys, extra_info
def run_odernn(self, data, time_steps, run_backwards=True, save_info=False, testing=False): # IMPORTANT: assumes that 'data' already has mask concatenated to it n_traj, n_tp, n_dims = data.size() extra_info = [] save_latents = 10 if testing else 0 device = get_device(data) # Initialize the hidden state with noise if self.RNNcell == 'lstm': # make some noise prev_h = torch.zeros( (1, n_traj, self.latent_dim // 2)).data.normal_(0, 0.0001).to(device) prev_h_std = torch.zeros( (1, n_traj, self.latent_dim // 2)).data.normal_(0, 0.0001).to(device) ci = torch.zeros( (1, n_traj, self.latent_dim // 2)).data.normal_(0, 0.0001).to(device) ci_std = torch.zeros( (1, n_traj, self.latent_dim // 2)).data.normal_(0, 0.0001).to(device) #concatinate cell state and hidden state prev_y = torch.cat([prev_h, ci], -1) prev_std = torch.cat([prev_h_std, ci_std], -1) else: # make some noise prev_y = torch.zeros( (1, n_traj, self.latent_dim)).data.normal_(0, 0.0001).to(device) prev_std = torch.zeros( (1, n_traj, self.latent_dim)).data.normal_(0, 0.0001).to(device) #prev_t, t_i = time_steps[-1] + 0.01, time_steps[-1] # original #prev_t = time_steps[0] - 0.00001 # new2 t_i = time_steps[0] - 0.00001 # new interval_length = time_steps[-1] - time_steps[0] minimum_step = interval_length / 200 # maybe have to modify minimum time step # original #minimum_step = interval_length / 100 # maybe have to modify minimum time step # new #print("minimum step: {}".format(minimum_step)) assert (not torch.isnan(data).any()) assert (not torch.isnan(time_steps).any()) latent_ys = [] firststep = True # Run ODE backwards and combine the y(t) estimates using gating time_points_iter = range(0, len(time_steps)) if run_backwards: time_points_iter = reversed(time_points_iter) # Get positional encoding # position_encodings = utils.get_sinusoid_encoding_table(time_steps*2*math.pi, d_hid=2) # note: 2*pi is already included? [0,1] is enough position_encodings = utils.get_sinusoid_encoding_table( time_steps.cpu().numpy(), d_hid=2) ### Check if experimental is on/off ### experimental = False for i in time_points_iter: # move time step to the next interval #t_i = time_steps[i] # new2 prev_t = time_steps[i] # new n_intermediate_tp = self.n_intermediate_tp # get steps in between, minimum is 2 if save_latents != 0: n_intermediate_tp = max( 2, ((prev_t - t_i) / minimum_step ).int()) # get more steps in between for testing time_points = utils.linspace_vector(prev_t, t_i, n_intermediate_tp) if experimental: time_points = time_points.flip(0) #Include inplementationin case of no ODE function if self.use_ODE: if abs(prev_t - t_i) < minimum_step: #short integration, linear approximation with the gradient time_points = torch.stack((prev_t, t_i)) inc = self.z0_diffeq_solver.ode_func( prev_t, prev_y) * (t_i - prev_t) assert (not torch.isnan(inc).any()) ode_sol = prev_y + inc ode_sol = torch.stack((prev_y, ode_sol), 2).to(device) assert (not torch.isnan(ode_sol).any()) else: #complete Integration using differential equation solver ode_sol = self.z0_diffeq_solver(prev_y, time_points) assert (not torch.isnan(ode_sol).any()) if torch.mean(ode_sol[:, :, 0, :] - prev_y) >= 0.001: print( "Error: first point of the ODE is not equal to initial value" ) print(torch.mean(ode_sol[:, :, 0, :] - prev_y)) exit() yi_ode = ode_sol[:, :, -1, :] xi = data[:, i, :].unsqueeze(0) else: # skipping ODE function and assign directly yi_ode = prev_y time_points = time_points[-1] # extract the mask for the current (single) time step single_mask = data[:, i, self.input_dim // 2] delta_ts = (prev_t - t_i).repeat(1, n_traj, 1).float() delta_ts[:, ~single_mask.bool(), :] = 0 if self.nornnimputation: delta_ts[:, :, :] = 0 features = data[:, i, :self.input_dim // 2].unsqueeze(0) if not self.use_pos_encod: new_mask = single_mask.unsqueeze(0).unsqueeze(2).repeat( 1, 1, self.input_dim // 2 + 1) xi = torch.cat([features, delta_ts, new_mask], -1) else: pos_encod = position_encodings[i].repeat(1, n_traj, 1).float() pos_encod[:, ~single_mask.bool(), :] = 0 new_mask = single_mask.unsqueeze(0).unsqueeze(2).repeat( 1, 1, self.input_dim // 2 + 2) xi = torch.cat([features, pos_encod, new_mask], -1) #creating new data including delta ts plus mask, concaninate the delta t for pure RNN if self.RNNcell == 'lstm': # In case of LSTM update, we have to take special care of the variables for the hidden and cell state h_i_ode = yi_ode[:, :, :self.latent_dim // 2] c_i_ode = yi_ode[:, :, self.latent_dim // 2:] h_c_lstm = (h_i_ode, c_i_ode) # actually this is a LSTM update here: outi, yi_std = self.RNN_update(h_c_lstm, prev_std, xi) # the RNN cell is a LSTM and outi:=(yi,ci), we only need h as latent dim h_i_, c_i_ = outi[0], outi[1] yi = torch.cat([h_i_, c_i_], -1) yi_out = h_i_ if not self.use_ODE: ode_sol = yi_out.unsqueeze(2) time_points = time_points.unsqueeze(0) else: # GRU-unit or any other RNN cell: the output is directly the hidden state yi_ode, prev_std = self.RNN_update(yi_ode, prev_std, xi) yi, yi_std = yi_ode, prev_std yi_out = yi if not self.use_ODE: ode_sol = yi_ode.unsqueeze(2) time_points = time_points.unsqueeze(0) prev_y, prev_std = yi, yi_std #prev_t, t_i = time_steps[i], time_steps[i-1] # original #prev_t = time_steps[i] # new2 t_i = time_steps[i] # new latent_ys.append(yi_out) if save_info or save_latents: if self.use_ODE: #ODE-RNN case ODE_flags = (xi[:, :, self.latent_dim:].sum( (0, 2)) == 0).cpu().detach().int().numpy( ) # zero: RNN-update, one: ODE-update marker = np.ones((n_traj, n_intermediate_tp)) marker[:, -1] = ODE_flags if not firststep: #marker[:,0] = old_ODE_flags pass else: firststep = False old_ODE_flags = ODE_flags else: #RNN case marker = ( (xi[:, :, (self.latent_dim + 1):].sum( (0, 2)) == 0).cpu().detach().int().numpy() * 2 )[:, np.newaxis] # zero: RNN-update, two: No update at all d = { "yi_ode": yi_ode[:, :save_latents].cpu().detach( ), #"yi_from_data": yi_from_data, "yi": yi_out[:, :save_latents].cpu().detach()[:, :save_latents], "yi_std": yi_std[:, :save_latents].cpu().detach(), "time_points": time_points.cpu().detach().double(), "ode_sol": ode_sol[:, :save_latents].cpu().detach().double(), "marker": marker[:save_latents] } """ if save_info or testing: d = {"yi_ode": yi_ode.detach()[:,:20], #"yi_from_data": yi_from_data, "yi": yi_out.detach()[:,:20], "yi_std": yi_std.detach()[:,:20], "time_points": time_points.detach(), "ode_sol": ode_sol.detach()[:,:20] } extra_info.append(d) """ latent_ys = torch.stack(latent_ys, 1) #BatchNormalization for the outputs if self.use_BN: # only apply BN to the RNN converted outputs (observed times), the one that are further used... # Experimental: for selective BN of outputs fancy_BN = False if fancy_BN: # not faster due to non-contigious data obs_mask = data[:, :, self.input_dim // 2].permute(1, 0) latent_ys[:, obs_mask.bool()] = self.output_bn( latent_ys[:, obs_mask.bool()].permute(0, 2, 1)).permute(0, 2, 1) else: latent_ys = self.output_bn(latent_ys.squeeze().permute( 0, 2, 1)).permute(0, 2, 1).unsqueeze(0) #orig assert (not torch.isnan(yi).any()) assert (not torch.isnan(yi_std).any()) return yi, yi_std, latent_ys, extra_info
def draw_all_plots_one_dim(self, data_dict, model, plot_name="", save=False, experimentID=0.): data = data_dict["data_to_predict"] time_steps = data_dict["tp_to_predict"] mask = data_dict["mask_predicted_data"] observed_data = data_dict["observed_data"] observed_time_steps = data_dict["observed_tp"] observed_mask = data_dict["observed_mask"] device = get_device(time_steps) time_steps_to_predict = time_steps if isinstance(model, LatentODE): # sample at the original time points time_steps_to_predict = utils.linspace_vector( time_steps[0], time_steps[-1], 100).to(device) reconstructions, info = model.get_reconstruction(time_steps_to_predict, observed_data, observed_time_steps, mask=observed_mask, n_traj_samples=10) n_traj_to_show = 3 # plot only 10 trajectories data_for_plotting = observed_data[:n_traj_to_show] mask_for_plotting = observed_mask[:n_traj_to_show] reconstructions_for_plotting = reconstructions.mean( dim=0)[:n_traj_to_show] reconstr_std = reconstructions.std(dim=0)[:n_traj_to_show] dim_to_show = 0 max_y = max(data_for_plotting[:, :, dim_to_show].cpu().numpy().max(), reconstructions[:, :, dim_to_show].cpu().numpy().max()) min_y = min(data_for_plotting[:, :, dim_to_show].cpu().numpy().min(), reconstructions[:, :, dim_to_show].cpu().numpy().min()) ############################################ # Plot reconstructions, true postrior and approximate posterior cmap = plt.cm.get_cmap('Set1') for traj_id in range(3): # Plot observations plot_trajectories( self.ax_traj[traj_id], data_for_plotting[traj_id].unsqueeze(0), observed_time_steps, mask=mask_for_plotting[traj_id].unsqueeze(0), min_y=min_y, max_y=max_y, #title="True trajectories", marker='o', linestyle='', dim_to_show=dim_to_show, color=cmap(2)) # Plot reconstructions plot_trajectories( self.ax_traj[traj_id], reconstructions_for_plotting[traj_id].unsqueeze(0), time_steps_to_predict, min_y=min_y, max_y=max_y, title="Sample {} (data space)".format(traj_id), dim_to_show=dim_to_show, add_to_plot=True, marker='', color=cmap(3), linewidth=3) # Plot variance estimated over multiple samples from approx posterior plot_std(self.ax_traj[traj_id], reconstructions_for_plotting[traj_id].unsqueeze(0), reconstr_std[traj_id].unsqueeze(0), time_steps_to_predict, alpha=0.5, color=cmap(3)) self.set_plot_lims(self.ax_traj[traj_id], "traj_" + str(traj_id)) # Plot true posterior and approximate posterior # self.draw_one_density_plot(self.ax_density[traj_id], # model, data_dict, traj_id = traj_id, # multiply_by_poisson = False) # self.set_plot_lims(self.ax_density[traj_id], "density_" + str(traj_id)) # self.ax_density[traj_id].set_title("Sample {}: p(z0) and q(z0 | x)".format(traj_id)) ############################################ # Get several samples for the same trajectory # one_traj = data_for_plotting[:1] # first_point = one_traj[:,0] # samples_same_traj, _ = model.get_reconstruction(time_steps_to_predict, # observed_data[:1], observed_time_steps, mask = observed_mask[:1], n_traj_samples = 5) # samples_same_traj = samples_same_traj.squeeze(1) # plot_trajectories(self.ax_samples_same_traj, samples_same_traj, time_steps_to_predict, marker = '') # plot_trajectories(self.ax_samples_same_traj, one_traj, time_steps, linestyle = "", # label = "True traj", add_to_plot = True, title="Reconstructions for the same trajectory (data space)") ############################################ # Plot trajectories from prior if isinstance(model, LatentODE): torch.manual_seed(1991) np.random.seed(1991) traj_from_prior = model.sample_traj_from_prior( time_steps_to_predict, n_traj_samples=3) # Since in this case n_traj = 1, n_traj_samples -- requested number of samples from the prior, squeeze n_traj dimension traj_from_prior = traj_from_prior.squeeze(1) plot_trajectories(self.ax_traj_from_prior, traj_from_prior, time_steps_to_predict, marker='', linewidth=3) self.ax_traj_from_prior.set_title( "Samples from prior (data space)", pad=20) #self.set_plot_lims(self.ax_traj_from_prior, "traj_from_prior") ################################################ # Plot z0 # first_point_mu, first_point_std, first_point_enc = info["first_point"] # dim1 = 0 # dim2 = 1 # self.ax_z0.cla() # # first_point_enc shape: [1, n_traj, n_dims] # self.ax_z0.scatter(first_point_enc.cpu()[0,:,dim1], first_point_enc.cpu()[0,:,dim2]) # self.ax_z0.set_title("Encodings z0 of all test trajectories (latent space)") # self.ax_z0.set_xlabel('dim {}'.format(dim1)) # self.ax_z0.set_ylabel('dim {}'.format(dim2)) ################################################ # Show vector field self.ax_vector_field.cla() plot_vector_field(self.ax_vector_field, model.diffeq_solver.ode_func, model.latent_dim, device) self.ax_vector_field.set_title("Slice of vector field (latent space)", pad=20) self.set_plot_lims(self.ax_vector_field, "vector_field") #self.ax_vector_field.set_ylim((-0.5, 1.5)) ################################################ # Plot trajectories in the latent space # shape before [1, n_traj, n_tp, n_latent_dims] # Take only the first sample from approx posterior latent_traj = info["latent_traj"][0, :n_traj_to_show] # shape before permute: [1, n_tp, n_latent_dims] self.ax_latent_traj.cla() cmap = plt.cm.get_cmap('Accent') n_latent_dims = latent_traj.size(-1) custom_labels = {} for i in range(n_latent_dims): col = cmap(i) plot_trajectories(self.ax_latent_traj, latent_traj, time_steps_to_predict, title="Latent trajectories z(t) (latent space)", dim_to_show=i, color=col, marker='', add_to_plot=True, linewidth=3) custom_labels['dim ' + str(i)] = Line2D([0], [0], color=col) self.ax_latent_traj.set_ylabel("z") self.ax_latent_traj.set_title( "Latent trajectories z(t) (latent space)", pad=20) self.ax_latent_traj.legend(custom_labels.values(), custom_labels.keys(), loc='lower left') self.set_plot_lims(self.ax_latent_traj, "latent_traj") ################################################ self.fig.tight_layout() plt.draw() if save: dirname = "plots/" + str(experimentID) + "/" os.makedirs(dirname, exist_ok=True) self.fig.savefig(dirname + plot_name)
def run_odernn(self, data, time_steps, minimum_step = None, run_backwards = False, save_info = False): # IMPORTANT: assumes that 'data' already has mask concatenated to it n_traj, n_tp, n_dims = data.size() extra_info = [] device = self.device prev_y = torch.zeros((1, n_traj, self.latent_dim)).to(device) #prev_std = torch.zeros((1, n_traj, self.latent_dim)).to(device) init_condition = (torch.zeros((n_traj, 1)).to(device), #A torch.zeros((n_traj, self.latent_dim)).to(device), #C torch.zeros((n_traj, self.latent_dim)).to(device), #K torch.zeros((n_traj, self.latent_dim, self.latent_dim)).to(device)) #V prev_t, t_i = time_steps[0] - 0.01, time_steps[0] interval_length = time_steps[-1] - time_steps[0] minimum_step = interval_length / 50 #print("minimum step: {}".format(minimum_step)) assert(not torch.isnan(data).any()) assert(not torch.isnan(time_steps).any()) latent_ys = [] record_condition = [] # Run ODE backwards and combine the y(t) estimates using gating time_points_iter = range(0, len(time_steps)) #if run_backwards: # time_points_iter = reversed(time_points_iter) for i in time_points_iter: if ( t_i - prev_t ) < minimum_step: time_points = torch.stack((prev_t, t_i)) tuple_sol = self.diffeq_solver.ode_func(prev_t, (prev_y.squeeze(),) + init_condition) ode_sol = prev_y + tuple_sol[0].unsqueeze(0) * (t_i - prev_t) ode_sol = torch.stack((prev_y, ode_sol), 2).to(device) init_condition = tuple(i+ j*(t_i - prev_t) for i,j in zip(init_condition, tuple_sol[1:])) assert(not torch.isnan(ode_sol).any()) else : n_intermediate_tp = 2 if minimum_step is not None: n_intermediate_tp = max(2, ((t_i - prev_t) / minimum_step).int()) time_points = utils.linspace_vector(prev_t, t_i, n_intermediate_tp) tuple_sol = self.diffeq_solver(prev_y, time_points, init_condition = init_condition) ode_sol = tuple_sol[0] init_condition = tuple(i[-1] for i in tuple_sol[1:]) assert(not torch.isnan(ode_sol).any()) if torch.mean(ode_sol[:, :, 0, :] - prev_y) >= 0.001: print("Error: first point of the ODE is not equal to initial value") print(torch.mean(ode_sol[:, :, 0, :] - prev_y)) exit() #assert(torch.mean(ode_sol[:, :, 0, :] - prev_y) < 0.001) yi_ode = ode_sol[:, :, -1, :] xi = data[:,i,:].unsqueeze(0) yi = self.GRU_update(yi_ode, xi) prev_y = yi if i+1<len(time_steps): prev_t, t_i = time_steps[i], time_steps[i+1] else : prev_t = time_steps[i] latent_ys.append(yi) record_condition.append(init_condition[1:3]) # [n_tp, ((n_traj, 1), (n_traj, n_dims))] if save_info: d = {"att_score": sum_att_score[:,0,:].detach().cpu().numpy(), "time_points": time_points.detach().cpu().numpy()} extra_info.append(d) latent_ys = torch.stack(latent_ys, 1) assert(not torch.isnan(yi).any()) context_vector = torch.stack([ i[1]/i[0] for i in record_condition], 1) # (n_traj, n_tp, n_dims) return yi, latent_ys, context_vector, extra_info