Beispiel #1
0
    def loss(self,
             inputs,
             outputs,
             extra_action=True,
             first_image=True,
             log_error_arr=False):
        loss_gt = inputs.traj_seq
        loss_pad_mask = inputs.pad_mask
        if not first_image:
            loss_gt = loss_gt[:, 1:]
            loss_pad_mask = loss_pad_mask[:, 1:]

        weights = broadcast_final(loss_pad_mask, inputs.traj_seq)
        # Skip first frame
        losses = self.nll(outputs.distr, loss_gt[:, 1:], weights[:, 1:],
                          log_error_arr)

        if self._hp.regress_actions:
            actions_pad_mask = inputs.pad_mask[:, :-1]
            loss_actions = outputs.actions
            if extra_action:
                loss_actions = loss_actions[:, :-1]

            weights = broadcast_final(actions_pad_mask, inputs.actions)
            losses.dense_action_rec = NLL(self._hp.dense_action_rec_weight) \
                (loss_actions, inputs.actions, weights=weights, reduction=[-1, -2], log_error_arr=log_error_arr)

        return losses
Beispiel #2
0
    def attention(self,
                  q,
                  k,
                  v,
                  nz_k,
                  start_ind,
                  end_ind,
                  dropout=None,
                  forced_attention_step=None,
                  attention_weights=None):
        def tensor_product(key, sequence):
            dims = list(range(len(list(sequence.shape)))[3:])
            return (key[:, None] * sequence).sum(dim=dims)

        scores = tensor_product(q, k) / math.sqrt(nz_k) * self.temperature
        scores = MultiheadAttention.mask_out(scores, start_ind, end_ind)
        scores = F.softmax(scores, dim=1)

        if forced_attention_step is not None:
            scores = torch.zeros_like(scores)
            batchwise_assign(scores, forced_attention_step[:, 0].long(), 1.0)

        if attention_weights is not None:
            scores = attention_weights[..., None].repeat_interleave(
                scores.shape[2], 2)

        if dropout is not None and dropout.p > 0.0:
            scores = dropout(scores)

        return (broadcast_final(scores, v) * v).sum(dim=1), scores
Beispiel #3
0
 def forward(self, input, actions):
     net_outputs = self.net(input)
     padded_actions = torch.nn.functional.pad(
         actions, (0, 0, 0, net_outputs.shape[1] - actions.shape[1], 0, 0))
     # TODO quite sure the concatenation is automatic
     net_outputs = batch_apply(
         self.ac_net,
         torch.cat([net_outputs,
                    broadcast_final(padded_actions, input)],
                   dim=2))
     return net_outputs
Beispiel #4
0
    def loss(self, inputs, outputs, log_error_arr=False):
        losses = self.decoder.loss(inputs,
                                   outputs,
                                   extra_action=False,
                                   log_error_arr=log_error_arr)

        # TODO don't place loss on the final image
        weights = broadcast_final(inputs.pad_mask[:, 1:], outputs.p_z.mu)
        losses.kl = KLDivLoss2(self._hp.kl_weight, breakdown=1, free_nats_per_dim=self._hp.free_nats)\
            (outputs.q_z, outputs.p_z, weights=weights, log_error_arr=log_error_arr)

        return losses
Beispiel #5
0
    def loss(self, inputs, model_output):
        losses = AttrDict()

        # action prediction loss
        n_actions = model_output.actions.shape[1]
        losses.action_reconst = L2Loss(1.0)(model_output.actions, inputs.actions[:, :n_actions],
                                            weights=broadcast_final(inputs.pad_mask[:, :n_actions], inputs.actions))

        # compute total loss
        #total_loss = torch.stack([loss[1].value * loss[1].weight for loss in losses.items()]).sum()
        #losses.total = AttrDict(value=total_loss)
        # losses.total = total_loss*torch.tensor(np.nan)   # for checking if backprop works
        return losses
Beispiel #6
0
    def _get_lstm_inputs(self, root, inputs):
        """
        :param root:
        :return:
        """
        device = inputs.reference_tensor.device
        batch_size, time = self._hp.batch_size, self._hp.max_seq_len
        fullseq_shape = [batch_size, time] + list(inputs.enc_e_0.shape[1:])
        lstm_inputs = AttrDict()

        # collect start and end indexes and values of all segments
        e_0s = torch.zeros(fullseq_shape, dtype=torch.float32, device=device)
        e_gs = torch.zeros(fullseq_shape, dtype=torch.float32, device=device)
        start_inds, end_inds = torch.zeros((batch_size, time), dtype=torch.float32, device=device), \
                               torch.zeros((batch_size, time), dtype=torch.float32, device=device)
        reset_indicator = torch.zeros((batch_size, time),
                                      dtype=torch.uint8,
                                      device=device)
        for segment in root.full_tree(
        ):  # traversing the tree in breadth-first order.
            if segment.depth == 0:  # if leaf-node
                start_ind = torch.ceil(segment.start_ind).type(
                    torch.LongTensor)
                end_ind = torch.floor(segment.end_ind).type(torch.LongTensor)
                batchwise_assign(reset_indicator, start_ind, 1)

                # TODO iterating over batch must be gone
                for ex in range(self._hp.batch_size):
                    if start_ind[ex] > end_ind[ex]:
                        continue  # happens if start and end floats have no int in between
                    e_0s[ex, start_ind[ex]:end_ind[ex] +
                         1] = segment.e_0[ex]  # +1 for including end_ind frame
                    e_gs[ex, start_ind[ex]:end_ind[ex] + 1] = segment.e_g[ex]
                    start_inds[ex, start_ind[ex]:end_ind[ex] +
                               1] = segment.start_ind[ex]
                    end_inds[ex, start_ind[ex]:end_ind[ex] +
                             1] = segment.end_ind[ex]

        # perform linear interpolation
        time_steps = torch.arange(time, dtype=torch.float, device=device)
        inter = (time_steps - start_inds) / (end_inds - start_inds + 1e-7)

        lstm_inputs.reset_indicator = reset_indicator
        lstm_inputs.cell_input = (e_gs - e_0s) * broadcast_final(inter,
                                                                 e_gs) + e_0s
        lstm_inputs.reset_input = torch.cat([e_gs, e_0s], dim=2)

        return lstm_inputs
Beispiel #7
0
    def loss(self, inputs, outputs, add_total=True):
        losses = AttrDict()

        # subgoal reconstruction loss
        n_action_output = outputs.actions.shape[1]
        loss_weights = broadcast_final(
            outputs.pad_mask[:, :n_action_output],
            inputs.actions) if 'pad_mask' in outputs else 1
        losses.action_reconst = L2Loss(self._hp.action_rec_weight)(
            outputs.actions,
            outputs.action_targets[:, :n_action_output],
            weights=loss_weights)
        if self._hp.pred_states:
            losses.state_reconst = L2Loss(self._hp.state_rec_weight)(
                outputs.states, outputs.state_targets)

        return losses
Beispiel #8
0
 def mask_extra(seq):
     return seq * broadcast_final(get_pad_mask(model_output.end_ind, N), seq)
Beispiel #9
0
 def mean(self):
     template = torch.tensor([128, 64, 32, 16, 8, 4, 2, 1])
     p = self.p
     value = broadcast_final(template.to(p.device).float()[None], p)
     return (p * value.float()).sum(1) / 127.5 - 1
Beispiel #10
0
 def mean(self):
     template = torch.arange(256)
     p = self.p
     value = broadcast_final(template.to(p.device).float()[None], p)
     return (p * value.float()).sum(1) / 127.5 - 1