def forward(self, forest, features, num_obj, labels=None, boxes_for_nms=None, batch_size=0):
        # generate dropout
        if self.dropout > 0.0:
            dropout_mask = get_dropout_mask(self.dropout, self.hidden_size)
        else:
            dropout_mask = None

        # generate tree lstm input/output class
        out_h = None
        out_dists = None
        out_commitments = None
        h_order = Variable(torch.LongTensor(num_obj).zero_().cuda())
        order_idx = 0
        lstm_io = tree_utils.TreeLSTM_IO(out_h, h_order, order_idx, out_dists, out_commitments, dropout_mask)

        for idx in range(len(forest)):
            self.decoderLSTM(forest[idx], features, lstm_io)

        out_h = torch.index_select(lstm_io.hidden, 0, lstm_io.order.long())
        out_dists = torch.index_select(lstm_io.dists, 0, lstm_io.order.long())[:-batch_size]
        out_commitments = torch.index_select(lstm_io.commitments, 0, lstm_io.order.long())[:-batch_size]

        # Do NMS here as a post-processing step
        if boxes_for_nms is not None and not self.training and self.not_rl:
            is_overlap = nms_overlaps(boxes_for_nms.data).view(
                boxes_for_nms.size(0), boxes_for_nms.size(0), boxes_for_nms.size(1)
            ).cpu().numpy() >= self.nms_thresh
            # is_overlap[np.arange(boxes_for_nms.size(0)), np.arange(boxes_for_nms.size(0))] = False

            out_dists_sampled = F.softmax(out_dists, 1).data.cpu().numpy()
            out_dists_sampled[:,0] = 0

            out_commitments = out_commitments.data.new(out_commitments.shape[0]).fill_(0)

            for i in range(out_commitments.size(0)):
                box_ind, cls_ind = np.unravel_index(out_dists_sampled.argmax(), out_dists_sampled.shape)
                out_commitments[int(box_ind)] = int(cls_ind)
                out_dists_sampled[is_overlap[box_ind,:,cls_ind], cls_ind] = 0.0
                out_dists_sampled[box_ind] = -1.0 # This way we won't re-sample

            out_commitments = Variable(out_commitments.view(-1))
        else:
            out_commitments = out_commitments.view(-1)

        if self.training and self.not_rl and (labels is not None):
            out_commitments = labels.clone()
        else:
            out_commitments = torch.cat((out_commitments, Variable(torch.randn(batch_size).long().fill_(0).cuda()).view(-1)), 0)
            
        return out_dists, out_commitments
    def get_max_preds(self, obj_dists, obj_labels, boxes_for_nms):
        """
        Get max non-background prediction
        :param obj_dists: [num_obj, num_classes] new probability distribution: O4
        :param obj_labels: [num_obj] the GT labels of the image
        :param boxes_for_nms: [num_obj, 4] boxes. We'll use this for NMS
        :return: obj_preds: [num_obj] argmax of that distribution: O4'
        """
        if self.training:
            # Whenever labels are 0 set to be max prediction
            obj_preds = obj_labels
            nonzero_pred = obj_dists[:, 1:].max(1)[1] + 1
            is_bg = (obj_preds.data == 0).nonzero()
            if is_bg.dim() > 0:
                obj_preds[is_bg.squeeze(1)] = nonzero_pred[is_bg.squeeze(1)]
        else:
            # Greedily take the max here amongst non-bgs
            obj_preds = obj_dists[:, 1:].max(1)[1] + 1

        # when sgdet is testing, do NMS as a post-processing step
        if boxes_for_nms is not None and not self.training:
            nms_thresh = 0.3
            is_overlap = nms_overlaps(boxes_for_nms.data).view(
                boxes_for_nms.size(0), boxes_for_nms.size(0),
                boxes_for_nms.size(1)).cpu().numpy() >= nms_thresh

            obj_preds = obj_preds[0].data.new(len(obj_preds)).fill_(0)
            out_dists_sampled = F.softmax(obj_dists, dim=1).data.cpu().numpy()
            out_dists_sampled[:, 0] = 0

            for i in range(obj_preds.size(0)):
                box_ind, cls_ind = np.unravel_index(out_dists_sampled.argmax(),
                                                    out_dists_sampled.shape)
                obj_preds[int(box_ind)] = int(cls_ind)
                out_dists_sampled[is_overlap[box_ind, :, cls_ind],
                                  cls_ind] = 0.0
                out_dists_sampled[
                    box_ind] = -1.0  # This way we won't re-sample
            obj_preds = Variable(obj_preds)

        return obj_preds
