Exemplo n.º 1
0
 def execute(self, in_feat):
     z_img = self.model(in_feat)
     z = jt.reshape(z_img, [z_img.shape[0], (-1)])
     zn = z[:, 0:self.latent_dim]
     zc_logits = z[:, self.latent_dim:]
     zc = nn.softmax(zc_logits, dim=1)
     return (zn, zc, zc_logits)
Exemplo n.º 2
0
    def ohem_conf_loss(self, conf_data, conf_t, pos, num):
        # Compute max conf across batch for hard negative mining
        batch_conf = conf_data.view(-1, self.num_classes)
        if cfg.ohem_use_most_confident:
            # i.e. max(softmax) along classes > 0
            batch_conf = nn.softmax(batch_conf, dim=1)
            loss_c = batch_conf[:, 1:].max(dim=1)
        else:
            # i.e. -softmax(class 0 confidence)
            loss_c = log_sum_exp(batch_conf) - batch_conf[:, 0]

        # Hard Negative Mining
        loss_c = loss_c.view(num, -1)
        loss_c[pos] = 0  # filter out pos boxes
        loss_c[conf_t < 0] = 0  # filter out neutrals (conf_t = -1)
        loss_idx, _ = loss_c.argsort(1, descending=True)
        idx_rank, _ = loss_idx.argsort(1)
        num_pos = pos.int32().sum(1, keepdims=True)
        num_neg = jt.clamp(self.negpos_ratio * num_pos, max_v=pos.shape[1] - 1)
        neg = idx_rank < num_neg.expand_as(idx_rank)
        neg = neg.int()

        # Just in case there aren't enough negatives, don't start using positives as negatives
        neg[pos] = 0
        neg[conf_t < 0] = 0  # Filter out neutrals
        neg = neg.bool()
        # Confidence Loss Including Positive and Negative Examples
        pos_idx = pos.unsqueeze(2).expand_as(conf_data)
        neg_idx = neg.unsqueeze(2).expand_as(conf_data)
        conf_p = conf_data[(pos_idx.int() + neg_idx.int()) > 0].view(
            -1, self.num_classes)
        targets_weighted = conf_t[(pos.int() + neg.int()) > 0]
        loss_c = cross_entropy_loss(conf_p, targets_weighted, reduction='none')

        if cfg.use_class_balanced_conf:
            # Lazy initialization
            if self.class_instances is None:
                self.class_instances = jt.zeros(self.num_classes,
                                                device=targets_weighted.device)

            classes, counts = targets_weighted.unique(return_counts=True)

            for _cls, _cnt in zip(classes.numpy(), counts.numpy()):
                self.class_instances[_cls] += _cnt

            self.total_instances += targets_weighted.shape[0]

            weighting = 1 - (self.class_instances[targets_weighted] /
                             self.total_instances)
            weighting = jt.clamp(weighting, min_v=1 / self.num_classes)

            # If you do the math, the average weight of self.class_instances is this
            avg_weight = (self.num_classes - 1) / self.num_classes

            loss_c = (loss_c * weighting).sum() / avg_weight
        else:
            loss_c = loss_c.sum()

        return cfg.conf_alpha * loss_c
    def execute(self, x, boxes):
        """
        Arguments:
            x (tuple[tensor, tensor]): x contains the class logits
                and the box_regression from the model.
            boxes (list[BoxList]): bounding boxes that are used as
                reference, one for each image

        Returns:
            results (list[BoxList]): one BoxList for each image, containing
                the extra fields labels and scores
        """
        class_logits, box_regression = x
        class_prob = nn.softmax(class_logits, -1)

        # TODO think about a representation of batch of boxes
        image_shapes = [box.size for box in boxes]
        boxes_per_image = [len(box) for box in boxes]
        concat_boxes = jt.contrib.concat([a.bbox for a in boxes], dim=0)

        if self.cls_agnostic_bbox_reg:
            box_regression = box_regression[:, -4:]
        proposals = self.box_coder.decode(
            box_regression.reshape(sum(boxes_per_image), -1), concat_boxes
        )
        if self.cls_agnostic_bbox_reg:
            proposals = proposals.repeat(1, class_prob.shape[1])

        num_classes = class_prob.shape[1]

        proposals = proposals.split(boxes_per_image, dim=0)
        class_prob = class_prob.split(boxes_per_image, dim=0)
        
        results = []
        for prob, boxes_per_img, image_shape in zip(
            class_prob, proposals, image_shapes
        ):
            boxlist = self.prepare_boxlist(boxes_per_img, prob, image_shape)
            boxlist = boxlist.clip_to_image(remove_empty=False)

            # print("boxlist",boxlist.bbox.mean(),boxlist.bbox.shape)

            if not self.bbox_aug_enabled:  # If bbox aug is enabled, we will do it later
                boxlist = self.filter_results(boxlist, num_classes)
                # boxlist = self.filter_results_v2(boxlist, num_classes)
                # boxlist = self.select_over_all_levels(boxlist)
            results.append(boxlist)
        return results
