Beispiel #1
0
    def forward(self, inputs, phase='train'):
        """
        forward pass at training time
        """
        if not 'enc_traj_seq' in inputs:
            enc_traj_seq, _ = self.encoder(inputs.traj_seq[:, 0]) if self._hp.train_first_action_only \
                                    else batch_apply(self.encoder, inputs.traj_seq)
            if self._hp.train_first_action_only:
                enc_traj_seq = enc_traj_seq[:, None]
            enc_traj_seq = enc_traj_seq.detach(
            ) if self.detach_enc else enc_traj_seq

        enc_goal, _ = self.encoder(inputs.I_g)
        n_dim = len(enc_goal.shape)
        fused_enc = torch.cat((enc_traj_seq, enc_goal[:, None].repeat(
            1, enc_traj_seq.shape[1], *([1] * (n_dim - 1)))),
                              dim=2)
        #fused_enc = torch.cat((enc_traj_seq, enc_goal[:, None].repeat(1, enc_traj_seq.shape[1], 1, 1, 1)), dim=2)

        if self._hp.reactive:
            actions_pred = batch_apply(self.policy, fused_enc)
        else:
            policy_output = self.policy(fused_enc)
            actions_pred = policy_output

        # remove last time step to match ground truth if training on full sequence
        actions_pred = actions_pred[:, :
                                    -1] if not self._hp.train_first_action_only else actions_pred

        output = AttrDict()
        output.actions = remove_spatial(actions_pred) if len(
            actions_pred.shape) > 3 else actions_pred
        return output
Beispiel #2
0
    def full_seq_forward(self, inputs):
        if 'model_enc_seq' in inputs:
            enc_seq_1 = inputs.model_enc_seq[:, 1:]
            if self._hp.train_im0_enc and 'enc_traj_seq' in inputs:
                enc_seq_0 = inputs.enc_traj_seq.reshape(
                    inputs.enc_traj_seq.shape[:2] +
                    (self._hp.nz_enc, ))[:, :-1]
                enc_seq_0 = enc_seq_0[:, :enc_seq_1.shape[1]]
            else:
                enc_seq_0 = inputs.model_enc_seq[:, :-1]
        else:
            enc_seq = batch_apply(self.encoder, inputs.traj_seq)
            enc_seq_0, enc_seq_1 = enc_seq[:, :-1], enc_seq[:, 1:]

        if self.detach_enc:
            enc_seq_0 = enc_seq_0.detach()
            enc_seq_1 = enc_seq_1.detach()

        # TODO quite sure the concatenation is automatic
        actions_pred = batch_apply(self.action_pred,
                                   torch.cat([enc_seq_0, enc_seq_1], dim=2))

        output = AttrDict()
        output.actions = actions_pred  #remove_spatial(actions_pred)
        if 'actions' in inputs:
            output.action_targets = inputs.actions
            output.pad_mask = inputs.pad_mask
        return output
Beispiel #3
0
    def apply_fn(self, inputs, fn, left_parents, right_parents):
        """ Recursively applies fn to the tree.
        
        :param inputs:
        :param fn: a function that takes in (inputs, subgoal, left_parent, right_parent) and outputs a dict
        :param left_parents:
        :param right_parents:
        :return:
        """
        
        if self.depth == 0:
            return
        assert self.subgoals is not None      # need subgoal info to match to ground truth sequence

        self.subgoals.update(batch_apply(fn, inputs, self.subgoals, left_parents, right_parents, unshape_inputs=True))
        self.child_layer.apply_fn(rec_interleave([inputs, inputs]),
                                  fn,
                                  rec_interleave([left_parents, self.subgoals]),
                                  rec_interleave([self.subgoals, right_parents]))
Beispiel #4
0
    def produce_tree(self, inputs, layerwise_inputs, start_inds, end_inds, left_parents, right_parents, producer, depth):
        """no done mask checks, assumes start_ind never None.
            all input tensors are of shape [batch, num_parent_nodes, ...]
        """
        self.depth = depth
        if depth == 0:
            return

        # slice out inputs for this layer

        layer_inputs = rmap(lambda x: depthfirst2layers(reduce_dim(x, dim=1))[-depth].contiguous(), layerwise_inputs)

        out = batch_apply(lambda x: producer.produce_subgoal(inputs, *x, depth=depth),
                          [layer_inputs, start_inds.float(), end_inds.float(), left_parents, right_parents])
        self.subgoals, left_parents, right_parents = out

        self.child_layer = SubgoalTreeLayer(self)
        self.child_layer.produce_tree(inputs,
                                      layerwise_inputs,
                                      rec_interleave([start_inds.float(), self.subgoals.ind.clone()]),
                                      rec_interleave([self.subgoals.ind.clone(), end_inds.float()]),
                                      rec_interleave([left_parents, self.subgoals]),
                                      rec_interleave([self.subgoals, right_parents]),
                                      producer, depth - 1)