Пример #1
0
    def forward(self, root, inputs):
        outputs = AttrDict()
        # TODO implement soft interpolation

        sg_times, sg_encs = [], []
        for segment in root:
            sg_times.append(segment.subgoal.ind)
            sg_encs.append(segment.subgoal.e_g_prime)
        sg_times = torch.stack(sg_times, dim=1)
        sg_encs = torch.stack(sg_encs, dim=1)

        # compute time difference weights
        seq_length = self._hp.max_seq_len
        target_ind = torch.arange(end=seq_length, dtype=sg_times.dtype)
        time_diffs = torch.abs(target_ind[None, None, :] -
                               sg_times[:, :, None])
        weights = nn.functional.softmax(-time_diffs, dim=-1)

        # compute weighted sum outputs
        weighted_sg = weights[:, :, :, None, None,
                              None] * sg_encs.unsqueeze(2).repeat(
                                  1, 1, seq_length, 1, 1, 1)
        outputs.encodings = torch.sum(weighted_sg, dim=1)
        outputs.update(
            self._dense_decode(inputs, outputs.encodings, seq_length))
        return outputs
Пример #2
0
    def forward(self, root, inputs):
        lstm_inputs = AttrDict()
        initial_inputs = AttrDict(x=inputs.e_0)
        context = torch.cat([inputs.e_0, inputs.e_g], dim=1)
        static_inputs = AttrDict()

        if 'enc_traj_seq' in inputs:
            lstm_inputs.x_prime = inputs.enc_traj_seq[:, 1:]
        if 'z' in inputs:
            lstm_inputs.z = inputs.z
        if self._hp.context_every_step:
            static_inputs.context = context
        if self._hp.action_conditioned_pred:
            assert 'enc_action_seq' in inputs  # need to feed actions for action conditioned predictor
            lstm_inputs.update(more_context=inputs.enc_action_seq)

        self.lstm.cell.init_state(initial_inputs.x, context,
                                  lstm_inputs.get('more_context', None))
        # Note: the last image is also produced. The actions are defined as going to the image
        outputs = self.lstm(inputs=lstm_inputs,
                            initial_inputs=initial_inputs,
                            static_inputs=static_inputs,
                            length=self._hp.max_seq_len - 1)
        outputs.encodings = outputs.pop('x')
        outputs.update(self.decoder.decode_seq(inputs, outputs.encodings))
        outputs.images = torch.cat([inputs.I_0[:, None], outputs.images],
                                   dim=1)
        return outputs
Пример #3
0
 def forward(self, x, output_length, conditioning_length, context=None):
     """
     
     :param x: the modelled sequence, batch x time x  x_dim
     :param length: the desired length of the output sequence. Note, this includes all conditioning frames except 1
     :param conditioning_length: the length on which the prediction will be conditioned. Ground truth data are observed
     for this length
     :param context: a context sequence. Prediction is conditioned on all context up to and including this moment
     :return:
     """
     lstm_inputs = AttrDict()
     outputs = AttrDict()
     if context is not None:
         lstm_inputs.more_context = context
 
     if not self._sample_prior:
         outputs.q_z = self.inference(x, context)
         lstm_inputs.z = Gaussian(outputs.q_z).sample()
 
     outputs.update(self.generator(inputs=lstm_inputs,
                                   length=output_length + conditioning_length,
                                   static_inputs=AttrDict(batch_size=x.shape[0])))
     # The way the conditioning works now is by zeroing out the loss on the KL divergence and returning less frames
     # That way the network can pass the info directly through z. I can also implement conditioning by feeding
     # the frames directly into predictor. that would require passing previous frames to the VRNNCell and
     # using a fake frame to condition the 0th frame on.
     outputs = rmap(lambda ten: ten[:, conditioning_length:], outputs)
     outputs.conditioning_length = conditioning_length
     return outputs
Пример #4
0
    def get_node_loss(self, inputs, outputs):
        """ Reconstruction and KL divergence loss """
        losses = AttrDict()
        tree = outputs.tree

        losses.update(self.binding.reconstruction_loss(inputs, outputs))
        losses.update(self.inference.loss(tree.bf.q_z, tree.bf.p_z))

        return losses