Exemplo n.º 4
0
    def predict(self, images,score_thresh=0.7,nms_thresh = 0.3):
        N = images.shape[0]
        img_size = (images.shape[-1],images.shape[-2])
        rpn_locs, rpn_scores,roi_cls_locs, roi_scores, rois, roi_indices = self.execute(images)
        roi_cls_locs = roi_cls_locs.reshape(roi_cls_locs.shape[0],-1,4)
        probs = nn.softmax(roi_scores,dim=-1)
        rois = rois.unsqueeze(1).repeat(1,self.n_class,1)
        cls_bbox = loc2bbox(rois.reshape(-1,4),roi_cls_locs.reshape(-1,4))
        cls_bbox[:,0::2] = jt.clamp(cls_bbox[:,0::2],min_v=0,max_v=img_size[0])
        cls_bbox[:,1::2] = jt.clamp(cls_bbox[:,1::2],min_v=0,max_v=img_size[1])
        
        cls_bbox = cls_bbox.reshape(roi_cls_locs.shape)
        
        results = []
        for i in range(N):
            index = jt.where(roi_indices==i)[0]
            score = probs[index,:]
            bbox = cls_bbox[index,:,:]
            boxes = []
            scores = []
            labels = []
            for j in range(1,self.n_class):
                bbox_j = bbox[:,j,:]
                score_j = score[:,j]
                mask = jt.where(score_j>score_thresh)[0]
                bbox_j = bbox_j[mask,:]
                score_j = score_j[mask]
                dets = jt.contrib.concat([bbox_j,score_j.unsqueeze(1)],dim=1)
                keep = jt.nms(dets,nms_thresh)
                bbox_j = bbox_j[keep]
                score_j = score_j[keep]
                label_j = jt.ones_like(score_j).int32()*j
                boxes.append(bbox_j)
                scores.append(score_j)
                labels.append(label_j)
            
            boxes = jt.contrib.concat(boxes,dim=0)
            scores = jt.contrib.concat(scores,dim=0)
            labels = jt.contrib.concat(labels,dim=0)
            results.append((boxes,scores,labels))

        return results
    
    
    
    
        
Exemplo n.º 5
0
    def execute(self, x):
        idn = x
        x = self.conv1(x)

        b, c, h, w = x.size()
        n = h*w
        x = x.view(b, c, h*w)   # b * c * n 

        attn = self.linear_0(x) # b, k, n
        attn = nn.softmax(attn, dim=-1) # b, k, n

        attn = attn / (1e-9 + attn.sum(dim=1, keepdims=True)) #  # b, k, n
        x = self.linear_1(attn) # b, c, n

        x = x.view(b, c, h, w)
        x = self.conv2(x)
        x = x + idn
        x = self.relu(x)
        return x
