def _create_initial_nodes(self, inputs): start_node, end_node = AttrDict(e_g_prime=inputs.e_0, images=inputs.I_0), \ AttrDict(e_g_prime=inputs.e_g, images=inputs.I_g) if not self._hp.attentive_inference: start_match_timestep, end_match_timestep = self.tree_module.binding.get_init_inds( inputs) start_node.update(AttrDict(match_timesteps=start_match_timestep)) end_node.update(AttrDict(match_timesteps=end_match_timestep)) if self._hp.tree_lstm: start_node.hidden_state, end_node.hidden_state = None, None return start_node, end_node
def _create_initial_nodes(self, inputs): start_node, end_node = AttrDict(e_g_prime=inputs.enc_e_0, images=inputs.I_0), \ AttrDict(e_g_prime=inputs.enc_e_g, images=inputs.I_g) if self._hp.forced_attention or self._hp.timestep_cond_attention or self._hp.supervise_attn_weight > 0.0: start_match_timestep, end_match_timestep = self.one_step_planner.matcher.get_init_inds( inputs) start_node.update(AttrDict(match_timesteps=start_match_timestep)) end_node.update(AttrDict(match_timesteps=end_match_timestep)) if self._hp.tree_lstm: start_node.hidden_state, end_node.hidden_state = None, None return start_node, end_node
def produce_subgoal(self, inputs, layerwise_inputs, start_ind, end_ind, left_parent, right_parent, depth=None): """ Divides the subsequence by producing a subgoal inside it. This function represents one step of recursion of the model """ subgoal = AttrDict() e_l = left_parent.e_g_prime e_r = right_parent.e_g_prime subgoal.p_z = self.prior(e_l, e_r) if 'z' in layerwise_inputs: z = layerwise_inputs.z if self._hp.prior_type == 'learned': # reparametrize if learned prior is used z = subgoal.p_z.reparametrize(z) elif self._sample_prior: z = subgoal.p_z.sample() else: ## Inference if self._hp.attentive_inference: subgoal.update( self.inference(inputs, e_l, e_r, start_ind, end_ind)) else: subgoal.match_timesteps = self.binding.comp_timestep( left_parent.match_timesteps, right_parent.match_timesteps) subgoal.update( self.inference(inputs, e_l, e_r, start_ind, end_ind, subgoal.match_timesteps.float())) z = subgoal.q_z.sample() ## Predict the next node pred_input = [e_l, e_r, z] if self._hp.context_every_step: mult = int(z.shape[0] / inputs.e_0.shape[0]) pred_input += [ inputs.e_0.repeat_interleave(mult, 0), inputs.e_g.repeat_interleave(mult, 0) ] if self._hp.tree_lstm: if left_parent.hidden_state is None and right_parent.hidden_state is None: left_parent.hidden_state, right_parent.hidden_state = self.lstm_initializer( e_l, e_r, z) subgoal.hidden_state, subgoal.e_g_prime = \ self.subgoal_pred(left_parent.hidden_state, right_parent.hidden_state, *pred_input) else: subgoal.e_g_prime_preact = self.subgoal_pred(*pred_input) subgoal.e_g_prime = torch.tanh(subgoal.e_g_prime_preact) subgoal.ind = ( start_ind + end_ind ) / 2 # gets overwritten w/ argmax of matching at training time (in loss) return subgoal, left_parent, right_parent
def produce_subgoal(self, inputs, layerwise_inputs, start_ind, end_ind, left_parent, right_parent, depth=None): """ Divides the subsequence by producing a subgoal inside it. This function represents one step of recursion of the model """ subgoal = AttrDict() batch_size = start_ind.shape[0] e_l = left_parent.e_g_prime e_r = right_parent.e_g_prime subgoal.p_z = self.prior(e_l, e_r) if 'z' in layerwise_inputs: z = layerwise_inputs.z if self._hp.prior_type == 'learned': # reparametrize if learned prior is used z = subgoal.p_z.reparametrize(z) elif self._sample_prior: z = subgoal.p_z.sample() else: ## Inference if (self._hp.timestep_cond_attention or self._hp.forced_attention): subgoal.fraction = self.fraction_pred(e_l, e_r)[..., -1] if self.predict_fraction else None subgoal.match_timesteps = self.matcher.comp_timestep(left_parent.match_timesteps, right_parent.match_timesteps, subgoal.fraction[:, None] if subgoal.fraction is not None else None) subgoal.update(self.inference(inputs, e_l, e_r, start_ind, end_ind, subgoal.match_timesteps.float())) else: subgoal.update(self.inference( inputs, e_l, e_r, start_ind, end_ind, attention_weights=layerwise_inputs.safe.attention_weights)) z = subgoal.q_z.sample() ## Predict the next node pred_input = [e_l, e_r, z] if self._hp.context_every_step: mult = int(z.shape[0] / inputs.enc_e_0.shape[0]) pred_input += [inputs.enc_e_0.repeat_interleave(mult, 0), inputs.enc_e_g.repeat_interleave(mult, 0)] if self._hp.tree_lstm: if left_parent.hidden_state is None and right_parent.hidden_state is None: left_parent.hidden_state, right_parent.hidden_state = self.lstm_initializer(e_l, e_r, z) if self._hp.lstm_warmup_cycles > 0: for _ in range(self._hp.lstm_warmup_cycles): left_parent.hidden_state, __ = \ self.subgoal_pred(left_parent.hidden_state, right_parent.hidden_state, e_l, e_r, z) right_parent.hidden_state = left_parent.hidden_state.clone() subgoal.hidden_state, subgoal.e_g_prime = \ self.subgoal_pred(left_parent.hidden_state, right_parent.hidden_state, *pred_input) else: subgoal.e_g_prime_preact = self.subgoal_pred(*pred_input) subgoal.e_g_prime = torch.tanh(subgoal.e_g_prime_preact) ## Additional predicted values if self.predict_fraction and not self._hp.timestep_cond_attention: subgoal.fraction = self.fraction_pred(e_l, e_r, subgoal.e_g_prime)[..., -1] # remove unnecessary dim # add attention target if trained with attention supervision if self._hp.supervise_attn_weight > 0.0: frac = subgoal.fraction[:, None] if 'fraction' in subgoal and subgoal.fraction is not None else None subgoal.match_timesteps = self.matcher.comp_timestep(left_parent.match_timesteps, right_parent.match_timesteps, frac) subgoal.ind = (start_ind + end_ind) / 2 # gets overwritten w/ argmax of matching at training time (in loss) return subgoal, left_parent, right_parent