Example #1
0
    def forward(self, data, time_steps, run_backwards=True):
        # IMPORTANT: assumes that 'data' already has mask concatenated to it

        # data shape: [n_traj, n_tp, n_dims]
        # shape required for rnn: (seq_len, batch, input_size)
        # t0: not used here
        n_traj = data.size(0)

        assert (not torch.isnan(data).any())
        assert (not torch.isnan(time_steps).any())

        if mask is not None:
            assert (not torch.isnan(mask).any())

        data = data.permute(1, 0, 2)
        if mask is not None:
            mask = mask.permute(1, 0, 2)
            data = torch.cat((data, mask), -1)

        if run_backwards:
            # Look at data in the reverse order: from later points to the first
            data = utils.reverse(data)

        if self.use_delta_t:
            delta_t = time_steps[1:] - time_steps[:-1]
            if run_backwards:
                # we are going backwards in time with
                delta_t = utils.reverse(delta_t)
            # append zero delta t in the end
            delta_t = torch.cat((delta_t, torch.zeros(1).to(self.device)))
            delta_t = delta_t.unsqueeze(1).repeat((1, n_traj)).unsqueeze(-1)
            data = torch.cat((delta_t, data), -1)

        outputs, _ = self.gru_rnn(data)

        # LSTM output shape: (seq_len, batch, num_directions * hidden_size)
        last_output = outputs[-1]

        self.extra_info = {"rnn_outputs": outputs, "time_points": time_steps}

        mean, std = utils.split_last_dim(self.hiddens_to_z0(last_output))
        std = std.abs()

        assert (not torch.isnan(mean).any())
        assert (not torch.isnan(std).any())

        return mean.unsqueeze(0), std.unsqueeze(0)
    def get_reconstruction(self,
                           time_steps_to_predict,
                           data,
                           truth_time_steps,
                           mask=None,
                           n_traj_samples=1,
                           mode=None):

        assert (mask is not None)

        batch_size = data.size(0)
        zero_delta_t = torch.Tensor([0.]).to(self.device)

        # run encoder backwards
        run_backwards = bool(time_steps_to_predict[0] < truth_time_steps[-1])

        if run_backwards:
            # Look at data in the reverse order: from later points to the first
            data = utils.reverse(data)
            mask = utils.reverse(mask)

        delta_ts = truth_time_steps[1:] - truth_time_steps[:-1]
        if run_backwards:
            # we are going backwards in time
            delta_ts = utils.reverse(delta_ts)

        delta_ts = torch.cat((delta_ts, zero_delta_t))
        if len(delta_ts.size()) == 1:
            # delta_ts are shared for all trajectories in a batch
            assert (data.size(1) == delta_ts.size(0))
            delta_ts = delta_ts.unsqueeze(-1).repeat((batch_size, 1, 1))

        input_decay_params = None
        if self.input_space_decay:
            input_decay_params = (self.w_input_decay, self.b_input_decay)

        hidden_state, _ = run_rnn(data,
                                  delta_ts,
                                  cell=self.rnn_cell_enc,
                                  mask=mask,
                                  input_decay_params=input_decay_params)

        z0_mean, z0_std = utils.split_last_dim(self.z0_net(hidden_state))
        z0_std = z0_std.abs()
        z0_sample = utils.sample_standard_gaussian(z0_mean, z0_std)

        # Decoder # # # # # # # # # # # # # # # # # # # #
        delta_ts = torch.cat(
            (zero_delta_t,
             time_steps_to_predict[1:] - time_steps_to_predict[:-1]))
        if len(delta_ts.size()) == 1:
            delta_ts = delta_ts.unsqueeze(-1).repeat((batch_size, 1, 1))

        _, all_hiddens = run_rnn(data,
                                 delta_ts,
                                 cell=self.rnn_cell_dec,
                                 first_hidden=z0_sample,
                                 feed_previous=True,
                                 n_steps=time_steps_to_predict.size(0),
                                 decoder=self.decoder,
                                 input_decay_params=input_decay_params)

        outputs = self.decoder(all_hiddens)
        # Shift outputs for computing the loss -- we should compare the first output to the second data point, etc.
        first_point = data[:, 0, :]
        outputs = utils.shift_outputs(outputs, first_point)

        extra_info = {
            "first_point":
            (z0_mean.unsqueeze(0), z0_std.unsqueeze(0), z0_sample.unsqueeze(0))
        }

        if self.use_binary_classif:
            if self.classif_per_tp:
                extra_info["label_predictions"] = self.classifier(all_hiddens)
            else:
                extra_info["label_predictions"] = self.classifier(
                    z0_mean).reshape(1, -1)

        # outputs shape: [n_traj_samples, n_traj, n_tp, n_dims]
        return outputs, extra_info