Пример #5
0
 def reconstruction_loss(self, inputs, outputs, weights):
     losses = AttrDict()
 
     outputs.soft_matched_estimates = self.criterion.get_soft_estimates(outputs.gt_match_dists,
                                                                        outputs.tree.bf.images)
     losses.update(self.criterion.loss(
         outputs, inputs.traj_seq, weights, inputs.pad_mask, self._hp.dense_img_rec_weight, self.decoder.log_sigma))
 
     return losses
Пример #6
0
    def forward(self, root, inputs):
        # TODO implement stopping probability prediction
        # TODO make the low-level network not predict subgoals
        batch_size, time = self._hp.batch_size, self._hp.max_seq_len
        outputs = AttrDict()

        lstm_inputs = self._get_lstm_inputs(root, inputs)
        lstm_outputs = self.lstm(lstm_inputs, time)
        outputs.encodings = torch.stack(lstm_outputs, dim=1)
        outputs.update(self._dense_decode(inputs, outputs.encodings, time))
        return outputs
Пример #7
0
 def _create_initial_nodes(self, inputs):
     start_node, end_node = AttrDict(e_g_prime=inputs.enc_e_0, images=inputs.I_0), \
                            AttrDict(e_g_prime=inputs.enc_e_g, images=inputs.I_g)
     if self._hp.forced_attention or self._hp.timestep_cond_attention or self._hp.supervise_attn_weight > 0.0:
         start_match_timestep, end_match_timestep = self.one_step_planner.matcher.get_init_inds(
             inputs)
         start_node.update(AttrDict(match_timesteps=start_match_timestep))
         end_node.update(AttrDict(match_timesteps=end_match_timestep))
     if self._hp.tree_lstm:
         start_node.hidden_state, end_node.hidden_state = None, None
     return start_node, end_node
Пример #8
0
 def _create_initial_nodes(self, inputs):
     start_node, end_node = AttrDict(e_g_prime=inputs.e_0, images=inputs.I_0), \
                            AttrDict(e_g_prime=inputs.e_g, images=inputs.I_g)
     if not self._hp.attentive_inference:
         start_match_timestep, end_match_timestep = self.tree_module.binding.get_init_inds(
             inputs)
         start_node.update(AttrDict(match_timesteps=start_match_timestep))
         end_node.update(AttrDict(match_timesteps=end_match_timestep))
     if self._hp.tree_lstm:
         start_node.hidden_state, end_node.hidden_state = None, None
     return start_node, end_node
Пример #9
0
    def loss(self, inputs, model_output, log_error_arr=False):
        losses = AttrDict()
        
        # Length prediction loss
        if self._hp.regress_length:
            losses.update(self.length_pred.loss(inputs, model_output))
        
        # Dense Reconstruction loss
        losses.update(self.dense_rec.loss(inputs, model_output.dense_rec, log_error_arr))

        # Inverse Model loss
        if self._hp.attach_inv_mdl:
            losses.update(self.inv_mdl.loss(inputs, model_output, add_total=False))

        # Cost model loss
        if self._hp.attach_cost_mdl and self._hp.run_cost_mdl:
            losses.update(self.cost_mdl.loss(inputs, model_output))

        # State regressor cost
        if self._hp.attach_state_regressor:
            reg_len = model_output.regressed_state.shape[1]
            losses.state_regression = L2Loss(1.0)(model_output.regressed_state, inputs.demo_seq_states[:, :reg_len],
                                                  weights=inputs.pad_mask[:, :reg_len][:, :, None])

        # Negative Log-likelihood (upper bound)
        if 'dense_img_rec' in losses and 'kl' in losses:
            losses.nll = AttrDict(value=losses.dense_img_rec.value + 1.0 * losses.kl.value, weight=0.0)

        return losses
Пример #10
0
    def loss(self, inputs, outputs):
        if outputs.tree.depth == 0:
            return {}

        losses = AttrDict()

        losses.update(self.get_node_loss(inputs, outputs))

        # Explaining loss
        losses.update(self.binding.loss(inputs, outputs))

        # entropy penalty
        losses.entropy = PenaltyLoss(weight=self._hp.entropy_weight)(
            outputs.entropy)

        return losses