Exemplo n.º 6
0
    def _forward(self):
        #########################################################################################################
        ## If we use `pytorch` pretrained model, the input should be RGB, and normalized by the following code:
        ##      normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
        ##                                       std=[0.229, 0.224, 0.225])
        ## Note: input[channel] = (input[channel] - mean[channel]) / std[channel], input is (0,1), not (0,255)
        #########################################################################################################
        inputs = (jt.array(self.inputs) / 255.0 - self.mean) / self.std
        [p1, p2, p3, p4] = self.backbone(inputs)
        feature = p1

        alignHs = np.vstack(self.featAlignMatrixs)
        indexs = np.hstack([
            idx * np.ones(len(m), )
            for idx, m in enumerate(self.featAlignMatrixs)
        ])

        rois = affine_align_gpu(feature, indexs,
                                (self.size_align, self.size_align), alignHs)

        if self.cat_skeleton:
            skeletons = np.vstack(self.skeletonFeats)
            skeletons = jt.array(skeletons).float()
            rois = jt.contrib.concat([rois, skeletons], 1)
        netOutput = self.segnet(rois)

        if self.is_training():
            loss = self._calcLoss(netOutput)
            return loss
        else:
            netOutput = nn.softmax(netOutput, 1)
            netOutput = jt.detach(netOutput)
            #output = self._getMaskOutput(netOutput)
            if not self.benchmark:
                output = self._getMaskOutput(netOutput)
            else:
                output = netOutput
                #output.sync()
            if self.visCount < 0:
                self._visualizeOutput(netOutput)
                self.visCount += 1

            return output
