コード例 #1
0
    def forward(self, data, time_steps, run_backwards=True, save_info=False):
        # data, time_steps -- observations and their time stamps
        # IMPORTANT: assumes that 'data' already has mask concatenated to it
        assert (not torch.isnan(data).any())
        assert (not torch.isnan(time_steps).any())

        n_traj, n_tp, n_dims = data.size()
        if len(time_steps) == 1:
            prev_y = torch.zeros((1, n_traj, self.latent_dim)).to(self.device)
            prev_std = torch.zeros(
                (1, n_traj, self.latent_dim)).to(self.device)

            xi = data[:, 0, :].unsqueeze(0)

            last_yi, last_yi_std = self.GRU_update(prev_y, prev_std, xi)
            extra_info = None
        else:

            last_yi, last_yi_std, _, extra_info = self.run_odernn(
                data,
                time_steps,
                run_backwards=run_backwards,
                save_info=save_info)

        means_z0 = last_yi.reshape(1, n_traj, self.latent_dim)
        std_z0 = last_yi_std.reshape(1, n_traj, self.latent_dim)

        mean_z0, std_z0 = utils.split_last_dim(
            self.transform_z0(torch.cat((means_z0, std_z0), -1)))
        std_z0 = std_z0.abs()
        if save_info:
            self.extra_info = extra_info

        return mean_z0, std_z0
コード例 #2
0
ファイル: encoder_decoder.py プロジェクト: zlannnn/latent_ode
    def forward(self, data, time_steps, run_backwards=True):
        # IMPORTANT: assumes that 'data' already has mask concatenated to it

        # data shape: [n_traj, n_tp, n_dims]
        # shape required for rnn: (seq_len, batch, input_size)
        # t0: not used here
        n_traj = data.size(0)

        assert (not torch.isnan(data).any())
        assert (not torch.isnan(time_steps).any())

        if mask is not None:
            assert (not torch.isnan(mask).any())

        data = data.permute(1, 0, 2)
        if mask is not None:
            mask = mask.permute(1, 0, 2)
            data = torch.cat((data, mask), -1)

        if run_backwards:
            # Look at data in the reverse order: from later points to the first
            data = utils.reverse(data)

        if self.use_delta_t:
            delta_t = time_steps[1:] - time_steps[:-1]
            if run_backwards:
                # we are going backwards in time with
                delta_t = utils.reverse(delta_t)
            # append zero delta t in the end
            delta_t = torch.cat((delta_t, torch.zeros(1).to(self.device)))
            delta_t = delta_t.unsqueeze(1).repeat((1, n_traj)).unsqueeze(-1)
            data = torch.cat((delta_t, data), -1)

        outputs, _ = self.gru_rnn(data)

        # LSTM output shape: (seq_len, batch, num_directions * hidden_size)
        last_output = outputs[-1]

        self.extra_info = {"rnn_outputs": outputs, "time_points": time_steps}

        mean, std = utils.split_last_dim(self.hiddens_to_z0(last_output))
        std = std.abs()

        assert (not torch.isnan(mean).any())
        assert (not torch.isnan(std).any())

        return mean.unsqueeze(0), std.unsqueeze(0)
コード例 #3
0
    def forward(self, y_mean, y_std, x, masked_update=True):
        y_concat = torch.cat([y_mean, y_std, x], -1)

        update_gate = self.update_gate(y_concat)
        reset_gate = self.reset_gate(y_concat)
        concat = torch.cat([y_mean * reset_gate, y_std * reset_gate, x], -1)

        new_state, new_state_std = utils.split_last_dim(
            self.new_state_net(concat))
        new_state_std = new_state_std.abs()

        new_y = (1 - update_gate) * new_state + update_gate * y_mean
        new_y_std = F.softplus((1 - update_gate) * new_state_std +
                               update_gate * y_std)

        assert (not torch.isnan(new_y).any())

        if masked_update:
            # IMPORTANT: assumes that x contains both data and mask
            # update only the hidden states for hidden state only if at least one feature is present for the current time point
            n_data_dims = x.size(-1) // 2
            mask = x[:, :, n_data_dims:]
            utils.check_mask(x[:, :, :n_data_dims], mask)

            mask = (torch.sum(mask, -1, keepdim=True) > 0).float()

            assert (not torch.isnan(mask).any())

            new_y = mask * new_y + (1 - mask) * y_mean
            # 			new_y_std = mask * new_y_std + (1-mask) * y_std

            if torch.isnan(new_y).any():
                print("new_y is nan!")
                print(mask)
                print(y_mean)
                print(new_y)
                exit()


# 		new_y_std = new_y_std.abs()
        return new_y, new_y_std
コード例 #4
0
    def get_reconstruction(self,
                           time_steps_to_predict,
                           data,
                           truth_time_steps,
                           mask=None,
                           n_traj_samples=1,
                           mode=None):

        assert (mask is not None)

        batch_size = data.size(0)
        zero_delta_t = torch.Tensor([0.]).to(self.device)

        # run encoder backwards
        run_backwards = bool(time_steps_to_predict[0] < truth_time_steps[-1])

        if run_backwards:
            # Look at data in the reverse order: from later points to the first
            data = utils.reverse(data)
            mask = utils.reverse(mask)

        delta_ts = truth_time_steps[1:] - truth_time_steps[:-1]
        if run_backwards:
            # we are going backwards in time
            delta_ts = utils.reverse(delta_ts)

        delta_ts = torch.cat((delta_ts, zero_delta_t))
        if len(delta_ts.size()) == 1:
            # delta_ts are shared for all trajectories in a batch
            assert (data.size(1) == delta_ts.size(0))
            delta_ts = delta_ts.unsqueeze(-1).repeat((batch_size, 1, 1))

        input_decay_params = None
        if self.input_space_decay:
            input_decay_params = (self.w_input_decay, self.b_input_decay)

        hidden_state, _ = run_rnn(data,
                                  delta_ts,
                                  cell=self.rnn_cell_enc,
                                  mask=mask,
                                  input_decay_params=input_decay_params)

        z0_mean, z0_std = utils.split_last_dim(self.z0_net(hidden_state))
        z0_std = z0_std.abs()
        z0_sample = utils.sample_standard_gaussian(z0_mean, z0_std)

        # Decoder # # # # # # # # # # # # # # # # # # # #
        delta_ts = torch.cat(
            (zero_delta_t,
             time_steps_to_predict[1:] - time_steps_to_predict[:-1]))
        if len(delta_ts.size()) == 1:
            delta_ts = delta_ts.unsqueeze(-1).repeat((batch_size, 1, 1))

        _, all_hiddens = run_rnn(data,
                                 delta_ts,
                                 cell=self.rnn_cell_dec,
                                 first_hidden=z0_sample,
                                 feed_previous=True,
                                 n_steps=time_steps_to_predict.size(0),
                                 decoder=self.decoder,
                                 input_decay_params=input_decay_params)

        outputs = self.decoder(all_hiddens)
        # Shift outputs for computing the loss -- we should compare the first output to the second data point, etc.
        first_point = data[:, 0, :]
        outputs = utils.shift_outputs(outputs, first_point)

        extra_info = {
            "first_point":
            (z0_mean.unsqueeze(0), z0_std.unsqueeze(0), z0_sample.unsqueeze(0))
        }

        if self.use_binary_classif:
            if self.classif_per_tp:
                extra_info["label_predictions"] = self.classifier(all_hiddens)
            else:
                extra_info["label_predictions"] = self.classifier(
                    z0_mean).reshape(1, -1)

        # outputs shape: [n_traj_samples, n_traj, n_tp, n_dims]
        return outputs, extra_info