Пример #11
0
    def get_node_loss(self, inputs, outputs):
        """ Reconstruction and KL divergence loss """
        losses = AttrDict()
        tree = outputs.tree
        
        # Weight of the loss
        kl_weights = weights = 1
        if self._hp.equal_weight_layer:
            top = 2 ** (self._hp.hierarchy_levels - 1)
            # For each layer, divide the weight by the number of elements in the layer
            weights = np.concatenate([np.full((2 ** l,), top / (2 ** l)) for l in range(self._hp.hierarchy_levels)])
            weights = torch.from_numpy(weights).to(self._hp.device).float()[None, :, None]
            kl_weights = weights[..., None, None]
        
        losses.update(self.matcher.reconstruction_loss(inputs, outputs, weights))
        
        losses.update(self.inference.loss(tree.bf.q_z, tree.bf.p_z, weights=kl_weights))

        return losses
Пример #12
0
    def forward(self, inputs, full_seq=None):
        outputs = AttrDict()
        
        n_repeats = 32
        if inputs.I_0.shape[0] == inputs.demo_seq.shape[0]:
            # Repeat the I_0 so that the visualization is correct
            # This should only be done once per batch!!
            inputs.I_0 = inputs.I_0.repeat_interleave(n_repeats, 0)
        
        I_target = self.sample_target(inputs.demo_seq, inputs.end_ind, n_repeats)
        inputs.I_target = I_target = I_target.reshape((-1,) + I_target.shape[2:])

        # Note, the decoder supports skips and pixel copying!
        e_0, _ = self.encoder(inputs.I_0)
        e_g, _ = self.encoder(I_target)
        
        outputs.update(self.net(e_g, e_0))
        outputs.e_target_prime = e_target_prime = outputs.pop('mu')
        outputs.I_target_prime = self.decoder(e_target_prime)
        
        return outputs
Пример #13
0
    def _default_hparams(self):
        # Data Dimensions
        default_dict = AttrDict({
            'batch_size': -1,
            'max_seq_len': -1,
            'n_actions': -1,
            'state_dim': -1,
            'img_sz': 32,  # image resolution
            'input_nc': 3,  # number of input feature maps
            'n_conv_layers':
            None,  # Number of conv layers. Can be of format 'n-<int>' for any int for relative spec
        })

        # Network params
        default_dict.update({
            'use_convs': True,
            'use_batchnorm': True,  # TODO deprecate
            'normalization': 'batch',
            'predictor_normalization': 'group',
        })

        # Misc params
        default_dict.update({
            'filter_repeated_tail':
            False,  # whether to remove repeated states from the dataset
            'rep_tail': False,
            'dataset_class': None,
            'standardize': None,
            'split': None,
            'subsampler': None,
            'subsample_args': None,
            'checkpt_path': None,
        })

        # add new params to parent params
        parent_params = HParams()
        for k in default_dict.keys():
            parent_params.add_hparam(k, default_dict[k])
        return parent_params
Пример #14
0
    def forward(self, x, output_length, conditioning_length, context=None):
        """
        
        :param x: the modelled sequence, batch x time x  x_dim
        :param length: the desired length of the output sequence. Note, this includes all conditioning frames except 1
        :param conditioning_length: the length on which the prediction will be conditioned. Ground truth data are observed
        for this length
        :param context: a context sequence. Prediction is conditioned on all context up to and including this moment
        :return:
        """
        lstm_inputs = AttrDict(x_prime=x[:, 1:])
        if context is not None:
            context = pad(context, pad_front=1, dim=1)
            lstm_inputs.update(more_context=context[:, 1:])

        initial_inputs = AttrDict(x=x[:, :conditioning_length])

        self.lstm.cell.init_state(x[:, 0], more_context=context)
        outputs = self.lstm(inputs=lstm_inputs,
                            initial_seq_inputs=initial_inputs,
                            length=output_length + conditioning_length - 1)
        outputs = rmap(lambda ten: ten[:, conditioning_length - 1:], outputs)
        return outputs
