示例#1
0
    def run_odernn(self,
                   data,
                   time_steps,
                   run_backwards=True,
                   save_info=False):
        # IMPORTANT: assumes that 'data' already has mask concatenated to it

        n_traj, n_tp, n_dims = data.size()
        extra_info = []

        t0 = time_steps[-1]
        if run_backwards:
            t0 = time_steps[0]

        device = get_device(data)

        prev_y = torch.zeros((1, n_traj, self.latent_dim)).to(device)
        prev_std = torch.zeros((1, n_traj, self.latent_dim)).to(device)

        prev_t, t_i = time_steps[-1] + 0.01, time_steps[-1]

        interval_length = time_steps[-1] - time_steps[0]
        minimum_step = interval_length / 50

        # print("minimum step: {}".format(minimum_step))

        assert (not torch.isnan(data).any())
        assert (not torch.isnan(time_steps).any())

        latent_ys = []
        # Run ODE backwards and combine the y(t) estimates using gating
        time_points_iter = range(0, len(time_steps))
        if run_backwards:
            time_points_iter = reversed(time_points_iter)

        for i in time_points_iter:
            if (prev_t - t_i) < minimum_step:
                time_points = torch.stack((prev_t, t_i))
                inc = self.z0_diffeq_solver.ode_func(prev_t,
                                                     prev_y) * (t_i - prev_t)

                assert (not torch.isnan(inc).any())

                ode_sol = prev_y + inc
                ode_sol = torch.stack((prev_y, ode_sol), 2).to(device)

                assert (not torch.isnan(ode_sol).any())
            else:
                n_intermediate_tp = max(2,
                                        ((prev_t - t_i) / minimum_step).int())

                time_points = utils.linspace_vector(prev_t, t_i,
                                                    n_intermediate_tp)
                ode_sol = self.z0_diffeq_solver(prev_y, time_points)

                assert (not torch.isnan(ode_sol).any())

            if torch.mean(ode_sol[:, :, 0, :] - prev_y) >= 0.001:
                print(
                    "Error: first point of the ODE is not equal to initial value"
                )
                print(torch.mean(ode_sol[:, :, 0, :] - prev_y))
                exit()
            # assert(torch.mean(ode_sol[:, :, 0, :]  - prev_y) < 0.001)

            yi_ode = ode_sol[:, :, -1, :]
            xi = data[:, i, :].unsqueeze(0)

            yi, yi_std = self.GRU_update(yi_ode, prev_std, xi)

            prev_y, prev_std = yi, yi_std
            prev_t, t_i = time_steps[i], time_steps[i - 1]

            latent_ys.append(yi)

            if save_info:
                d = {
                    "yi_ode": yi_ode.detach(),  # "yi_from_data": yi_from_data,
                    "yi": yi.detach(),
                    "yi_std": yi_std.detach(),
                    "time_points": time_points.detach(),
                    "ode_sol": ode_sol.detach()
                }
                extra_info.append(d)

        latent_ys = torch.stack(latent_ys, 1)

        assert (not torch.isnan(yi).any())
        assert (not torch.isnan(yi_std).any())

        return yi, yi_std, latent_ys, extra_info
