Example #1
0
 def forward(self, x, output_length, conditioning_length, context=None):
     """
     
     :param x: the modelled sequence, batch x time x  x_dim
     :param length: the desired length of the output sequence. Note, this includes all conditioning frames except 1
     :param conditioning_length: the length on which the prediction will be conditioned. Ground truth data are observed
     for this length
     :param context: a context sequence. Prediction is conditioned on all context up to and including this moment
     :return:
     """
     lstm_inputs = AttrDict()
     outputs = AttrDict()
     if context is not None:
         lstm_inputs.more_context = context
 
     if not self._sample_prior:
         outputs.q_z = self.inference(x, context)
         lstm_inputs.z = Gaussian(outputs.q_z).sample()
 
     outputs.update(self.generator(inputs=lstm_inputs,
                                   length=output_length + conditioning_length,
                                   static_inputs=AttrDict(batch_size=x.shape[0])))
     # The way the conditioning works now is by zeroing out the loss on the KL divergence and returning less frames
     # That way the network can pass the info directly through z. I can also implement conditioning by feeding
     # the frames directly into predictor. that would require passing previous frames to the VRNNCell and
     # using a fake frame to condition the 0th frame on.
     outputs = rmap(lambda ten: ten[:, conditioning_length:], outputs)
     outputs.conditioning_length = conditioning_length
     return outputs
Example #2
0
    def forward(self, root, inputs):
        lstm_inputs = AttrDict()
        initial_inputs = AttrDict(x=inputs.e_0)
        context = torch.cat([inputs.e_0, inputs.e_g], dim=1)
        static_inputs = AttrDict()

        if 'enc_traj_seq' in inputs:
            lstm_inputs.x_prime = inputs.enc_traj_seq[:, 1:]
        if 'z' in inputs:
            lstm_inputs.z = inputs.z
        if self._hp.context_every_step:
            static_inputs.context = context
        if self._hp.action_conditioned_pred:
            assert 'enc_action_seq' in inputs  # need to feed actions for action conditioned predictor
            lstm_inputs.update(more_context=inputs.enc_action_seq)

        self.lstm.cell.init_state(initial_inputs.x, context,
                                  lstm_inputs.get('more_context', None))
        # Note: the last image is also produced. The actions are defined as going to the image
        outputs = self.lstm(inputs=lstm_inputs,
                            initial_inputs=initial_inputs,
                            static_inputs=static_inputs,
                            length=self._hp.max_seq_len - 1)
        outputs.encodings = outputs.pop('x')
        outputs.update(self.decoder.decode_seq(inputs, outputs.encodings))
        outputs.images = torch.cat([inputs.I_0[:, None], outputs.images],
                                   dim=1)
        return outputs