Пример #15
0
    def loss(self, inputs, model_output):
        if model_output.tree.depth == 0:
            return {}

        losses = AttrDict()

        if not 'gt_matching_dists' in model_output:     # potentially already computed in forward pass
            self.compute_matching(inputs, model_output)

        losses.update(self.get_node_loss(inputs, model_output))

        losses.update(self.get_extra_losses(inputs, model_output))
        
        # Explaining loss
        losses.update(self.matcher.loss(inputs, model_output))
        
        # entropy penalty
        losses.entropy = PenaltyLoss(weight=self._hp.entropy_weight)(model_output.entropy)

        return losses
Пример #16
0
    def produce_subgoal(self,
                        inputs,
                        layerwise_inputs,
                        start_ind,
                        end_ind,
                        left_parent,
                        right_parent,
                        depth=None):
        """
        Divides the subsequence by producing a subgoal inside it.
         This function represents one step of recursion of the model
        """
        subgoal = AttrDict()

        e_l = left_parent.e_g_prime
        e_r = right_parent.e_g_prime

        subgoal.p_z = self.prior(e_l, e_r)

        if 'z' in layerwise_inputs:
            z = layerwise_inputs.z
            if self._hp.prior_type == 'learned':  # reparametrize if learned prior is used
                z = subgoal.p_z.reparametrize(z)
        elif self._sample_prior:
            z = subgoal.p_z.sample()
        else:
            ## Inference
            if self._hp.attentive_inference:
                subgoal.update(
                    self.inference(inputs, e_l, e_r, start_ind, end_ind))
            else:
                subgoal.match_timesteps = self.binding.comp_timestep(
                    left_parent.match_timesteps, right_parent.match_timesteps)
                subgoal.update(
                    self.inference(inputs, e_l, e_r, start_ind, end_ind,
                                   subgoal.match_timesteps.float()))

            z = subgoal.q_z.sample()

        ## Predict the next node
        pred_input = [e_l, e_r, z]
        if self._hp.context_every_step:
            mult = int(z.shape[0] / inputs.e_0.shape[0])
            pred_input += [
                inputs.e_0.repeat_interleave(mult, 0),
                inputs.e_g.repeat_interleave(mult, 0)
            ]

        if self._hp.tree_lstm:
            if left_parent.hidden_state is None and right_parent.hidden_state is None:
                left_parent.hidden_state, right_parent.hidden_state = self.lstm_initializer(
                    e_l, e_r, z)

            subgoal.hidden_state, subgoal.e_g_prime = \
                self.subgoal_pred(left_parent.hidden_state, right_parent.hidden_state, *pred_input)
        else:
            subgoal.e_g_prime_preact = self.subgoal_pred(*pred_input)
            subgoal.e_g_prime = torch.tanh(subgoal.e_g_prime_preact)

        subgoal.ind = (
            start_ind + end_ind
        ) / 2  # gets overwritten w/ argmax of matching at training time (in loss)
        return subgoal, left_parent, right_parent
