Ejemplo n.º 1
0
    def forward(self, hidden, y_std, x, masked_update=True):

        #getting the mask
        n_data_dims = x.size(-1) // 2
        mask = x[:, :, n_data_dims:]
        utils.check_mask(x[:, :, :n_data_dims], mask)
        mask = (torch.sum(mask, -1, keepdim=True) > 0).float()

        gate_x_K = self.x_K(
            x)  # return size torch.Size([1, batch_size, latent_dim])
        gate_x_z = self.x_z(
            x)  # return size torch.Size([1, batch_size, latent_dim])
        gate_h_K = self.h_K(
            hidden)  # return size torch.Size([1, batch_size, latent_dim])

        gate_x_K = gate_x_K.squeeze()
        gate_x_z = gate_x_z.squeeze()
        gate_h_K = gate_h_K.squeeze()

        if self.use_BN:
            if torch.sum(mask.float()) > 1:
                gate_x_K[mask.squeeze().bool()] = self.bn_x_K(
                    gate_x_K[mask.squeeze().bool()])
                gate_x_z[mask.squeeze().bool()] = self.bn_x_z(
                    gate_x_z[mask.squeeze().bool()])
                gate_h_K[mask.squeeze().bool()] = self.bn_h_K(
                    gate_h_K[mask.squeeze().bool()])

        K_gain = torch.sigmoid(gate_x_K + gate_h_K)
        z = torch.tanh(gate_x_z)

        h_new = hidden + K_gain * (z - hidden)
        h_new = torch.tanh(h_new)

        # Masked update
        if masked_update:
            # IMPORTANT: assumes that x contains both data and mask
            # update only the hidden states for hidden state only if at least one feature is present for the current time point

            assert (not torch.isnan(mask).any())

            h_new = mask * h_new + (1 - mask) * hidden

            if torch.isnan(h_new).any():
                print("new_y is nan!")
                print(mask)
                print(hidden)
                print(h_new)
                exit()

        return h_new, y_std
Ejemplo n.º 2
0
    def forward(self, y_mean, y_std, x, masked_update=True):
        y_concat = torch.cat([y_mean, y_std, x], -1)

        update_gate = self.update_gate(y_concat)
        reset_gate = self.reset_gate(y_concat)
        concat = torch.cat([y_mean * reset_gate, y_std * reset_gate, x], -1)

        new_state, new_state_std = utils.split_last_dim(
            self.new_state_net(concat))
        new_state_std = new_state_std.abs()

        new_y = (1 - update_gate) * new_state + update_gate * y_mean
        new_y_std = F.softplus((1 - update_gate) * new_state_std +
                               update_gate * y_std)

        assert (not torch.isnan(new_y).any())

        if masked_update:
            # IMPORTANT: assumes that x contains both data and mask
            # update only the hidden states for hidden state only if at least one feature is present for the current time point
            n_data_dims = x.size(-1) // 2
            mask = x[:, :, n_data_dims:]
            utils.check_mask(x[:, :, :n_data_dims], mask)

            mask = (torch.sum(mask, -1, keepdim=True) > 0).float()

            assert (not torch.isnan(mask).any())

            new_y = mask * new_y + (1 - mask) * y_mean
            # 			new_y_std = mask * new_y_std + (1-mask) * y_std

            if torch.isnan(new_y).any():
                print("new_y is nan!")
                print(mask)
                print(y_mean)
                print(new_y)
                exit()


# 		new_y_std = new_y_std.abs()
        return new_y, new_y_std
