def forward(self, x, output_length, conditioning_length, context=None): """ :param x: the modelled sequence, batch x time x x_dim :param length: the desired length of the output sequence. Note, this includes all conditioning frames except 1 :param conditioning_length: the length on which the prediction will be conditioned. Ground truth data are observed for this length :param context: a context sequence. Prediction is conditioned on all context up to and including this moment :return: """ lstm_inputs = AttrDict() outputs = AttrDict() if context is not None: lstm_inputs.more_context = context if not self._sample_prior: outputs.q_z = self.inference(x, context) lstm_inputs.z = Gaussian(outputs.q_z).sample() outputs.update(self.generator(inputs=lstm_inputs, length=output_length + conditioning_length, static_inputs=AttrDict(batch_size=x.shape[0]))) # The way the conditioning works now is by zeroing out the loss on the KL divergence and returning less frames # That way the network can pass the info directly through z. I can also implement conditioning by feeding # the frames directly into predictor. that would require passing previous frames to the VRNNCell and # using a fake frame to condition the 0th frame on. outputs = rmap(lambda ten: ten[:, conditioning_length:], outputs) outputs.conditioning_length = conditioning_length return outputs
def forward(self, context=None, x_prime=None, more_context=None, z=None): """ :param x: observation at current step :param context: to be fed at each timestep :param x_prime: observation at next step :param more_context: also to be fed at each timestep. :param z: (optional) if not None z is used directly and not sampled :return: """ # TODO to get rid of more_context, make an interface that allows context structures output = AttrDict() output.p_z = self.prior(torch.zeros_like( x_prime)) # the input is only used to read the batchsize atm if x_prime is not None: output.q_z = self.inf( self.inf_lstm(x_prime, context, more_context).output) if z is None: if self._sample_prior: z = Gaussian(output.p_z).sample() else: z = Gaussian(output.q_z).sample() pred_input = [z, context, more_context] output.x = self.gen_lstm(*pred_input).output return output
def forward(self, inputs, e_l, e_r, start_ind, end_ind, timestep=None): assert timestep is None output = AttrDict() if self.deterministic: output.q_z = self.q(e_l) return output values = inputs.inf_enc_seq keys = inputs.inf_enc_key_seq # Get (initial) attention key query_input = [e_l, e_r] e_tilde, output.gamma = self.attention(values, keys, query_input, start_ind, end_ind, inputs) output.q_z = self.q(e_l, e_r, e_tilde) return output
def forward(self, inputs, e_l, e_r, start_ind, end_ind, timestep): assert timestep is not None output = AttrDict(gamma=None) if self.deterministic: output.q_z = self.q(e_l) return output values = inputs.inf_enc_seq keys = inputs.inf_enc_key_seq mult = int(timestep.shape[0] / keys.shape[0]) if mult > 1: timestep = timestep.reshape(-1, mult) result = batchwise_index(values, timestep.long()) e_tilde = result.reshape([-1] + list(result.shape[2:])) else: e_tilde = batchwise_index(values, timestep[:, 0].long()) output.q_z = self.q(e_l, e_r, e_tilde) return output
def forward(self, x, context=None, x_prime=None, more_context=None, z=None): """ :param x: observation at current step :param context: to be fed at each timestep :param x_prime: observation at next step :param more_context: also to be fed at each timestep. :param z: (optional) if not None z is used directly and not sampled :return: """ # TODO to get rid of more_context, make an interface that allows context structures output = AttrDict() if x_prime is None: x_prime = torch.zeros_like(x) # Used when sequence isn't available output.q_z = self.inf( self.inf_lstm(x_prime, context, more_context).output) output.p_z = self.prior( x) # the input is only used to read the batchsize atm if z is not None: if self._hp.prior_type == 'learned': z = output.p_z.reparametrize(z) pass # use z directly elif self._sample_prior: z = output.p_z.sample() else: z = output.q_z.sample() # Note: providing x might be unnecessary if it is provided in the init_state pred_input = [x, z, context, more_context] # x_t is fed back in as input (technically, it is unnecessary, however, the lstm is setup to observe a frame # every step because it observes one in the beginning). output.x = self.gen_lstm(*pred_input).output return output