Пример #17
0
def get_default_gcp_hyperparameters():
    
    # Params that actually should be elsewhere
    default_dict = AttrDict({
        'randomize_length': False,
        'randomize_start': False,
    })
    
    # Network size
    default_dict.update({
        'ngf': 4,  # number of feature maps in shallowest level
        'nz_enc': 32,  # number of dimensions in encoder-latent space
        'nz_vae': 32,  # number of dimensions in vae-latent space
        'nz_vae2': 256,  # number of dimensions in 2nd level vae-latent space (if used)
        'nz_mid': 32,  # number of dimensions for internal feature spaces
        'nz_mid_lstm': 32,
        'n_lstm_layers': 1,
        'n_processing_layers': 3,  # Number of layers in MLPs
        'conv_inf_enc_kernel_size': 3,  # kernel size of convolutional inference encoder
        'conv_inf_enc_layers': 1,  # number of layers in convolutional inference encoder
        'n_attention_heads': 1,  # number of attention heads (needs to divide nz_enc evenly)
        'n_attention_layers': 1,  # number of layers in attention network
        'nz_attn_key': 32,  # dimensionality of the attention key
        'init_mlp_layers': 3,  # number of layers in the LSTM initialization MLP (if used)
        'init_mlp_mid_sz': 32,  # size of hidden layers inside LSTM initialization MLP (if used)
        'n_conv_layers': None,  # Number of conv layers. Can be of format 'n-<int>' for any int for relative spec
    })
    
    # Network params
    default_dict.update(AttrDict(
        action_activation=None,
        device=None,
        context_every_step=True,
    ))
    
    # Loss weights
    default_dict.update({
        'kl_weight': 1.,
        'kl_weight_burn_in': None,
        'entropy_weight': .0,
        'length_pred_weight': 1.,
        'dense_img_rec_weight': 1.,
        'dense_action_rec_weight': 1.,
        'free_nats': 0,
    })
    
    # Architecture params
    default_dict.update({
        'use_skips': True,  # only works with conv encoder/decoder
        'skips_stride': 2,
        'add_weighted_pixel_copy': False,  # if True, adds pixel copying stream for decoder
        'pixel_shift_decoder': False,
        'skip_from_parents': False,  # If True, parents are added to the pixel copy/shift sources
        'seq_enc': 'none',  # Manner of sequence encoding. ['none', 'conv', 'lstm']
        'regress_actions': False,  # whether or not to regress actions
        'learn_attn_temp': True,  # if True, makes attention temperature a trainable variable
        'attention_temperature': 1.0,  # temperature param used in attention softmax
        'attach_inv_mdl': False,  # if True, attaches an inverse model to output that computes actions
        'attach_cost_mdl': False,  # if True, attaches a cost function MLP that estimates cost from pairs of states
        'run_cost_mdl': True,   # if False, does not run cost model (but might still build it
        'attach_state_regressor': False,    # if True, attaches network that regresses states from pre-decoding-latents
        'action_conditioned_pred': False,  # if True, conditions prediction on actions
        'learn_beta': True,  # if True, learns beta
        'initial_sigma': 1.0,  # if True, learns beta
        'separate_cnn_start_goal_encoder': False,   # if True, builds separate conv encoder for start/goal image
        'decoder_distribution': 'gaussian'  # [gaussian, categorical]
    })

    # RNN params
    default_dict.update({
        'use_conv_lstm': False,
    })
    
    # Variational inference parameters
    default_dict.update(AttrDict(
        prior_type='learned',  # type of prior to be used ['fixed', 'learned']
        var_inf='standard',  # type of variation inference ['standard', '2layer', 'deterministic']
    ))
    
    # RecPlan params
    default_dict.update({
        'hierarchy_levels': 3,  # number of levels in the subgoal tree
        
        'one_hot_attn_time_cond': False,  # if True, conditions attention on one-hot time step index
        'attentive_inference': False,  # if True, forces attention to single matching frame
        'non_goal_conditioned': False,  # if True, does not condition prediction on goal frame
        
        'tree_lstm': '',  # ['', 'sum' or 'linear']
        'lstm_init': 'zero',  # defines how treeLSTM is initialized, ['zero', 'mlp'], #, 'warmup']
        
        'matching_temp': 1.0,  # temperature used in TAP-style matching softmax
        'matching_temp_tenthlife': -1,
        'matching_temp_min': 1e-3,
        'matching_type': 'latent',  # evidence binding procedure
        # ['image', 'latent', 'fraction', 'balanced', 'tap']
        'leaves_bias': 0.0,
        'top_bias': 1.0,
        'n_top_bias_nodes': 1,
        'supervise_match_weight': 0.0,
        
        'regress_index': False,
        'regress_length': False,
        
        'inv_mdl_params': {},  # params for the inverse model, if attached
        'train_inv_mdl_full_seq': False,  # if True, omits sampling for inverse model and trains on full seq

        'cost_mdl_params': {},   # cost model parameters

        'act_cond_inference': False,  # if True, conditions inference on actions

        'train_on_action_seqs': False,  # if True, trains the predictive network on action sequences

        'learned_pruning_threshold': 0.5,   # confidence thresh for learned pruning network after which node gets pruned
        'untied_layers': False,
        'supervised_decoder': False,
        'states_inference': False,
    })
    
    # Outdated GCP params
    default_dict.update({
        'dense_rec_type': 'none',  # ['none', 'discrete', 'softmax', 'linear', 'node_prob', action_prob].
        'one_step_planner': 'discrete',  # ['discrete', 'continuous', 'sh_pred']. Always 'sh_pred' for HEDGE models
        'mask_inf_attention': False,  # if True, masks out inference attention outside the current subsegment
        'binding': 'frames',  # Matching loss form ['loss', 'frames', 'exp', 'lf']. Always 'loss'.
    })
    
    # Matching params
    default_dict.update(AttrDict(
        learn_matching_temp=True,  # If true, the matching temperature is learned
    ))
    
    # Logging params
    default_dict.update(AttrDict(
        dump_encodings='',  # Specifies the directory where to dump the encodings
        dump_encodings_inv_model='',  # Specifies the directory where to dump the encodings
        log_states_2d=False,  # if True, logs 2D plot of first two state dimensions
        log_cartgripper=False,  # if True, logs sawyer from states
        data_dir='',   # necessary for sawyer logging
    ))

    # Hyperparameters that shouldn't exist
    default_dict.update(AttrDict(
        log_d2b_3x3maze=0,  # Specifies the directory where to dump the encodings
    ))
    
    
    return default_dict