Exemplo n.º 7
0
    def execute(self, x):
        b, n, c = x.shape
        qkv = self.qkv(x).reshape(b, n, 3, self.num_heads,
                                  c // self.num_heads).permute(2, 0, 3, 1, 4)

        q, k, v = qkv[0], qkv[1], qkv[2]

        # attn = nn.bmm(q,k.transpose(0,1,3,2))*self.scale
        attn = nn.bmm_transpose(q, k) * self.scale

        attn = nn.softmax(attn, dim=-1)

        attn = self.attn_drop(attn)

        out = nn.bmm(attn, v)
        out = out.transpose(0, 2, 1, 3).reshape(b, n, c)
        out = self.proj(out)
        out = self.proj_drop(out)

        return out
Exemplo n.º 8
0
    def execute(self, x, mask=None):
        b, n, _ = x.shape
        h = self.heads
        q, k, v = self.to_qkv(x).chunk(3, dim=-1)

        q = q.reshape(b, n, h, -1)
        q = q.transpose(0, 2, 1, 3)

        k = k.reshape(b, n, h, -1)
        k = k.transpose(0, 2, 1, 3)

        v = v.reshape(b, n, h, -1)
        v = v.transpose(0, 2, 1, 3)

        #b,h,n,d
        d = q.shape[-1]
        q = q.reshape(b * h, n, d)
        k = k.reshape(b * h, n, d).transpose(0, 2, 1)

        dots = nn.bmm(q, k).reshape(b, h, n, n)
        dots = dots * self.scale

        if mask is not None:
            mask = nn.pad(mask.flatten(1), (1, 0), value=1)
            assert mask.shape[-1] == dots.shape[
                -1], 'mask has incorrect shapes'
            mask = mask.unsqueeze(1) * mask.unsqueeze(2)
            dots.masked_fill_(~mask, float('-inf'))
            del mask

        attn = nn.softmax(dots, dim=-1)

        out = nn.bmm(attn.reshape(b * h, n, n),
                     v.reshape(b * h, n, d)).reshape(b, h, n, d)
        out = out.transpose(0, 2, 1, 3).reshape(b, n, h * d)
        out = self.to_out(out)
        return out
Exemplo n.º 9
0
    def execute(self, x, pos):
        """
        Args:
            x: Tensor, (B, c, 2048)
            pos: Tensor, (B, 2048, 3)
        """
        identity = x
        x_bcn = self.linear_start(x)
        b, dim, n = x_bcn.shape
        pos_bcn = pos.transpose(0, 2, 1)
        _, idx_knn = knn(pos, pos, self.n_knn)
        # idx_knn = knn(pos_bcn, self.n_knn)

        key = self.conv_key(x_bcn)
        value = self.conv_value(x_bcn)
        query = self.conv_query(x_bcn)

        # key = index_points(key.transpose(0, 2, 1), idx_knn).transpose(0, 3, 1, 2)  # (b, c, n, n_knn)
        key = grouping_operation(key, idx_knn)
        # print('key.shape', key.shape)
        qk_rel = query.reshape((b, -1, n, 1)) - key


        pos_rel = pos_bcn.reshape((b, -1, n, 1)) - \
                  grouping_operation(pos_bcn, idx_knn)
        # index_points(pos, idx_knn).transpose(0, 3, 1, 2)
        pos_embedding = self.pos_mlp(pos_rel)

        attention = self.attn_mlp(qk_rel + pos_embedding)
        attention = nn.softmax(attention, dim=-1)

        value = value.reshape((b, -1, n, 1)) + pos_embedding

        agg = (value * attention).sum(dim=-1)
        y = self.linear_end(agg)

        return y + identity
Exemplo n.º 10
0
    def execute(self, x, img_size):
        """Forward Region Proposal Network.

        Here are notations.

        * :math:`N` is batch size.
        * :math:`C` channel size of the input.
        * :math:`H` and :math:`W` are height and witdh of the input feature.
        * :math:`A` is number of anchors assigned to each pixel.

        Args:
            x : The Features extracted from images.
                Its shape is :math:`(N, C, H, W)`.
            img_size (tuple of ints): A tuple :obj:`height, width`,
                which contains image size after scaling.

        Returns:
            This is a tuple of five following values.

            * **rpn_locs**: Predicted bounding box offsets and scales for \
                anchors. Its shape is :math:`(N, H W A, 4)`.
            * **rpn_scores**:  Predicted foreground scores for \
                anchors. Its shape is :math:`(N, H W A, 2)`.
            * **rois**: A bounding box array containing coordinates of \
                proposal boxes.  This is a concatenation of bounding box \
                arrays from multiple images in the batch. \
                Its shape is :math:`(R', 4)`. Given :math:`R_i` predicted \
                bounding boxes from the :math:`i` th image, \
                :math:`R' = \\sum _{i=1} ^ N R_i`.
            * **roi_indices**: An array containing indices of images to \
                which RoIs correspond to. Its shape is :math:`(R',)`.
            * **anchor**: Coordinates of enumerated shifted anchors. \
                Its shape is :math:`(H W A, 4)`.

        """
        n, _, hh, ww = x.shape
        anchor = _enumerate_shifted_anchor(self.anchor_base, self.feat_stride,
                                           hh, ww)
        anchor = jt.array(anchor)

        n_anchor = anchor.shape[0] // (hh * ww)
        h = nn.relu(self.conv1(x))

        rpn_locs = self.loc(h)

        rpn_locs = rpn_locs.permute(0, 2, 3, 1).view(n, -1, 4)
        rpn_scores = self.score(h)
        rpn_scores = rpn_scores.permute(0, 2, 3, 1)
        rpn_softmax_scores = nn.softmax(rpn_scores.view(
            n, hh, ww, n_anchor, 2),
                                        dim=4)
        rpn_fg_scores = rpn_softmax_scores[:, :, :, :, 1]
        rpn_fg_scores = rpn_fg_scores.view(n, -1)
        rpn_scores = rpn_scores.view(n, -1, 2)
        rois = []
        roi_indices = []
        for i in range(n):
            roi = self.proposal_layer(rpn_locs[i], rpn_fg_scores[i], anchor,
                                      img_size, 1.0 / self.feat_stride)
            batch_index = i * jt.ones((len(roi), ), dtype='int32')
            rois.append(roi)
            roi_indices.append(batch_index)

        rois = jt.contrib.concat(rois, dim=0)
        roi_indices = jt.contrib.concat(roi_indices, dim=0)
        return rpn_locs, rpn_scores, rois, roi_indices, anchor
Exemplo n.º 11
0
    def detect_objects(self, predicted_locs, predicted_scores, min_score,
                       max_overlap, top_k):
        """ Decipher the 8732 locations and class scores (output of ths SSD300) to detect objects.
        For each class, perform Non-Maximum Suppression (NMS) on boxes that are above a minimum threshold.

        Args:
            predicted_locs: predicted locations/boxes w.r.t the 8732 prior boxes, a tensor of dimensions (N, 8732, 4)
            predicted_scores: class scores for each of the encoded locations/boxes, a tensor of dimensions (N, 8732, n_classes)
            min_score: minimum threshold for a box to be considered a match for a certain class
            max_overlap: maximum overlap two boxes can have so that the one with the lower score is not suppressed via NMS
            top_k: if there are a lot of resulting detection across all classes, keep only the top 'k'
        
        Return: detections (boxes, labels, and scores), lists of length batch_size
        """
        batch_size = predicted_locs.shape[0]
        n_priors = self.priors_cxcy.shape[0]
        predicted_scores = nn.softmax(predicted_scores, dim=2)
        all_images_boxes = list()
        all_images_labels = list()
        all_images_scores = list()
        predicted_locs = predicted_locs.data
        predicted_scores = predicted_scores.data
        assert (n_priors == predicted_locs.shape[1])
        for i in range(batch_size):
            decoded_locs = cxcy_to_xy(
                gcxgcy_to_cxcy(predicted_locs[i], self.priors_cxcy))
            image_boxes = list()
            image_labels = list()
            image_scores = list()
            for c in range(1, self.n_classes):
                class_scores = predicted_scores[i][:, c]
                score_above_min_score = (class_scores >= min_score)
                n_above_min_score = score_above_min_score.sum()
                if (n_above_min_score == 0):
                    continue
                class_scores = class_scores[score_above_min_score]
                class_decoded_locs = decoded_locs[score_above_min_score]
                sort_ind = np.argsort(-class_scores, axis=0)
                class_scores = class_scores[sort_ind]
                class_decoded_locs = class_decoded_locs[sort_ind]
                overlap = find_jaccard_overlap(class_decoded_locs,
                                               class_decoded_locs)
                suppress = np.zeros((n_above_min_score)).astype('int')
                for box in range(class_decoded_locs.shape[0]):
                    if (suppress[box] == 1):
                        continue
                    suppress = np.maximum(suppress,
                                          (overlap[box] > max_overlap))
                    suppress[box] = 0
                image_boxes.append(
                    class_decoded_locs[(1 - suppress).astype('bool')])
                image_labels.append(int((1 - suppress).sum()) * [c])
                image_scores.append(class_scores[(1 -
                                                  suppress).astype('bool')])
            if (len(image_boxes) == 0):
                image_boxes.append(np.array([[0.0, 0.0, 1.0, 1.0]]))
                image_labels.append(np.array([0]))
                image_scores.append(np.array([0.0]))
            image_boxes = np.concatenate(image_boxes, 0)
            image_labels = np.concatenate(image_labels, 0)
            image_scores = np.concatenate(image_scores, 0)
            n_objects = image_scores.shape[0]
            if (n_objects > top_k):
                sort_ind = np.argsort(-image_scores, axis=0)
                image_scores = image_scores[sort_ind][:top_k]
                image_boxes = image_boxes[sort_ind][:top_k]
                image_labels = image_labels[sort_ind][:top_k]
            all_images_boxes.append(image_boxes)
            all_images_labels.append(image_labels)
            all_images_scores.append(image_scores)
        return (all_images_boxes, all_images_labels, all_images_scores)
Exemplo n.º 12
0
 def execute(self, input, target):
     bs_idx = jt.array(range(input.shape[0]))
     ret = (-jt.log(nn.softmax(input, dim=1)))[bs_idx, target]
     if self.reduction != None:
         ret = jt.mean(ret) if self.reduction == 'mean' else jt.sum(ret)
     return ret
Exemplo n.º 13
0
    def execute(self, x):
        """ The input should be of size [batch_size, 3, img_h, img_w] """
        _, _, img_h, img_w = x.shape
        cfg._tmp_img_h = img_h
        cfg._tmp_img_w = img_w

        with timer.env('backbone'):
            outs = self.backbone(x)

        if cfg.fpn is not None:
            with timer.env('fpn'):
                # Use backbone.selected_layers because we overwrote self.selected_layers
                outs = [outs[i] for i in cfg.backbone.selected_layers]
                outs = self.fpn(outs)
        proto_out = None
        if cfg.mask_type == mask_type.lincomb and cfg.eval_mask_branch:
            with timer.env('proto'):
                proto_x = x if self.proto_src is None else outs[self.proto_src]

                if self.num_grids > 0:
                    grids = self.grid.repeat(proto_x.shape[0], 1, 1, 1)
                    proto_x = jt.contrib.concat([proto_x, grids], dim=1)

                proto_out = self.proto_net(proto_x)
                proto_out = cfg.mask_proto_prototype_activation(proto_out)

                if cfg.mask_proto_prototypes_as_features:
                    # Clone here because we don't want to permute this, though idk if contiguous makes this unnecessary
                    proto_downsampled = proto_out.clone()

                    if cfg.mask_proto_prototypes_as_features_no_grad:
                        proto_downsampled = proto_out.detach()

                # Move the features last so the multiplication is easy
                proto_out = proto_out.permute(0, 2, 3, 1)

                if cfg.mask_proto_bias:
                    bias_shape = [x for x in proto_out.shape]
                    bias_shape[-1] = 1
                    proto_out = jt.contrib.concat(
                        [proto_out, jt.ones(bias_shape)], -1)

        with timer.env('pred_heads'):
            pred_outs = {'loc': [], 'conf': [], 'mask': [], 'priors': []}

            if cfg.use_mask_scoring:
                pred_outs['score'] = []

            if cfg.use_instance_coeff:
                pred_outs['inst'] = []

            for idx, pred_layer in zip(self.selected_layers,
                                       self.prediction_layers):
                pred_x = outs[idx]

                if cfg.mask_type == mask_type.lincomb and cfg.mask_proto_prototypes_as_features:
                    # Scale the prototypes down to the current prediction layer's size and add it as inputs
                    proto_downsampled = nn.interpolate(
                        proto_downsampled,
                        size=outs[idx].shape[2:],
                        mode='bilinear',
                        align_corners=False)
                    # proto_downsampled = interpolate(proto_downsampled, size=outs[idx].shape[2:], mode='bilinear', align_corners=False)

                    pred_x = jt.contrib.concat([pred_x, proto_downsampled],
                                               dim=1)

                # A hack for the way dataparallel works
                if cfg.share_prediction_module and pred_layer is not self.prediction_layers[
                        0]:
                    pred_layer.parent = [self.prediction_layers[0]]

                p = pred_layer(pred_x)

                for k, v in p.items():
                    pred_outs[k].append(v)

        for k, v in pred_outs.items():
            pred_outs[k] = jt.contrib.concat(v, -2)

        if proto_out is not None:
            pred_outs['proto'] = proto_out

        #print('hh',pred_outs)
        #print()
        if self.is_training():
            # For the extra loss functions
            if cfg.use_class_existence_loss:
                pred_outs['classes'] = self.class_existence_fc(
                    outs[-1].mean(dim=(2, 3)))

            if cfg.use_semantic_segmentation_loss:
                pred_outs['segm'] = self.semantic_seg_conv(outs[0])

            return pred_outs
        else:
            if cfg.use_mask_scoring:
                pred_outs['score'] = jt.sigmoid(pred_outs['score'])
            if cfg.use_focal_loss:
                if cfg.use_sigmoid_focal_loss:
                    # Note: even though conf[0] exists, this mode doesn't train it so don't use it
                    pred_outs['conf'] = jt.sigmoid(pred_outs['conf'])
                    if cfg.use_mask_scoring:
                        pred_outs['conf'] *= pred_outs['score']
                elif cfg.use_objectness_score:
                    # See focal_loss_sigmoid in multibox_loss.py for details
                    objectness = jt.sigmoid(pred_outs['conf'][:, :, 0])
                    pred_outs['conf'][:, :, 1:] = objectness.unsqueeze(
                        2) * nn.softmax(pred_outs['conf'][:, :, 1:], -1)
                    pred_outs['conf'][:, :, 0] = 1 - objectness
                else:
                    pred_outs['conf'] = nn.softmax(pred_outs['conf'], -1)
            else:

                if cfg.use_objectness_score:
                    objectness = jt.sigmoid(pred_outs['conf'][:, :, 0])

                    pred_outs['conf'][:, :, 1:] = (objectness > 0.10).unsqueeze(-1) \
                        * nn.softmax(pred_outs['conf'][:, :, 1:], dim=-1)

                else:
                    pred_outs['conf'] = nn.softmax(pred_outs['conf'], -1)
            return self.detect(pred_outs, self)
    def execute(
        self,
        query,
        key=None,
        value=None,
        key_padding_mask=None,
        incremental_state=None,
        need_weights=True,
        static_kv=False,
        attn_mask=None,
        before_softmax=False,
        need_head_weights=False,
    ):
        if need_head_weights:
            need_weights = True

        tgt_len, bsz, embed_dim = query.shape
        assert embed_dim == self.embed_dim
        assert list(query.shape) == [tgt_len, bsz, embed_dim]

        assert incremental_state is None, "TODO: incremental_state is not None"
        saved_state = None

        if self.self_attention:
            q = self.q_proj(query)
            k = self.k_proj(query)
            v = self.v_proj(query)
        elif self.encoder_decoder_attention:
            # encoder-decoder attention
            q = self.q_proj(query)
            if key is None:
                assert value is None
                k = v = None
            else:
                k = self.k_proj(key)
                v = self.v_proj(key)
        else:
            assert key is not None and value is not None
            q = self.q_proj(query)
            k = self.k_proj(key)
            v = self.v_proj(value)
        q = q * self.scaling

        assert self.bias_k is None, "TODO: self.bias_k is not None:"

        q = q.view(tgt_len, bsz * self.num_heads,
                   self.head_dim).transpose(1, 0, 2)
        if k is not None:
            k = k.view(-1, bsz * self.num_heads,
                       self.head_dim).transpose(1, 0, 2)
        if v is not None:
            v = v.view(-1, bsz * self.num_heads,
                       self.head_dim).transpose(1, 0, 2)

        assert saved_state is None, "TODO: saved_state is not None"
        assert k is not None
        src_len = k.shape[1]

        assert key_padding_mask is None, "TODO: key_padding_mask is not None"
        assert not self.add_zero_attn, "TODO: self.add_zero_attn=True"

        attn_weights = nn.bmm(q, k.transpose(0, 2, 1))

        assert list(
            attn_weights.shape) == [bsz * self.num_heads, tgt_len, src_len]

        assert attn_mask is None, "TODO: attn_mask is not None"
        assert key_padding_mask is None, "TODO: key_padding_mask is not None"

        if before_softmax:
            return attn_weights, v

        attn_weights_float = nn.softmax(attn_weights, dim=-1)
        attn_weights = attn_weights_float.type_as(attn_weights)

        assert v is not None
        attn = nn.bmm(attn_weights, v)
        assert list(
            attn.shape) == [bsz * self.num_heads, tgt_len, self.head_dim]
        if self.onnx_trace and attn.shape[1] == 1:
            # when ONNX tracing a single decoder step (sequence length == 1)
            # the transpose is a no-op copy before view, thus unnecessary
            attn = attn.view(tgt_len, bsz, embed_dim)
        else:
            attn = attn.transpose(1, 0, 2).view(tgt_len, bsz, embed_dim)
        attn = self.out_proj(attn)
        attn_weights = None
        if need_weights:
            attn_weights = attn_weights_float.view(bsz, self.num_heads,
                                                   tgt_len, src_len).transpose(
                                                       1, 0, 2, 3)
            if not need_head_weights:
                # average attention weights over heads
                attn_weights = attn_weights.mean(dims=[0])

        return attn, attn_weights