示例#2
0
    def run_odernn(self,
                   data,
                   time_steps,
                   run_backwards=True,
                   save_info=False,
                   testing=False):
        # IMPORTANT: assumes that 'data' already has mask concatenated to it

        n_traj, n_tp, n_dims = data.size()
        extra_info = []
        save_latents = 10 if testing else 0

        device = get_device(data)

        # Initialize the hidden state with noise
        if self.RNNcell == 'lstm':
            # make some noise
            prev_h = torch.zeros(
                (1, n_traj,
                 self.latent_dim // 2)).data.normal_(0, 0.0001).to(device)
            prev_h_std = torch.zeros(
                (1, n_traj,
                 self.latent_dim // 2)).data.normal_(0, 0.0001).to(device)

            ci = torch.zeros(
                (1, n_traj,
                 self.latent_dim // 2)).data.normal_(0, 0.0001).to(device)
            ci_std = torch.zeros(
                (1, n_traj,
                 self.latent_dim // 2)).data.normal_(0, 0.0001).to(device)

            #concatinate cell state and hidden state
            prev_y = torch.cat([prev_h, ci], -1)
            prev_std = torch.cat([prev_h_std, ci_std], -1)
        else:
            # make some noise
            prev_y = torch.zeros(
                (1, n_traj, self.latent_dim)).data.normal_(0,
                                                           0.0001).to(device)
            prev_std = torch.zeros(
                (1, n_traj, self.latent_dim)).data.normal_(0,
                                                           0.0001).to(device)

        #prev_t, t_i = time_steps[-1] + 0.01,  time_steps[-1] # original
        #prev_t = time_steps[0] - 0.00001 # new2
        t_i = time_steps[0] - 0.00001  # new

        interval_length = time_steps[-1] - time_steps[0]
        minimum_step = interval_length / 200  # maybe have to modify minimum time step # original
        #minimum_step = interval_length / 100 # maybe have to modify minimum time step # new

        #print("minimum step: {}".format(minimum_step))

        assert (not torch.isnan(data).any())
        assert (not torch.isnan(time_steps).any())

        latent_ys = []
        firststep = True

        # Run ODE backwards and combine the y(t) estimates using gating
        time_points_iter = range(0, len(time_steps))
        if run_backwards:
            time_points_iter = reversed(time_points_iter)

        # Get positional encoding
        # position_encodings = utils.get_sinusoid_encoding_table(time_steps*2*math.pi, d_hid=2) # note: 2*pi is already included? [0,1] is enough
        position_encodings = utils.get_sinusoid_encoding_table(
            time_steps.cpu().numpy(), d_hid=2)

        ### Check if experimental is on/off ###
        experimental = False

        for i in time_points_iter:

            # move time step to the next interval

            #t_i = time_steps[i]							# new2
            prev_t = time_steps[i]  # new

            n_intermediate_tp = self.n_intermediate_tp  # get steps in between, minimum is 2
            if save_latents != 0:
                n_intermediate_tp = max(
                    2, ((prev_t - t_i) / minimum_step
                        ).int())  # get more steps in between for testing

            time_points = utils.linspace_vector(prev_t, t_i, n_intermediate_tp)

            if experimental:
                time_points = time_points.flip(0)

            #Include inplementationin case of no ODE function
            if self.use_ODE:

                if abs(prev_t - t_i) < minimum_step:
                    #short integration, linear approximation with the gradient
                    time_points = torch.stack((prev_t, t_i))
                    inc = self.z0_diffeq_solver.ode_func(
                        prev_t, prev_y) * (t_i - prev_t)

                    assert (not torch.isnan(inc).any())

                    ode_sol = prev_y + inc
                    ode_sol = torch.stack((prev_y, ode_sol), 2).to(device)

                    assert (not torch.isnan(ode_sol).any())

                else:
                    #complete Integration using differential equation solver
                    ode_sol = self.z0_diffeq_solver(prev_y, time_points)

                    assert (not torch.isnan(ode_sol).any())

                if torch.mean(ode_sol[:, :, 0, :] - prev_y) >= 0.001:
                    print(
                        "Error: first point of the ODE is not equal to initial value"
                    )
                    print(torch.mean(ode_sol[:, :, 0, :] - prev_y))
                    exit()

                yi_ode = ode_sol[:, :, -1, :]

                xi = data[:, i, :].unsqueeze(0)

            else:

                # skipping ODE function and assign directly
                yi_ode = prev_y
                time_points = time_points[-1]

                # extract the mask for the current (single) time step
                single_mask = data[:, i, self.input_dim // 2]
                delta_ts = (prev_t - t_i).repeat(1, n_traj, 1).float()
                delta_ts[:, ~single_mask.bool(), :] = 0

                if self.nornnimputation:
                    delta_ts[:, :, :] = 0

                features = data[:, i, :self.input_dim // 2].unsqueeze(0)

                if not self.use_pos_encod:
                    new_mask = single_mask.unsqueeze(0).unsqueeze(2).repeat(
                        1, 1, self.input_dim // 2 + 1)
                    xi = torch.cat([features, delta_ts, new_mask], -1)
                else:
                    pos_encod = position_encodings[i].repeat(1, n_traj,
                                                             1).float()
                    pos_encod[:, ~single_mask.bool(), :] = 0
                    new_mask = single_mask.unsqueeze(0).unsqueeze(2).repeat(
                        1, 1, self.input_dim // 2 + 2)
                    xi = torch.cat([features, pos_encod, new_mask], -1)
                #creating new data including delta ts plus mask, concaninate the delta t for pure RNN

            if self.RNNcell == 'lstm':
                # In case of LSTM update, we have to take special care of the variables for the hidden and cell state
                h_i_ode = yi_ode[:, :, :self.latent_dim // 2]
                c_i_ode = yi_ode[:, :, self.latent_dim // 2:]
                h_c_lstm = (h_i_ode, c_i_ode)

                # actually this is a LSTM update here:
                outi, yi_std = self.RNN_update(h_c_lstm, prev_std, xi)
                # the RNN cell is a LSTM and outi:=(yi,ci), we only need h as latent dim
                h_i_, c_i_ = outi[0], outi[1]
                yi = torch.cat([h_i_, c_i_], -1)
                yi_out = h_i_

                if not self.use_ODE:
                    ode_sol = yi_out.unsqueeze(2)
                    time_points = time_points.unsqueeze(0)
            else:
                # GRU-unit or any other RNN cell: the output is directly the hidden state
                yi_ode, prev_std = self.RNN_update(yi_ode, prev_std, xi)
                yi, yi_std = yi_ode, prev_std
                yi_out = yi

                if not self.use_ODE:
                    ode_sol = yi_ode.unsqueeze(2)
                    time_points = time_points.unsqueeze(0)

            prev_y, prev_std = yi, yi_std
            #prev_t, t_i = time_steps[i],  time_steps[i-1]	# original
            #prev_t = time_steps[i]								# new2
            t_i = time_steps[i]  # new

            latent_ys.append(yi_out)
            if save_info or save_latents:
                if self.use_ODE:
                    #ODE-RNN case
                    ODE_flags = (xi[:, :, self.latent_dim:].sum(
                        (0, 2)) == 0).cpu().detach().int().numpy(
                        )  # zero: RNN-update, one: ODE-update
                    marker = np.ones((n_traj, n_intermediate_tp))
                    marker[:, -1] = ODE_flags

                    if not firststep:
                        #marker[:,0] = old_ODE_flags
                        pass
                    else:
                        firststep = False
                    old_ODE_flags = ODE_flags
                else:
                    #RNN case
                    marker = (
                        (xi[:, :, (self.latent_dim + 1):].sum(
                            (0, 2)) == 0).cpu().detach().int().numpy() * 2
                    )[:, np.newaxis]  # zero: RNN-update, two: No update at all

                d = {
                    "yi_ode": yi_ode[:, :save_latents].cpu().detach(
                    ),  #"yi_from_data": yi_from_data,
                    "yi":
                    yi_out[:, :save_latents].cpu().detach()[:, :save_latents],
                    "yi_std": yi_std[:, :save_latents].cpu().detach(),
                    "time_points": time_points.cpu().detach().double(),
                    "ode_sol":
                    ode_sol[:, :save_latents].cpu().detach().double(),
                    "marker": marker[:save_latents]
                }
            """
			if save_info or testing:
				d = {"yi_ode": yi_ode.detach()[:,:20], #"yi_from_data": yi_from_data,
					 "yi": yi_out.detach()[:,:20], "yi_std": yi_std.detach()[:,:20], 
					 "time_points": time_points.detach(),
					 "ode_sol": ode_sol.detach()[:,:20]
				}
				extra_info.append(d)
			"""

        latent_ys = torch.stack(latent_ys, 1)

        #BatchNormalization for the outputs
        if self.use_BN:

            # only apply BN to the RNN converted outputs (observed times), the one that are further used...
            # Experimental: for selective BN of outputs
            fancy_BN = False
            if fancy_BN:
                # not faster due to non-contigious data
                obs_mask = data[:, :, self.input_dim // 2].permute(1, 0)
                latent_ys[:, obs_mask.bool()] = self.output_bn(
                    latent_ys[:, obs_mask.bool()].permute(0, 2,
                                                          1)).permute(0, 2, 1)
            else:
                latent_ys = self.output_bn(latent_ys.squeeze().permute(
                    0, 2, 1)).permute(0, 2, 1).unsqueeze(0)  #orig

        assert (not torch.isnan(yi).any())
        assert (not torch.isnan(yi_std).any())

        return yi, yi_std, latent_ys, extra_info
示例#3
0
    def draw_all_plots_one_dim(self,
                               data_dict,
                               model,
                               plot_name="",
                               save=False,
                               experimentID=0.):

        data = data_dict["data_to_predict"]
        time_steps = data_dict["tp_to_predict"]
        mask = data_dict["mask_predicted_data"]

        observed_data = data_dict["observed_data"]
        observed_time_steps = data_dict["observed_tp"]
        observed_mask = data_dict["observed_mask"]

        device = get_device(time_steps)

        time_steps_to_predict = time_steps
        if isinstance(model, LatentODE):
            # sample at the original time points
            time_steps_to_predict = utils.linspace_vector(
                time_steps[0], time_steps[-1], 100).to(device)

        reconstructions, info = model.get_reconstruction(time_steps_to_predict,
                                                         observed_data,
                                                         observed_time_steps,
                                                         mask=observed_mask,
                                                         n_traj_samples=10)

        n_traj_to_show = 3
        # plot only 10 trajectories
        data_for_plotting = observed_data[:n_traj_to_show]
        mask_for_plotting = observed_mask[:n_traj_to_show]
        reconstructions_for_plotting = reconstructions.mean(
            dim=0)[:n_traj_to_show]
        reconstr_std = reconstructions.std(dim=0)[:n_traj_to_show]

        dim_to_show = 0
        max_y = max(data_for_plotting[:, :, dim_to_show].cpu().numpy().max(),
                    reconstructions[:, :, dim_to_show].cpu().numpy().max())
        min_y = min(data_for_plotting[:, :, dim_to_show].cpu().numpy().min(),
                    reconstructions[:, :, dim_to_show].cpu().numpy().min())

        ############################################
        # Plot reconstructions, true postrior and approximate posterior

        cmap = plt.cm.get_cmap('Set1')
        for traj_id in range(3):
            # Plot observations
            plot_trajectories(
                self.ax_traj[traj_id],
                data_for_plotting[traj_id].unsqueeze(0),
                observed_time_steps,
                mask=mask_for_plotting[traj_id].unsqueeze(0),
                min_y=min_y,
                max_y=max_y,  #title="True trajectories", 
                marker='o',
                linestyle='',
                dim_to_show=dim_to_show,
                color=cmap(2))
            # Plot reconstructions
            plot_trajectories(
                self.ax_traj[traj_id],
                reconstructions_for_plotting[traj_id].unsqueeze(0),
                time_steps_to_predict,
                min_y=min_y,
                max_y=max_y,
                title="Sample {} (data space)".format(traj_id),
                dim_to_show=dim_to_show,
                add_to_plot=True,
                marker='',
                color=cmap(3),
                linewidth=3)
            # Plot variance estimated over multiple samples from approx posterior
            plot_std(self.ax_traj[traj_id],
                     reconstructions_for_plotting[traj_id].unsqueeze(0),
                     reconstr_std[traj_id].unsqueeze(0),
                     time_steps_to_predict,
                     alpha=0.5,
                     color=cmap(3))
            self.set_plot_lims(self.ax_traj[traj_id], "traj_" + str(traj_id))

            # Plot true posterior and approximate posterior
            # self.draw_one_density_plot(self.ax_density[traj_id],
            # 	model, data_dict, traj_id = traj_id,
            # 	multiply_by_poisson = False)
            # self.set_plot_lims(self.ax_density[traj_id], "density_" + str(traj_id))
            # self.ax_density[traj_id].set_title("Sample {}: p(z0) and q(z0 | x)".format(traj_id))
        ############################################
        # Get several samples for the same trajectory
        # one_traj = data_for_plotting[:1]
        # first_point = one_traj[:,0]

        # samples_same_traj, _ = model.get_reconstruction(time_steps_to_predict,
        # 	observed_data[:1], observed_time_steps, mask = observed_mask[:1], n_traj_samples = 5)
        # samples_same_traj = samples_same_traj.squeeze(1)

        # plot_trajectories(self.ax_samples_same_traj, samples_same_traj, time_steps_to_predict, marker = '')
        # plot_trajectories(self.ax_samples_same_traj, one_traj, time_steps, linestyle = "",
        # 	label = "True traj", add_to_plot = True, title="Reconstructions for the same trajectory (data space)")

        ############################################
        # Plot trajectories from prior

        if isinstance(model, LatentODE):
            torch.manual_seed(1991)
            np.random.seed(1991)

            traj_from_prior = model.sample_traj_from_prior(
                time_steps_to_predict, n_traj_samples=3)
            # Since in this case n_traj = 1, n_traj_samples -- requested number of samples from the prior, squeeze n_traj dimension
            traj_from_prior = traj_from_prior.squeeze(1)

            plot_trajectories(self.ax_traj_from_prior,
                              traj_from_prior,
                              time_steps_to_predict,
                              marker='',
                              linewidth=3)
            self.ax_traj_from_prior.set_title(
                "Samples from prior (data space)", pad=20)
            #self.set_plot_lims(self.ax_traj_from_prior, "traj_from_prior")
        ################################################

        # Plot z0
        # first_point_mu, first_point_std, first_point_enc = info["first_point"]

        # dim1 = 0
        # dim2 = 1
        # self.ax_z0.cla()
        # # first_point_enc shape: [1, n_traj, n_dims]
        # self.ax_z0.scatter(first_point_enc.cpu()[0,:,dim1], first_point_enc.cpu()[0,:,dim2])
        # self.ax_z0.set_title("Encodings z0 of all test trajectories (latent space)")
        # self.ax_z0.set_xlabel('dim {}'.format(dim1))
        # self.ax_z0.set_ylabel('dim {}'.format(dim2))

        ################################################
        # Show vector field
        self.ax_vector_field.cla()
        plot_vector_field(self.ax_vector_field, model.diffeq_solver.ode_func,
                          model.latent_dim, device)
        self.ax_vector_field.set_title("Slice of vector field (latent space)",
                                       pad=20)
        self.set_plot_lims(self.ax_vector_field, "vector_field")
        #self.ax_vector_field.set_ylim((-0.5, 1.5))

        ################################################
        # Plot trajectories in the latent space

        # shape before [1, n_traj, n_tp, n_latent_dims]
        # Take only the first sample from approx posterior
        latent_traj = info["latent_traj"][0, :n_traj_to_show]
        # shape before permute: [1, n_tp, n_latent_dims]

        self.ax_latent_traj.cla()
        cmap = plt.cm.get_cmap('Accent')
        n_latent_dims = latent_traj.size(-1)

        custom_labels = {}
        for i in range(n_latent_dims):
            col = cmap(i)
            plot_trajectories(self.ax_latent_traj,
                              latent_traj,
                              time_steps_to_predict,
                              title="Latent trajectories z(t) (latent space)",
                              dim_to_show=i,
                              color=col,
                              marker='',
                              add_to_plot=True,
                              linewidth=3)
            custom_labels['dim ' + str(i)] = Line2D([0], [0], color=col)

        self.ax_latent_traj.set_ylabel("z")
        self.ax_latent_traj.set_title(
            "Latent trajectories z(t) (latent space)", pad=20)
        self.ax_latent_traj.legend(custom_labels.values(),
                                   custom_labels.keys(),
                                   loc='lower left')
        self.set_plot_lims(self.ax_latent_traj, "latent_traj")

        ################################################

        self.fig.tight_layout()
        plt.draw()

        if save:
            dirname = "plots/" + str(experimentID) + "/"
            os.makedirs(dirname, exist_ok=True)
            self.fig.savefig(dirname + plot_name)
    def run_odernn(self, data, time_steps, minimum_step = None,
        run_backwards = False, save_info = False):
        # IMPORTANT: assumes that 'data' already has mask concatenated to it 

        n_traj, n_tp, n_dims = data.size()
        extra_info = []

        device = self.device

        prev_y = torch.zeros((1, n_traj, self.latent_dim)).to(device)
        #prev_std = torch.zeros((1, n_traj, self.latent_dim)).to(device)
        init_condition = (torch.zeros((n_traj, 1)).to(device), #A
                          torch.zeros((n_traj, self.latent_dim)).to(device), #C
                          torch.zeros((n_traj, self.latent_dim)).to(device), #K
                          torch.zeros((n_traj, self.latent_dim, self.latent_dim)).to(device)) #V

        prev_t, t_i = time_steps[0] - 0.01,  time_steps[0]

        interval_length = time_steps[-1] - time_steps[0]
        minimum_step = interval_length / 50

        #print("minimum step: {}".format(minimum_step))

        assert(not torch.isnan(data).any())
        assert(not torch.isnan(time_steps).any())

        latent_ys = []
        record_condition = []
        # Run ODE backwards and combine the y(t) estimates using gating
        time_points_iter = range(0, len(time_steps))
        #if run_backwards:
        #    time_points_iter = reversed(time_points_iter)

        for i in time_points_iter:
            if ( t_i - prev_t ) < minimum_step:
                time_points = torch.stack((prev_t, t_i))
                tuple_sol = self.diffeq_solver.ode_func(prev_t, (prev_y.squeeze(),) + init_condition) 
                
                ode_sol = prev_y + tuple_sol[0].unsqueeze(0) * (t_i - prev_t)
                ode_sol = torch.stack((prev_y, ode_sol), 2).to(device)
                init_condition = tuple(i+ j*(t_i - prev_t) for i,j in zip(init_condition, tuple_sol[1:]))
                assert(not torch.isnan(ode_sol).any())
            else :
                n_intermediate_tp = 2
                if minimum_step is not None:
                    n_intermediate_tp = max(2, ((t_i - prev_t) / minimum_step).int())
                    
                time_points = utils.linspace_vector(prev_t, t_i, n_intermediate_tp)
                tuple_sol = self.diffeq_solver(prev_y, time_points,  init_condition = init_condition)
                ode_sol = tuple_sol[0]
                init_condition = tuple(i[-1] for i in tuple_sol[1:])
                assert(not torch.isnan(ode_sol).any())

            if torch.mean(ode_sol[:, :, 0, :]  - prev_y) >= 0.001:
                print("Error: first point of the ODE is not equal to initial value")
                print(torch.mean(ode_sol[:, :, 0, :]  - prev_y))
                exit()
            #assert(torch.mean(ode_sol[:, :, 0, :]  - prev_y) < 0.001)

            yi_ode = ode_sol[:, :, -1, :]
            xi = data[:,i,:].unsqueeze(0)
            
            yi = self.GRU_update(yi_ode, xi)
            
            prev_y = yi
            if i+1<len(time_steps):
                prev_t, t_i = time_steps[i],  time_steps[i+1]
            else :
                prev_t = time_steps[i]

            latent_ys.append(yi)
            record_condition.append(init_condition[1:3]) # [n_tp, ((n_traj, 1), (n_traj, n_dims))]

            if save_info:
                d = {"att_score": sum_att_score[:,0,:].detach().cpu().numpy(),
                     "time_points": time_points.detach().cpu().numpy()}
                extra_info.append(d)
                

        latent_ys = torch.stack(latent_ys, 1)

        assert(not torch.isnan(yi).any())
        
        context_vector = torch.stack([ i[1]/i[0] for i in record_condition], 1) # (n_traj, n_tp, n_dims)
        return yi, latent_ys, context_vector, extra_info