Пример #18
0
from blox import AttrDict
from experiments.prediction.base_configs import base_tree as base_conf

configuration = AttrDict(base_conf.configuration)

model_config = AttrDict(base_conf.model_config)
model_config.update({
    'matching_type': 'dtw_image',
    'learn_matching_temp': False,
    'attentive_inference': True,
})
Пример #19
0
from blox import AttrDict
from experiments.prediction.base_configs import base_tree as base_conf

configuration = AttrDict(base_conf.configuration)
configuration.metric_pruning_scheme = 'pruned_dtw'

model_config = AttrDict(base_conf.model_config)
model_config.update({
    'matching_type': 'balanced',
    'forced_attention': True,
})
Пример #20
0
import imp
import os

from blox import AttrDict
from gcp.datasets.data_loader import MazeTopRenderedGlobalSplitVarLenVideoDataset
from gcp.planning.cem_policy.utils.cost_fcn import EuclideanPathLength

current_dir = os.path.dirname(os.path.realpath(__file__))
from experiments.prediction.base_configs import gcp_sequential as base_conf

configuration = AttrDict(base_conf.configuration)
configuration.update({
    'dataset_name': 'nav_9rooms',
    'batch_size': 16,
    'lr': 2e-4,
    'epoch_cycles_train': 2,
    'n_rooms': 9,
    'metric_pruning_scheme': 'basic',
})

