def forward(self, root, inputs): outputs = AttrDict() # TODO implement soft interpolation sg_times, sg_encs = [], [] for segment in root: sg_times.append(segment.subgoal.ind) sg_encs.append(segment.subgoal.e_g_prime) sg_times = torch.stack(sg_times, dim=1) sg_encs = torch.stack(sg_encs, dim=1) # compute time difference weights seq_length = self._hp.max_seq_len target_ind = torch.arange(end=seq_length, dtype=sg_times.dtype) time_diffs = torch.abs(target_ind[None, None, :] - sg_times[:, :, None]) weights = nn.functional.softmax(-time_diffs, dim=-1) # compute weighted sum outputs weighted_sg = weights[:, :, :, None, None, None] * sg_encs.unsqueeze(2).repeat( 1, 1, seq_length, 1, 1, 1) outputs.encodings = torch.sum(weighted_sg, dim=1) outputs.update( self._dense_decode(inputs, outputs.encodings, seq_length)) return outputs
def forward(self, root, inputs): lstm_inputs = AttrDict() initial_inputs = AttrDict(x=inputs.e_0) context = torch.cat([inputs.e_0, inputs.e_g], dim=1) static_inputs = AttrDict() if 'enc_traj_seq' in inputs: lstm_inputs.x_prime = inputs.enc_traj_seq[:, 1:] if 'z' in inputs: lstm_inputs.z = inputs.z if self._hp.context_every_step: static_inputs.context = context if self._hp.action_conditioned_pred: assert 'enc_action_seq' in inputs # need to feed actions for action conditioned predictor lstm_inputs.update(more_context=inputs.enc_action_seq) self.lstm.cell.init_state(initial_inputs.x, context, lstm_inputs.get('more_context', None)) # Note: the last image is also produced. The actions are defined as going to the image outputs = self.lstm(inputs=lstm_inputs, initial_inputs=initial_inputs, static_inputs=static_inputs, length=self._hp.max_seq_len - 1) outputs.encodings = outputs.pop('x') outputs.update(self.decoder.decode_seq(inputs, outputs.encodings)) outputs.images = torch.cat([inputs.I_0[:, None], outputs.images], dim=1) return outputs
def forward(self, x, output_length, conditioning_length, context=None): """ :param x: the modelled sequence, batch x time x x_dim :param length: the desired length of the output sequence. Note, this includes all conditioning frames except 1 :param conditioning_length: the length on which the prediction will be conditioned. Ground truth data are observed for this length :param context: a context sequence. Prediction is conditioned on all context up to and including this moment :return: """ lstm_inputs = AttrDict() outputs = AttrDict() if context is not None: lstm_inputs.more_context = context if not self._sample_prior: outputs.q_z = self.inference(x, context) lstm_inputs.z = Gaussian(outputs.q_z).sample() outputs.update(self.generator(inputs=lstm_inputs, length=output_length + conditioning_length, static_inputs=AttrDict(batch_size=x.shape[0]))) # The way the conditioning works now is by zeroing out the loss on the KL divergence and returning less frames # That way the network can pass the info directly through z. I can also implement conditioning by feeding # the frames directly into predictor. that would require passing previous frames to the VRNNCell and # using a fake frame to condition the 0th frame on. outputs = rmap(lambda ten: ten[:, conditioning_length:], outputs) outputs.conditioning_length = conditioning_length return outputs
def get_node_loss(self, inputs, outputs): """ Reconstruction and KL divergence loss """ losses = AttrDict() tree = outputs.tree losses.update(self.binding.reconstruction_loss(inputs, outputs)) losses.update(self.inference.loss(tree.bf.q_z, tree.bf.p_z)) return losses
def reconstruction_loss(self, inputs, outputs, weights): losses = AttrDict() outputs.soft_matched_estimates = self.criterion.get_soft_estimates(outputs.gt_match_dists, outputs.tree.bf.images) losses.update(self.criterion.loss( outputs, inputs.traj_seq, weights, inputs.pad_mask, self._hp.dense_img_rec_weight, self.decoder.log_sigma)) return losses
def forward(self, root, inputs): # TODO implement stopping probability prediction # TODO make the low-level network not predict subgoals batch_size, time = self._hp.batch_size, self._hp.max_seq_len outputs = AttrDict() lstm_inputs = self._get_lstm_inputs(root, inputs) lstm_outputs = self.lstm(lstm_inputs, time) outputs.encodings = torch.stack(lstm_outputs, dim=1) outputs.update(self._dense_decode(inputs, outputs.encodings, time)) return outputs
def _create_initial_nodes(self, inputs): start_node, end_node = AttrDict(e_g_prime=inputs.enc_e_0, images=inputs.I_0), \ AttrDict(e_g_prime=inputs.enc_e_g, images=inputs.I_g) if self._hp.forced_attention or self._hp.timestep_cond_attention or self._hp.supervise_attn_weight > 0.0: start_match_timestep, end_match_timestep = self.one_step_planner.matcher.get_init_inds( inputs) start_node.update(AttrDict(match_timesteps=start_match_timestep)) end_node.update(AttrDict(match_timesteps=end_match_timestep)) if self._hp.tree_lstm: start_node.hidden_state, end_node.hidden_state = None, None return start_node, end_node
def _create_initial_nodes(self, inputs): start_node, end_node = AttrDict(e_g_prime=inputs.e_0, images=inputs.I_0), \ AttrDict(e_g_prime=inputs.e_g, images=inputs.I_g) if not self._hp.attentive_inference: start_match_timestep, end_match_timestep = self.tree_module.binding.get_init_inds( inputs) start_node.update(AttrDict(match_timesteps=start_match_timestep)) end_node.update(AttrDict(match_timesteps=end_match_timestep)) if self._hp.tree_lstm: start_node.hidden_state, end_node.hidden_state = None, None return start_node, end_node
def loss(self, inputs, model_output, log_error_arr=False): losses = AttrDict() # Length prediction loss if self._hp.regress_length: losses.update(self.length_pred.loss(inputs, model_output)) # Dense Reconstruction loss losses.update(self.dense_rec.loss(inputs, model_output.dense_rec, log_error_arr)) # Inverse Model loss if self._hp.attach_inv_mdl: losses.update(self.inv_mdl.loss(inputs, model_output, add_total=False)) # Cost model loss if self._hp.attach_cost_mdl and self._hp.run_cost_mdl: losses.update(self.cost_mdl.loss(inputs, model_output)) # State regressor cost if self._hp.attach_state_regressor: reg_len = model_output.regressed_state.shape[1] losses.state_regression = L2Loss(1.0)(model_output.regressed_state, inputs.demo_seq_states[:, :reg_len], weights=inputs.pad_mask[:, :reg_len][:, :, None]) # Negative Log-likelihood (upper bound) if 'dense_img_rec' in losses and 'kl' in losses: losses.nll = AttrDict(value=losses.dense_img_rec.value + 1.0 * losses.kl.value, weight=0.0) return losses
def loss(self, inputs, outputs): if outputs.tree.depth == 0: return {} losses = AttrDict() losses.update(self.get_node_loss(inputs, outputs)) # Explaining loss losses.update(self.binding.loss(inputs, outputs)) # entropy penalty losses.entropy = PenaltyLoss(weight=self._hp.entropy_weight)( outputs.entropy) return losses
def get_node_loss(self, inputs, outputs): """ Reconstruction and KL divergence loss """ losses = AttrDict() tree = outputs.tree # Weight of the loss kl_weights = weights = 1 if self._hp.equal_weight_layer: top = 2 ** (self._hp.hierarchy_levels - 1) # For each layer, divide the weight by the number of elements in the layer weights = np.concatenate([np.full((2 ** l,), top / (2 ** l)) for l in range(self._hp.hierarchy_levels)]) weights = torch.from_numpy(weights).to(self._hp.device).float()[None, :, None] kl_weights = weights[..., None, None] losses.update(self.matcher.reconstruction_loss(inputs, outputs, weights)) losses.update(self.inference.loss(tree.bf.q_z, tree.bf.p_z, weights=kl_weights)) return losses
def forward(self, inputs, full_seq=None): outputs = AttrDict() n_repeats = 32 if inputs.I_0.shape[0] == inputs.demo_seq.shape[0]: # Repeat the I_0 so that the visualization is correct # This should only be done once per batch!! inputs.I_0 = inputs.I_0.repeat_interleave(n_repeats, 0) I_target = self.sample_target(inputs.demo_seq, inputs.end_ind, n_repeats) inputs.I_target = I_target = I_target.reshape((-1,) + I_target.shape[2:]) # Note, the decoder supports skips and pixel copying! e_0, _ = self.encoder(inputs.I_0) e_g, _ = self.encoder(I_target) outputs.update(self.net(e_g, e_0)) outputs.e_target_prime = e_target_prime = outputs.pop('mu') outputs.I_target_prime = self.decoder(e_target_prime) return outputs
def _default_hparams(self): # Data Dimensions default_dict = AttrDict({ 'batch_size': -1, 'max_seq_len': -1, 'n_actions': -1, 'state_dim': -1, 'img_sz': 32, # image resolution 'input_nc': 3, # number of input feature maps 'n_conv_layers': None, # Number of conv layers. Can be of format 'n-<int>' for any int for relative spec }) # Network params default_dict.update({ 'use_convs': True, 'use_batchnorm': True, # TODO deprecate 'normalization': 'batch', 'predictor_normalization': 'group', }) # Misc params default_dict.update({ 'filter_repeated_tail': False, # whether to remove repeated states from the dataset 'rep_tail': False, 'dataset_class': None, 'standardize': None, 'split': None, 'subsampler': None, 'subsample_args': None, 'checkpt_path': None, }) # add new params to parent params parent_params = HParams() for k in default_dict.keys(): parent_params.add_hparam(k, default_dict[k]) return parent_params
def forward(self, x, output_length, conditioning_length, context=None): """ :param x: the modelled sequence, batch x time x x_dim :param length: the desired length of the output sequence. Note, this includes all conditioning frames except 1 :param conditioning_length: the length on which the prediction will be conditioned. Ground truth data are observed for this length :param context: a context sequence. Prediction is conditioned on all context up to and including this moment :return: """ lstm_inputs = AttrDict(x_prime=x[:, 1:]) if context is not None: context = pad(context, pad_front=1, dim=1) lstm_inputs.update(more_context=context[:, 1:]) initial_inputs = AttrDict(x=x[:, :conditioning_length]) self.lstm.cell.init_state(x[:, 0], more_context=context) outputs = self.lstm(inputs=lstm_inputs, initial_seq_inputs=initial_inputs, length=output_length + conditioning_length - 1) outputs = rmap(lambda ten: ten[:, conditioning_length - 1:], outputs) return outputs
def loss(self, inputs, model_output): if model_output.tree.depth == 0: return {} losses = AttrDict() if not 'gt_matching_dists' in model_output: # potentially already computed in forward pass self.compute_matching(inputs, model_output) losses.update(self.get_node_loss(inputs, model_output)) losses.update(self.get_extra_losses(inputs, model_output)) # Explaining loss losses.update(self.matcher.loss(inputs, model_output)) # entropy penalty losses.entropy = PenaltyLoss(weight=self._hp.entropy_weight)(model_output.entropy) return losses
def produce_subgoal(self, inputs, layerwise_inputs, start_ind, end_ind, left_parent, right_parent, depth=None): """ Divides the subsequence by producing a subgoal inside it. This function represents one step of recursion of the model """ subgoal = AttrDict() e_l = left_parent.e_g_prime e_r = right_parent.e_g_prime subgoal.p_z = self.prior(e_l, e_r) if 'z' in layerwise_inputs: z = layerwise_inputs.z if self._hp.prior_type == 'learned': # reparametrize if learned prior is used z = subgoal.p_z.reparametrize(z) elif self._sample_prior: z = subgoal.p_z.sample() else: ## Inference if self._hp.attentive_inference: subgoal.update( self.inference(inputs, e_l, e_r, start_ind, end_ind)) else: subgoal.match_timesteps = self.binding.comp_timestep( left_parent.match_timesteps, right_parent.match_timesteps) subgoal.update( self.inference(inputs, e_l, e_r, start_ind, end_ind, subgoal.match_timesteps.float())) z = subgoal.q_z.sample() ## Predict the next node pred_input = [e_l, e_r, z] if self._hp.context_every_step: mult = int(z.shape[0] / inputs.e_0.shape[0]) pred_input += [ inputs.e_0.repeat_interleave(mult, 0), inputs.e_g.repeat_interleave(mult, 0) ] if self._hp.tree_lstm: if left_parent.hidden_state is None and right_parent.hidden_state is None: left_parent.hidden_state, right_parent.hidden_state = self.lstm_initializer( e_l, e_r, z) subgoal.hidden_state, subgoal.e_g_prime = \ self.subgoal_pred(left_parent.hidden_state, right_parent.hidden_state, *pred_input) else: subgoal.e_g_prime_preact = self.subgoal_pred(*pred_input) subgoal.e_g_prime = torch.tanh(subgoal.e_g_prime_preact) subgoal.ind = ( start_ind + end_ind ) / 2 # gets overwritten w/ argmax of matching at training time (in loss) return subgoal, left_parent, right_parent
def get_default_gcp_hyperparameters(): # Params that actually should be elsewhere default_dict = AttrDict({ 'randomize_length': False, 'randomize_start': False, }) # Network size default_dict.update({ 'ngf': 4, # number of feature maps in shallowest level 'nz_enc': 32, # number of dimensions in encoder-latent space 'nz_vae': 32, # number of dimensions in vae-latent space 'nz_vae2': 256, # number of dimensions in 2nd level vae-latent space (if used) 'nz_mid': 32, # number of dimensions for internal feature spaces 'nz_mid_lstm': 32, 'n_lstm_layers': 1, 'n_processing_layers': 3, # Number of layers in MLPs 'conv_inf_enc_kernel_size': 3, # kernel size of convolutional inference encoder 'conv_inf_enc_layers': 1, # number of layers in convolutional inference encoder 'n_attention_heads': 1, # number of attention heads (needs to divide nz_enc evenly) 'n_attention_layers': 1, # number of layers in attention network 'nz_attn_key': 32, # dimensionality of the attention key 'init_mlp_layers': 3, # number of layers in the LSTM initialization MLP (if used) 'init_mlp_mid_sz': 32, # size of hidden layers inside LSTM initialization MLP (if used) 'n_conv_layers': None, # Number of conv layers. Can be of format 'n-<int>' for any int for relative spec }) # Network params default_dict.update(AttrDict( action_activation=None, device=None, context_every_step=True, )) # Loss weights default_dict.update({ 'kl_weight': 1., 'kl_weight_burn_in': None, 'entropy_weight': .0, 'length_pred_weight': 1., 'dense_img_rec_weight': 1., 'dense_action_rec_weight': 1., 'free_nats': 0, }) # Architecture params default_dict.update({ 'use_skips': True, # only works with conv encoder/decoder 'skips_stride': 2, 'add_weighted_pixel_copy': False, # if True, adds pixel copying stream for decoder 'pixel_shift_decoder': False, 'skip_from_parents': False, # If True, parents are added to the pixel copy/shift sources 'seq_enc': 'none', # Manner of sequence encoding. ['none', 'conv', 'lstm'] 'regress_actions': False, # whether or not to regress actions 'learn_attn_temp': True, # if True, makes attention temperature a trainable variable 'attention_temperature': 1.0, # temperature param used in attention softmax 'attach_inv_mdl': False, # if True, attaches an inverse model to output that computes actions 'attach_cost_mdl': False, # if True, attaches a cost function MLP that estimates cost from pairs of states 'run_cost_mdl': True, # if False, does not run cost model (but might still build it 'attach_state_regressor': False, # if True, attaches network that regresses states from pre-decoding-latents 'action_conditioned_pred': False, # if True, conditions prediction on actions 'learn_beta': True, # if True, learns beta 'initial_sigma': 1.0, # if True, learns beta 'separate_cnn_start_goal_encoder': False, # if True, builds separate conv encoder for start/goal image 'decoder_distribution': 'gaussian' # [gaussian, categorical] }) # RNN params default_dict.update({ 'use_conv_lstm': False, }) # Variational inference parameters default_dict.update(AttrDict( prior_type='learned', # type of prior to be used ['fixed', 'learned'] var_inf='standard', # type of variation inference ['standard', '2layer', 'deterministic'] )) # RecPlan params default_dict.update({ 'hierarchy_levels': 3, # number of levels in the subgoal tree 'one_hot_attn_time_cond': False, # if True, conditions attention on one-hot time step index 'attentive_inference': False, # if True, forces attention to single matching frame 'non_goal_conditioned': False, # if True, does not condition prediction on goal frame 'tree_lstm': '', # ['', 'sum' or 'linear'] 'lstm_init': 'zero', # defines how treeLSTM is initialized, ['zero', 'mlp'], #, 'warmup'] 'matching_temp': 1.0, # temperature used in TAP-style matching softmax 'matching_temp_tenthlife': -1, 'matching_temp_min': 1e-3, 'matching_type': 'latent', # evidence binding procedure # ['image', 'latent', 'fraction', 'balanced', 'tap'] 'leaves_bias': 0.0, 'top_bias': 1.0, 'n_top_bias_nodes': 1, 'supervise_match_weight': 0.0, 'regress_index': False, 'regress_length': False, 'inv_mdl_params': {}, # params for the inverse model, if attached 'train_inv_mdl_full_seq': False, # if True, omits sampling for inverse model and trains on full seq 'cost_mdl_params': {}, # cost model parameters 'act_cond_inference': False, # if True, conditions inference on actions 'train_on_action_seqs': False, # if True, trains the predictive network on action sequences 'learned_pruning_threshold': 0.5, # confidence thresh for learned pruning network after which node gets pruned 'untied_layers': False, 'supervised_decoder': False, 'states_inference': False, }) # Outdated GCP params default_dict.update({ 'dense_rec_type': 'none', # ['none', 'discrete', 'softmax', 'linear', 'node_prob', action_prob]. 'one_step_planner': 'discrete', # ['discrete', 'continuous', 'sh_pred']. Always 'sh_pred' for HEDGE models 'mask_inf_attention': False, # if True, masks out inference attention outside the current subsegment 'binding': 'frames', # Matching loss form ['loss', 'frames', 'exp', 'lf']. Always 'loss'. }) # Matching params default_dict.update(AttrDict( learn_matching_temp=True, # If true, the matching temperature is learned )) # Logging params default_dict.update(AttrDict( dump_encodings='', # Specifies the directory where to dump the encodings dump_encodings_inv_model='', # Specifies the directory where to dump the encodings log_states_2d=False, # if True, logs 2D plot of first two state dimensions log_cartgripper=False, # if True, logs sawyer from states data_dir='', # necessary for sawyer logging )) # Hyperparameters that shouldn't exist default_dict.update(AttrDict( log_d2b_3x3maze=0, # Specifies the directory where to dump the encodings )) return default_dict
from blox import AttrDict from experiments.prediction.base_configs import base_tree as base_conf configuration = AttrDict(base_conf.configuration) model_config = AttrDict(base_conf.model_config) model_config.update({ 'matching_type': 'dtw_image', 'learn_matching_temp': False, 'attentive_inference': True, })
from blox import AttrDict from experiments.prediction.base_configs import base_tree as base_conf configuration = AttrDict(base_conf.configuration) configuration.metric_pruning_scheme = 'pruned_dtw' model_config = AttrDict(base_conf.model_config) model_config.update({ 'matching_type': 'balanced', 'forced_attention': True, })
import imp import os from blox import AttrDict from gcp.datasets.data_loader import MazeTopRenderedGlobalSplitVarLenVideoDataset from gcp.planning.cem_policy.utils.cost_fcn import EuclideanPathLength current_dir = os.path.dirname(os.path.realpath(__file__)) from experiments.prediction.base_configs import gcp_sequential as base_conf configuration = AttrDict(base_conf.configuration) configuration.update({ 'dataset_name': 'nav_9rooms', 'batch_size': 16, 'lr': 2e-4, 'epoch_cycles_train': 2, 'n_rooms': 9, 'metric_pruning_scheme': 'basic', }) model_config = AttrDict(base_conf.model_config) model_config.update({ 'ngf': 16, 'nz_mid_lstm': 512, 'n_lstm_layers': 3, 'nz_mid': 128, 'nz_enc': 128, 'nz_vae': 256, 'regress_length': True, 'attach_state_regressor': True, 'attach_cost_mdl': True,
from blox import AttrDict from experiments.prediction.base_configs import base_tree as base_conf configuration = AttrDict(base_conf.configuration) model_config = AttrDict(base_conf.model_config) model_config.update({ 'matching_type': 'dtw_image', 'learn_matching_temp': False, })
def produce_subgoal(self, inputs, layerwise_inputs, start_ind, end_ind, left_parent, right_parent, depth=None): """ Divides the subsequence by producing a subgoal inside it. This function represents one step of recursion of the model """ subgoal = AttrDict() batch_size = start_ind.shape[0] e_l = left_parent.e_g_prime e_r = right_parent.e_g_prime subgoal.p_z = self.prior(e_l, e_r) if 'z' in layerwise_inputs: z = layerwise_inputs.z if self._hp.prior_type == 'learned': # reparametrize if learned prior is used z = subgoal.p_z.reparametrize(z) elif self._sample_prior: z = subgoal.p_z.sample() else: ## Inference if (self._hp.timestep_cond_attention or self._hp.forced_attention): subgoal.fraction = self.fraction_pred(e_l, e_r)[..., -1] if self.predict_fraction else None subgoal.match_timesteps = self.matcher.comp_timestep(left_parent.match_timesteps, right_parent.match_timesteps, subgoal.fraction[:, None] if subgoal.fraction is not None else None) subgoal.update(self.inference(inputs, e_l, e_r, start_ind, end_ind, subgoal.match_timesteps.float())) else: subgoal.update(self.inference( inputs, e_l, e_r, start_ind, end_ind, attention_weights=layerwise_inputs.safe.attention_weights)) z = subgoal.q_z.sample() ## Predict the next node pred_input = [e_l, e_r, z] if self._hp.context_every_step: mult = int(z.shape[0] / inputs.enc_e_0.shape[0]) pred_input += [inputs.enc_e_0.repeat_interleave(mult, 0), inputs.enc_e_g.repeat_interleave(mult, 0)] if self._hp.tree_lstm: if left_parent.hidden_state is None and right_parent.hidden_state is None: left_parent.hidden_state, right_parent.hidden_state = self.lstm_initializer(e_l, e_r, z) if self._hp.lstm_warmup_cycles > 0: for _ in range(self._hp.lstm_warmup_cycles): left_parent.hidden_state, __ = \ self.subgoal_pred(left_parent.hidden_state, right_parent.hidden_state, e_l, e_r, z) right_parent.hidden_state = left_parent.hidden_state.clone() subgoal.hidden_state, subgoal.e_g_prime = \ self.subgoal_pred(left_parent.hidden_state, right_parent.hidden_state, *pred_input) else: subgoal.e_g_prime_preact = self.subgoal_pred(*pred_input) subgoal.e_g_prime = torch.tanh(subgoal.e_g_prime_preact) ## Additional predicted values if self.predict_fraction and not self._hp.timestep_cond_attention: subgoal.fraction = self.fraction_pred(e_l, e_r, subgoal.e_g_prime)[..., -1] # remove unnecessary dim # add attention target if trained with attention supervision if self._hp.supervise_attn_weight > 0.0: frac = subgoal.fraction[:, None] if 'fraction' in subgoal and subgoal.fraction is not None else None subgoal.match_timesteps = self.matcher.comp_timestep(left_parent.match_timesteps, right_parent.match_timesteps, frac) subgoal.ind = (start_ind + end_ind) / 2 # gets overwritten w/ argmax of matching at training time (in loss) return subgoal, left_parent, right_parent
def forward(self, inputs, phase='train'): """ forward pass at training time :param images shape = batch x time x height x width x channel pad mask shape = batch x time, 1 indicates actual image 0 is padded :return: """ if self._hp.non_goal_conditioned: if 'demo_seq' in inputs: inputs.demo_seq[torch.arange(inputs.demo_seq.shape[0]), inputs.end_ind] = 0.0 inputs.demo_seq_images[torch.arange(inputs.demo_seq.shape[0]), inputs.end_ind] = 0.0 inputs.I_g = torch.zeros_like(inputs.I_g) if "I_g_image" in inputs: inputs.I_g_image = torch.zeros_like(inputs.I_g_image) if inputs.I_0.shape[-1] == 5: # special hack for maze inputs.I_0[..., -2:] = 0.0 if "demo_seq" in inputs: inputs.demo_seq[..., -2:] = 0.0 # swap in actions if we want to train action sequence decoder if self._hp.train_on_action_seqs: inputs.demo_seq = torch.cat([inputs.actions, torch.zeros_like(inputs.actions[:, :1])], dim=1) model_output = AttrDict() inputs.reference_tensor = find_tensor(inputs) if 'start_ind' not in inputs: start_ind = torch.zeros(self._hp.batch_size, dtype=torch.long, device=inputs.reference_tensor.device) else: start_ind = inputs.start_ind self.run_encoder(inputs, start_ind) end_ind = inputs.end_ind if 'end_ind' in inputs else None if self._hp.regress_length: # predict total sequence length model_output.update(self.length_pred(inputs.enc_e_0, inputs.enc_e_g)) if self._use_pred_length and (self._hp.length_pred_weight > 0 or end_ind is None): end_ind = torch.argmax(model_output.seq_len_pred.sample().long(), dim=1) if self._hp.action_conditioned_pred or self._hp.non_goal_conditioned: # don't use predicted length when action conditioned end_ind = torch.ones_like(end_ind) * (self._hp.max_seq_len - 1) # TODO clean this up. model_output.end_ind is not currently used anywhere model_output.end_ind = end_ind # Run the model to generate sequences model_output.update(self.predict_sequence(inputs, model_output, start_ind, end_ind, phase)) if self.prune_sequences: if phase == 'train': inputs.model_enc_seq = self.get_matched_pruned_seqs(inputs, model_output) else: inputs.model_enc_seq = self.get_predicted_pruned_seqs(inputs, model_output) inputs.model_enc_seq = pad_sequence(inputs.model_enc_seq, batch_first=True) if len(inputs.model_enc_seq.shape) == 5: inputs.model_enc_seq = inputs.model_enc_seq[..., 0, 0] if self._hp.attach_inv_mdl and phase == 'train': model_output.update(self.inv_mdl(inputs, full_seq=self._inv_mdl_full_seq or self._hp.train_inv_mdl_full_seq)) if self._hp.attach_state_regressor: regressor_inputs = inputs.model_enc_seq if not self._hp.supervised_decoder: regressor_inputs = regressor_inputs.detach() model_output.regressed_state = batch_apply(regressor_inputs, self.state_regressor) if self._hp.attach_cost_mdl and self._hp.run_cost_mdl and phase == 'train': # There is an issue here since SVG doesn't output a latent for the first imagge # Beyong conceptual problems, this breaks if end_ind = 199 model_output.update(self.cost_mdl(inputs)) return model_output
image_width=32, start_goal_confs=os.environ['GCP_DATA_DIR'] + '/nav_25rooms/start_goal_configs/raw', ) h_config = AttrDict(base_conf.model_config) h_config.update({ 'state_dim': 2, 'ngf': 16, 'max_seq_len': 200, 'hierarchy_levels': 8, 'nz_mid_lstm': 512, 'n_lstm_layers': 3, 'nz_mid': 128, 'nz_enc': 128, 'nz_vae': 256, 'regress_length': True, 'attach_state_regressor': True, 'attach_inv_mdl': True, 'inv_mdl_params': AttrDict( n_actions=2, use_convs=False, build_encoder=False, ), 'untied_layers': True, 'decoder_distribution': 'discrete_logistic_mixture', }) h_config.pop("add_weighted_pixel_copy") cem_params = AttrDict( prune_final=True, horizon=200,
from blox import AttrDict from experiments.prediction.base_configs import base_tree as base_conf configuration = AttrDict(base_conf.configuration) configuration.metric_pruning_scheme = 'pruned_dtw' model_config = AttrDict(base_conf.model_config) model_config.update({ 'matching_type': 'balanced', })