Exemple #3
0
    def forward(
            self,  # pylint: disable=arguments-differ
            inputs: PackedSequence,
            initial_state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
            labels=None,
            boxes_for_nms=None):
        """
        Parameters
        ----------
        inputs : PackedSequence, required.
            A tensor of shape (batch_size, num_timesteps, input_size)
            to apply the LSTM over.

        initial_state : Tuple[torch.Tensor, torch.Tensor], optional, (default = None)
            A tuple (state, memory) representing the initial hidden state and memory
            of the LSTM. Each tensor has shape (1, batch_size, output_dimension).

        Returns
        -------
        A PackedSequence containing a torch.FloatTensor of shape
        (batch_size, num_timesteps, output_dimension) representing
        the outputs of the LSTM per timestep and a tuple containing
        the LSTM state, with shape (1, batch_size, hidden_size) to
        match the Pytorch API.
        """
        if not isinstance(inputs, PackedSequence):
            raise ValueError('inputs must be PackedSequence but got %s' %
                             (type(inputs)))

        assert isinstance(inputs, PackedSequence)
        sequence_tensor, batch_lengths = inputs
        batch_size = batch_lengths[0]

        # We're just doing an LSTM decoder here so ignore states, etc
        if initial_state is None:
            previous_memory = Variable(sequence_tensor.data.new().resize_(
                batch_size, self.hidden_size).fill_(0))
            previous_state = Variable(sequence_tensor.data.new().resize_(
                batch_size, self.hidden_size).fill_(0))
        else:
            assert len(initial_state) == 2
            previous_state = initial_state[0].squeeze(0)
            previous_memory = initial_state[1].squeeze(0)

        previous_embed = self.obj_embed.weight[0, None].expand(batch_size, 100)

        if self.recurrent_dropout_probability > 0.0:
            dropout_mask = get_dropout_mask(self.recurrent_dropout_probability,
                                            previous_memory)
        else:
            dropout_mask = None

        # Only accumulating label predictions here, discarding everything else
        out_dists = []
        out_commitments = []

        end_ind = 0
        for i, l_batch in enumerate(batch_lengths):
            start_ind = end_ind
            end_ind = end_ind + l_batch

            if previous_memory.size(0) != l_batch:
                previous_memory = previous_memory[:l_batch]
                previous_state = previous_state[:l_batch]
                previous_embed = previous_embed[:l_batch]
                if dropout_mask is not None:
                    dropout_mask = dropout_mask[:l_batch]

            timestep_input = torch.cat(
                (sequence_tensor[start_ind:end_ind], previous_embed), 1)

            previous_state, previous_memory = self.lstm_equations(
                timestep_input,
                previous_state,
                previous_memory,
                dropout_mask=dropout_mask)

            pred_dist = self.out(previous_state)
            out_dists.append(pred_dist)

            if self.training:
                labels_to_embed = labels[start_ind:end_ind].clone()
                # Whenever labels are 0 set input to be our max prediction
                nonzero_pred = pred_dist[:, 1:].max(
                    1
                )[1] + 1  # +1: because the index is in 150-d but truth is 151-d
                is_bg = (labels_to_embed.data == 0).nonzero()
                if is_bg.dim() > 0:
                    # the 0 in labels is because they overlap with gt box lower than threshold, so assigned 0
                    # but for these entry, we should give them the maximum value within 151-d dists
                    labels_to_embed[is_bg.squeeze(1)] = nonzero_pred[
                        is_bg.squeeze(1)]
                out_commitments.append(labels_to_embed)
                previous_embed = self.obj_embed(labels_to_embed + 1)
            else:
                assert l_batch == 1
                out_dist_sample = F.softmax(pred_dist, dim=1)
                # if boxes_for_nms is not None:
                #     out_dist_sample[domains_allowed[i] == 0] = 0.0

                # Greedily take the max here amongst non-bgs
                best_ind = out_dist_sample[:, 1:].max(1)[1] + 1

                # if boxes_for_nms is not None and i < boxes_for_nms.size(0):
                #     best_int = int(best_ind.data[0])
                #     domains_allowed[i:, best_int] *= (1 - is_overlap[i, i:, best_int])
                out_commitments.append(best_ind)
                previous_embed = self.obj_embed(best_ind + 1)

        # Do NMS here as a post-processing step
        if boxes_for_nms is not None and not self.training:
            is_overlap = nms_overlaps(boxes_for_nms.data).view(
                boxes_for_nms.size(0), boxes_for_nms.size(0),
                boxes_for_nms.size(1)).cpu().numpy() >= self.nms_thresh
            # is_overlap[np.arange(boxes_for_nms.size(0)), np.arange(boxes_for_nms.size(0))] = False

            out_dists_sampled = F.softmax(torch.cat(out_dists, 0),
                                          1).data.cpu().numpy()
            out_dists_sampled[:, 0] = 0

            out_commitments = out_commitments[0].data.new(
                len(out_commitments)).fill_(0)

            for i in range(out_commitments.size(0)):
                box_ind, cls_ind = np.unravel_index(out_dists_sampled.argmax(),
                                                    out_dists_sampled.shape)
                out_commitments[int(box_ind)] = int(cls_ind)
                out_dists_sampled[is_overlap[box_ind, :, cls_ind],
                                  cls_ind] = 0.0
                out_dists_sampled[
                    box_ind] = -1.0  # This way we won't re-sample

            out_commitments = Variable(out_commitments)
        else:
            out_commitments = torch.cat(out_commitments, 0)

        return torch.cat(out_dists, 0), out_commitments