model_config = AttrDict(base_conf.model_config)
model_config.update({
    'ngf': 16,
    'nz_mid_lstm': 512,
    'n_lstm_layers': 3,
    'nz_mid': 128,
    'nz_enc': 128,
    'nz_vae': 256,
    'regress_length': True,
    'attach_state_regressor': True,
    'attach_cost_mdl': True,
Пример #21
0
from blox import AttrDict
from experiments.prediction.base_configs import base_tree as base_conf

configuration = AttrDict(base_conf.configuration)

model_config = AttrDict(base_conf.model_config)
model_config.update({
    'matching_type': 'dtw_image',
    'learn_matching_temp': False,
})
Пример #22
0
    def produce_subgoal(self, inputs, layerwise_inputs, start_ind, end_ind, left_parent, right_parent, depth=None):
        """
        Divides the subsequence by producing a subgoal inside it.
         This function represents one step of recursion of the model
        """
        subgoal = AttrDict()
        batch_size = start_ind.shape[0]

        e_l = left_parent.e_g_prime
        e_r = right_parent.e_g_prime

        subgoal.p_z = self.prior(e_l, e_r)

        if 'z' in layerwise_inputs:
            z = layerwise_inputs.z
            if self._hp.prior_type == 'learned':    # reparametrize if learned prior is used
                z = subgoal.p_z.reparametrize(z)
        elif self._sample_prior:
            z = subgoal.p_z.sample()
        else:
            ## Inference
            if (self._hp.timestep_cond_attention or self._hp.forced_attention):
                subgoal.fraction = self.fraction_pred(e_l, e_r)[..., -1] if self.predict_fraction else None
                subgoal.match_timesteps = self.matcher.comp_timestep(left_parent.match_timesteps,
                                                                     right_parent.match_timesteps,
                                                                     subgoal.fraction[:,
                                                                     None] if subgoal.fraction is not None else None)
                subgoal.update(self.inference(inputs, e_l, e_r, start_ind, end_ind, subgoal.match_timesteps.float()))
            else:
                subgoal.update(self.inference(
                    inputs, e_l, e_r, start_ind, end_ind, attention_weights=layerwise_inputs.safe.attention_weights))
                
            z = subgoal.q_z.sample()

        ## Predict the next node
        pred_input = [e_l, e_r, z]
        if self._hp.context_every_step:
            mult = int(z.shape[0] / inputs.enc_e_0.shape[0])
            pred_input += [inputs.enc_e_0.repeat_interleave(mult, 0),
                           inputs.enc_e_g.repeat_interleave(mult, 0)]
        
        if self._hp.tree_lstm:
            if left_parent.hidden_state is None and right_parent.hidden_state is None:
                left_parent.hidden_state, right_parent.hidden_state = self.lstm_initializer(e_l, e_r, z)
                if self._hp.lstm_warmup_cycles > 0:
                    for _ in range(self._hp.lstm_warmup_cycles):
                        left_parent.hidden_state, __ = \
                            self.subgoal_pred(left_parent.hidden_state, right_parent.hidden_state, e_l, e_r, z)
                        right_parent.hidden_state = left_parent.hidden_state.clone()
                        
            subgoal.hidden_state, subgoal.e_g_prime = \
                self.subgoal_pred(left_parent.hidden_state, right_parent.hidden_state, *pred_input)
        else:
            subgoal.e_g_prime_preact = self.subgoal_pred(*pred_input)
            subgoal.e_g_prime = torch.tanh(subgoal.e_g_prime_preact)

        ## Additional predicted values
        if self.predict_fraction and not self._hp.timestep_cond_attention:
            subgoal.fraction = self.fraction_pred(e_l, e_r, subgoal.e_g_prime)[..., -1]     # remove unnecessary dim
            
        # add attention target if trained with attention supervision
        if self._hp.supervise_attn_weight > 0.0:
            frac = subgoal.fraction[:, None] if 'fraction' in subgoal and subgoal.fraction is not None else None
            subgoal.match_timesteps = self.matcher.comp_timestep(left_parent.match_timesteps,
                                                                 right_parent.match_timesteps,
                                                                 frac)

        subgoal.ind = (start_ind + end_ind) / 2     # gets overwritten w/ argmax of matching at training time (in loss)
        return subgoal, left_parent, right_parent
Пример #23
0
    def forward(self, inputs, phase='train'):
        """
        forward pass at training time
        :param
            images shape = batch x time x height x width x channel
            pad mask shape = batch x time, 1 indicates actual image 0 is padded
        :return:
        """
        if self._hp.non_goal_conditioned:
            if 'demo_seq' in inputs:
                inputs.demo_seq[torch.arange(inputs.demo_seq.shape[0]), inputs.end_ind] = 0.0
                inputs.demo_seq_images[torch.arange(inputs.demo_seq.shape[0]), inputs.end_ind] = 0.0
            inputs.I_g = torch.zeros_like(inputs.I_g)
            if "I_g_image" in inputs:
                inputs.I_g_image = torch.zeros_like(inputs.I_g_image)
            if inputs.I_0.shape[-1] == 5: # special hack for maze
                inputs.I_0[..., -2:] = 0.0
                if "demo_seq" in inputs:
                    inputs.demo_seq[..., -2:] = 0.0

        # swap in actions if we want to train action sequence decoder
        if self._hp.train_on_action_seqs:
            inputs.demo_seq = torch.cat([inputs.actions, torch.zeros_like(inputs.actions[:, :1])], dim=1)

        model_output = AttrDict()
        inputs.reference_tensor = find_tensor(inputs)

        if 'start_ind' not in inputs:
            start_ind = torch.zeros(self._hp.batch_size, dtype=torch.long, device=inputs.reference_tensor.device)
        else:
            start_ind = inputs.start_ind
    
        self.run_encoder(inputs, start_ind)
    
        end_ind = inputs.end_ind if 'end_ind' in inputs else None
        if self._hp.regress_length:
            # predict total sequence length
            model_output.update(self.length_pred(inputs.enc_e_0, inputs.enc_e_g))
            if self._use_pred_length and (self._hp.length_pred_weight > 0 or end_ind is None):
                end_ind = torch.argmax(model_output.seq_len_pred.sample().long(), dim=1)
                if self._hp.action_conditioned_pred or self._hp.non_goal_conditioned:
                    # don't use predicted length when action conditioned
                    end_ind = torch.ones_like(end_ind) * (self._hp.max_seq_len - 1)
        # TODO clean this up. model_output.end_ind is not currently used anywhere
        model_output.end_ind = end_ind
    
        # Run the model to generate sequences
        model_output.update(self.predict_sequence(inputs, model_output, start_ind, end_ind, phase))
    
        if self.prune_sequences:
            if phase == 'train':
                inputs.model_enc_seq = self.get_matched_pruned_seqs(inputs, model_output)
            else:
                inputs.model_enc_seq = self.get_predicted_pruned_seqs(inputs, model_output)
            inputs.model_enc_seq = pad_sequence(inputs.model_enc_seq, batch_first=True)
            if len(inputs.model_enc_seq.shape) == 5:
                inputs.model_enc_seq = inputs.model_enc_seq[..., 0, 0]
                
            if self._hp.attach_inv_mdl and phase == 'train':
                model_output.update(self.inv_mdl(inputs, full_seq=self._inv_mdl_full_seq or self._hp.train_inv_mdl_full_seq))
            if self._hp.attach_state_regressor:
                regressor_inputs = inputs.model_enc_seq
                if not self._hp.supervised_decoder:
                    regressor_inputs = regressor_inputs.detach()
                model_output.regressed_state = batch_apply(regressor_inputs, self.state_regressor)
            if self._hp.attach_cost_mdl and self._hp.run_cost_mdl and phase == 'train':
                # There is an issue here since SVG doesn't output a latent for the first imagge
                # Beyong conceptual problems, this breaks if end_ind = 199
                model_output.update(self.cost_mdl(inputs))
    
        return model_output
Пример #24
0
    image_width=32,
    start_goal_confs=os.environ['GCP_DATA_DIR'] + '/nav_25rooms/start_goal_configs/raw',
)

h_config = AttrDict(base_conf.model_config)
h_config.update({
    'state_dim': 2,
    'ngf': 16,
    'max_seq_len': 200,
    'hierarchy_levels': 8,
    'nz_mid_lstm': 512,
    'n_lstm_layers': 3,
    'nz_mid': 128,
    'nz_enc': 128,
    'nz_vae': 256,
    'regress_length': True,
    'attach_state_regressor': True,
    'attach_inv_mdl': True,
    'inv_mdl_params': AttrDict(
        n_actions=2,
        use_convs=False,
        build_encoder=False,
    ),
    'untied_layers': True,
    'decoder_distribution': 'discrete_logistic_mixture',
})
h_config.pop("add_weighted_pixel_copy")

cem_params = AttrDict(
    prune_final=True,
    horizon=200,
Пример #25
0
from blox import AttrDict
from experiments.prediction.base_configs import base_tree as base_conf

configuration = AttrDict(base_conf.configuration)
configuration.metric_pruning_scheme = 'pruned_dtw'

model_config = AttrDict(base_conf.model_config)
model_config.update({
    'matching_type': 'balanced',
})