Ejemplo n.º 3
0
    def forward(self, y_mean, y_std, x, masked_update=True):
        #forward(self, x, hidden):

        h, c = y_mean
        h = h.view(h.size(1), -1)
        c = c.view(c.size(1), -1)
        x_short = x.view(x.size(1), -1)

        # Linear mappings
        preact = self.i2h(x_short) + self.h2h(h)

        # activations
        gates = preact[:, :3 * self.hidden_size].sigmoid()
        g_t = preact[:, 3 * self.hidden_size:].tanh()
        i_t = gates[:, :self.hidden_size]
        f_t = gates[:, self.hidden_size:2 * self.hidden_size]
        o_t = gates[:, -self.hidden_size:]

        c_t = torch.mul(c, f_t) + torch.mul(i_t, g_t)

        h_t = torch.mul(o_t, c_t.tanh())

        new_h = h_t.view(1, h_t.size(0), -1)
        new_c = c_t.view(1, c_t.size(0), -1)

        new_y = (new_h, new_c)

        if masked_update:
            # IMPORTANT: assumes that x contains both data and mask
            # update only the hidden states for hidden state only if at least one feature is present for the current time point
            n_data_dims = x.size(-1) // 2
            mask = x[:, :, n_data_dims:]
            #pdb.set_trace()
            utils.check_mask(x[:, :, :n_data_dims], mask)

            mask = (torch.sum(mask, -1, keepdim=True) > 0).float()

            assert (not torch.isnan(mask).any())

            new_h = mask * new_h + (1 - mask) * y_mean[0]
            new_c = mask * new_c + (1 - mask) * y_mean[1]

            new_y = (new_h, new_c)

            if torch.isnan(new_h).any():
                print("new_y is nan!")
                print(mask)
                print(new_h)
                print(y_mean)
                exit()
            if torch.isnan(new_c).any():
                print("new_y is nan!")
                print(mask)
                print(new_c)
                print(y_mean)
                exit()

        # just return a dummy tensor, since it is not used later for this project.
        new_y_std = y_std

        #return h_t, (h_t, c_t)
        return new_y, new_y_std
    def get_reconstruction(self,
                           time_steps_to_predict,
                           data,
                           truth_time_steps,
                           mask=None,
                           n_traj_samples=1,
                           mode=None):

        assert (mask is not None)
        n_traj, n_tp, n_dims = data.size()

        if (len(truth_time_steps) != len(time_steps_to_predict)) or (
                torch.sum(time_steps_to_predict - truth_time_steps) != 0):
            raise Exception(
                "Extrapolation mode not implemented for RNN models")

        # for classic RNN time_steps_to_predict should be the same as  truth_time_steps
        assert (len(truth_time_steps) == len(time_steps_to_predict))

        batch_size = data.size(0)
        zero_delta_t = torch.Tensor([0.]).to(self.device)

        delta_ts = truth_time_steps[1:] - truth_time_steps[:-1]
        delta_ts = torch.cat((delta_ts, zero_delta_t))
        if len(delta_ts.size()) == 1:
            # delta_ts are shared for all trajectories in a batch
            assert (data.size(1) == delta_ts.size(0))
            delta_ts = delta_ts.unsqueeze(-1).repeat((batch_size, 1, 1))

        input_decay_params = None
        if self.input_space_decay:
            input_decay_params = (self.w_input_decay, self.b_input_decay)

        if mask is not None:
            utils.check_mask(data, mask)

        hidden_state, all_hiddens = run_rnn(
            data,
            delta_ts,
            cell=self.rnn_cell,
            mask=mask,
            input_decay_params=input_decay_params,
            feed_previous_w_prob=(0. if self.use_binary_classif else 0.5),
            decoder=self.decoder)

        outputs = self.decoder(all_hiddens)
        # Shift outputs for computing the loss -- we should compare the first output to the second data point, etc.
        first_point = data[:, 0, :]
        outputs = utils.shift_outputs(outputs, first_point)

        extra_info = {
            "first_point":
            (hidden_state.unsqueeze(0), 0.0, hidden_state.unsqueeze(0))
        }

        if self.use_binary_classif:
            if self.classif_per_tp:
                extra_info["label_predictions"] = self.classifier(all_hiddens)
            else:
                extra_info["label_predictions"] = self.classifier(
                    hidden_state).reshape(1, -1)

        # outputs shape: [n_traj_samples, n_traj, n_tp, n_dims]
        return outputs, extra_info