Exemplo n.º 1
0
    def forward(self, context=None, x_prime=None, more_context=None, z=None):
        """

        :param x: observation at current step
        :param context: to be fed at each timestep
        :param x_prime: observation at next step
        :param more_context: also to be fed at each timestep.
        :param z: (optional) if not None z is used directly and not sampled
        :return:
        """
        # TODO to get rid of more_context, make an interface that allows context structures
        output = AttrDict()

        output.p_z = self.prior(torch.zeros_like(
            x_prime))  # the input is only used to read the batchsize atm
        if x_prime is not None:
            output.q_z = self.inf(
                self.inf_lstm(x_prime, context, more_context).output)

        if z is None:
            if self._sample_prior:
                z = Gaussian(output.p_z).sample()
            else:
                z = Gaussian(output.q_z).sample()

        pred_input = [z, context, more_context]

        output.x = self.gen_lstm(*pred_input).output
        return output
Exemplo n.º 2
0
    def forward(self,
                x,
                context=None,
                x_prime=None,
                more_context=None,
                z=None):
        """
        
        :param x: observation at current step
        :param context: to be fed at each timestep
        :param x_prime: observation at next step
        :param more_context: also to be fed at each timestep.
        :param z: (optional) if not None z is used directly and not sampled
        :return:
        """
        # TODO to get rid of more_context, make an interface that allows context structures
        output = AttrDict()
        if x_prime is None:
            x_prime = torch.zeros_like(x)  # Used when sequence isn't available

        output.q_z = self.inf(
            self.inf_lstm(x_prime, context, more_context).output)
        output.p_z = self.prior(
            x)  # the input is only used to read the batchsize atm

        if z is not None:
            if self._hp.prior_type == 'learned':
                z = output.p_z.reparametrize(z)
            pass  # use z directly
        elif self._sample_prior:
            z = output.p_z.sample()
        else:
            z = output.q_z.sample()

        # Note: providing x might be unnecessary if it is provided in the init_state
        pred_input = [x, z, context, more_context]

        # x_t is fed back in as input (technically, it is unnecessary, however, the lstm is setup to observe a frame
        # every step because it observes one in the beginning).
        output.x = self.gen_lstm(*pred_input).output
        return output
Exemplo n.º 3
0
 def forward(self, context=None, more_context=None, z=None, batch_size=None):
     """
     
     :param x: observation at current step
     :param context: to be fed at each timestep
     :param x_prime: observation at next step
     :param more_context: also to be fed at each timestep.
     :param z: (optional) if not None z is used directly and not sampled
     :return:
     """
     # TODO to get rid of more_context, make an interface that allows context structures
     outputs = AttrDict()
     
     outputs.p_z = self.prior(self.gen_lstm.output.weight.new_zeros(batch_size))  # the input is only used to read the batchsize atm
     if z is None:
         z = Gaussian(outputs.p_z).sample()
 
     pred_input = [z, context, more_context]
 
     outputs.x = self.gen_lstm(*pred_input).output
     return outputs