Exemplo n.º 1
0
    def two_scale_forward(self, inputs):
        assert 'images' in inputs

        x_1x = inputs['images']
        x_lo = ResizeX(x_1x, cfg.MODEL.MSCALE_LO_SCALE)

        p_lo, feats_lo = self._fwd(x_lo)
        p_1x, feats_hi = self._fwd(x_1x)

        feats_lo = scale_as(feats_lo, feats_hi)
        cat_feats = torch.cat([feats_lo, feats_hi], 1)
        logit_attn = self.scale_attn(cat_feats)
        logit_attn_lo = scale_as(logit_attn, p_lo)
        logit_attn_1x = scale_as(logit_attn, p_1x)

        p_lo = logit_attn_lo * p_lo
        p_lo = scale_as(p_lo, p_1x)
        joint_pred = p_lo + (1 - logit_attn_1x) * p_1x

        if self.training:
            assert 'gts' in inputs
            gts = inputs['gts']
            loss = self.criterion(joint_pred, gts)
            return loss
        else:
            return joint_pred, logit_attn_1x
Exemplo n.º 2
0
    def _fwd(self, x, aspp_lo=None, aspp_attn=None, scale_float=None):
        x_size = x.size()
        s2_features, _, final_features = self.backbone(x)
        aspp = self.aspp(final_features)

        if self.fuse_aspp and \
           aspp_lo is not None and aspp_attn is not None:
            aspp_attn = scale_as(aspp_attn, aspp)
            aspp_lo = scale_as(aspp_lo, aspp)
            aspp = aspp_attn * aspp_lo + (1 - aspp_attn) * aspp

        conv_aspp = self.bot_aspp(aspp)
        conv_s2 = self.bot_fine(s2_features)
        conv_aspp = Upsample(conv_aspp, s2_features.size()[2:])
        cat_s4 = [conv_s2, conv_aspp]
        cat_s4_attn = [conv_s2, conv_aspp]
        cat_s4 = torch.cat(cat_s4, 1)
        cat_s4_attn = torch.cat(cat_s4_attn, 1)

        final = self.final(cat_s4)
        scale_attn = self.scale_attn(cat_s4_attn)

        out = Upsample(final, x_size[2:])
        scale_attn = Upsample(scale_attn, x_size[2:])

        if self.attn_2b:
            logit_attn = scale_attn[:, 0:1, :, :]
            aspp_attn = scale_attn[:, 1:, :, :]
        else:
            logit_attn = scale_attn
            aspp_attn = scale_attn

        return out, logit_attn, aspp_attn, aspp
Exemplo n.º 3
0
    def two_scale_forward(self, inputs):
        assert 'images' in inputs

        x_1x = inputs['images']
        x_lo = ResizeX(x_1x, cfg.MODEL.MSCALE_LO_SCALE)

        p_lo, feats_lo = self._fwd(x_lo)
        p_1x, feats_hi = self._fwd(x_1x)

        feats_hi = scale_as(feats_hi, feats_lo)
        cat_feats = torch.cat([feats_lo, feats_hi], 1)
        logit_attn = self.scale_attn(cat_feats)
        logit_attn = scale_as(logit_attn, p_lo)

        p_lo = logit_attn * p_lo
        p_lo = scale_as(p_lo, p_1x)
        logit_attn = scale_as(logit_attn, p_1x)
        joint_pred = p_lo + (1 - logit_attn) * p_1x

        if self.training:
            assert 'gts' in inputs
            gts = inputs['gts']
            loss = self.criterion(joint_pred, gts)
            return loss
        else:
            # FIXME: should add multi-scale values for pred and attn
            return {'pred': joint_pred, 'attn_10x': logit_attn}
Exemplo n.º 4
0
    def _fwd(self, x, aspp_lo=None, aspp_attn=None):
        s2_features, s4_features, final_features = self.backbone(x)
        s2_features = self.convs2(s2_features)
        s4_features = self.convs4(s4_features)
        aspp = self.aspp(final_features)

        if self.fuse_aspp and \
           aspp_lo is not None and aspp_attn is not None:
            aspp_attn = scale_as(aspp_attn, aspp)
            aspp_lo = scale_as(aspp_lo, aspp)
            aspp = aspp_attn * aspp_lo + (1 - aspp_attn) * aspp

        x = self.conv_up1(aspp)
        x = Upsample2(x)
        x = torch.cat([x, s4_features], 1)
        x = self.conv_up2(x)
        x = Upsample2(x)
        x = torch.cat([x, s2_features], 1)
        up3 = self.conv_up3(x)

        out = self.conv_up5(up3)
        out = Upsample2(out)

        scale_attn = self.scale_attn(up3)
        scale_attn = Upsample2(scale_attn)

        if self.attn_2b:
            logit_attn = scale_attn[:, 0:1, :, :]
            aspp_attn = scale_attn[:, 1:, :, :]
        else:
            logit_attn = scale_attn
            aspp_attn = scale_attn

        return out, logit_attn, aspp_attn, aspp
Exemplo n.º 5
0
    def recurse_fuse_fwd(self, x, scales, aspp_lo=None, attn_lo=None):
        """
        recursive eval for n-scales

        target resolution is fixed at 1.0

        [0.5, 1.0]:
            p_0.5, aspp_0.5, attn_0.5 = fwd(attn,aspp=None)
            p_1.0 = recurse([1.0], aspp_0.5, attn_0.5)
                 p_1.0 = fwd(attn_0.5, aspp_0.5)
            output = attn_0.5 * p_0.5 + (1 - attn_0.5) * p_1.0
        """
        this_scale = scales.pop()
        if this_scale == 1.0:
            x_resize = x
        else:
            x_resize = ResizeX(x, this_scale)
        #p, attn, aspp = self._fwd(x_resize, attn_lo=attn_lo, aspp_lo=aspp_lo)
        p, attn_lo, aspp_attn, aspp_lo = self._fwd(x_resize, aspp_lo=aspp_lo, aspp_attn=attn_lo)

        if this_scale == 1.0:
            p_1x = p
            attn_1x = attn_lo
        else:
            p_1x = scale_as(p, x)
            attn_1x = scale_as(attn_lo, x)

        if len(scales) == 0:
            output = p_1x
        else:
            output = attn_1x * p_1x
            p_next, _ = self.recurse_fuse_fwd(x, scales,
                                              attn_lo=aspp_attn, aspp_lo=aspp_lo)
            output += (1 - attn_1x) * p_next
        return output, attn_1x
Exemplo n.º 6
0
    def _fwd(self, x, aspp_lo=None, aspp_attn=None, scale_float=None):
        _, _, final_features = self.backbone(x)
        attn = self.scale_attn(final_features)
        pred = self.cls_head(final_features)
        attn = scale_as(attn, x)
        pred = scale_as(pred, x)

        return pred, attn, None, None
Exemplo n.º 7
0
    def forward(self, inputs):
        assert 'images' in inputs
        x = inputs['images']

        _, _, high_level_features = self.backbone(x)
        cls_out, aux_out, _ = self.ocr(high_level_features)
        aux_out = scale_as(aux_out, x)
        cls_out = scale_as(cls_out, x)

        output_dict = {'pred': cls_out}
        return output_dict
Exemplo n.º 8
0
    def _forward_fused(self, inputs):
        """
        Combine multiple scales of predictions together with attention
        predicted jointly off of multi-scale features.
        """
        x_1x = inputs['images']

        # run 1x scale
        assert 1.0 in self.scales, 'expected one of scales to be 1.0'
        ps = {}
        ps[1.0], feats_1x = self._fwd(x_1x)
        concat_feats = [feats_1x]

        # run all other scales
        for scale in self.scales:
            if scale == 1.0:
                continue
            resized_x = ResizeX(x_1x, scale)
            p, feats = self._fwd(resized_x)
            ps[scale] = scale_as(p, x_1x)
            feats = scale_as(feats, feats_1x)
            concat_feats.append(feats)

        concat_feats = torch.cat(concat_feats, 1)
        attn_tensor = self.scale_attn(concat_feats)

        output = None
        for idx, scale in enumerate(self.scales):
            attn = attn_tensor[:, idx:idx + 1, :, :]
            attn_1x_scale = scale_as(attn, x_1x)
            if output is None:
                # logx.msg(f'ps[scale] shape {ps[scale].shape} '
                #         f'attn shape {attn_1x_scale.shape}')
                output = ps[scale] * attn_1x_scale
            else:
                output += ps[scale] * attn_1x_scale

        if self.training:
            assert 'gts' in inputs
            gts = inputs['gts']
            loss = self.criterion(output, gts)

            if cfg.LOSS.SUPERVISED_MSCALE_WT:
                for scale in self.scales:
                    loss_scale = self.criterion(ps[scale], gts, do_rmi=False)
                    loss += cfg.LOSS.SUPERVISED_MSCALE_WT * loss_scale
            return loss
        else:
            return output, attn
Exemplo n.º 9
0
    def two_scale_forward(self, inputs):
        assert 'images' in inputs
        x_1x = inputs['images']
        x_lo = ResizeX(x_1x, cfg.MODEL.MSCALE_LO_SCALE)

        pred_05x, aspp_lo, cat_s4_attn_lo = self._fwd_feature(x_lo)
        p_1x, aspp_1x, cat_s4_attn_1x = self._fwd_feature(x_1x)
        attn_lo, aspp_attn_lo = self._fwd_attn(x_lo, cat_s4_attn_lo)
        attn_1x, aspp_attn_1x = self._fwd_attn_rev(x_1x, cat_s4_attn_1x)
        #low 2 high
        p_lo = attn_lo * pred_05x.narrow(1, 0, self.attn_ch)
        p_lo = scale_as(p_lo, p_1x)
        logit_attn = scale_as(attn_lo, p_1x)
        joint_pred1 = p_lo + (1 - logit_attn) * p_1x.narrow(1, 0, self.attn_ch)
        #high 2 low
        p_hi = attn_1x * p_1x.narrow(1, self.attn_ch, self.attn_ch)
        p_lo = scale_as(pred_05x, p_1x)
        joint_pred2 = p_hi + (1 - attn_1x) * p_lo.narrow(
            1, self.attn_ch, self.attn_ch)

        joint_pred = self.classifier(
            torch.cat([joint_pred1, joint_pred2], dim=1))
        pred_05x = self.classifier(pred_05x)
        p_1x = self.classifier(p_1x)
        if self.training:
            assert 'gts' in inputs
            gts = inputs['gts']
            loss = self.criterion(joint_pred, gts)
            # Optionally, apply supervision to the multi-scale predictions
            # directly. Turn off RMI to keep things lightweight
            if cfg.LOSS.SUPERVISED_MSCALE_WT:
                scaled_pred_05x = scale_as(pred_05x, p_1x)
                loss_lo = self.criterion(scaled_pred_05x, gts, do_rmi=False)
                loss_hi = self.criterion(p_1x, gts, do_rmi=False)
                loss += cfg.LOSS.SUPERVISED_MSCALE_WT * loss_lo
                loss += cfg.LOSS.SUPERVISED_MSCALE_WT * loss_hi
            return loss
        else:
            output_dict = {
                'pred': joint_pred,
                'pred_05x': pred_05x,
                'pred_10x': p_1x,
                'attn_05x': attn_lo,
            }
            return output_dict
Exemplo n.º 10
0
    def forward(self, inputs):
        assert 'images' in inputs
        x = inputs['images']

        _, _, high_level_features = self.backbone(x)
        aspp = self.aspp(high_level_features)
        cls_out, aux_out, _ = self.ocr(aspp)
        aux_out = scale_as(aux_out, x)
        cls_out = scale_as(cls_out, x)

        if self.training:
            gts = inputs['gts']
            loss = cfg.LOSS.OCR_ALPHA * self.criterion(aux_out, gts) + \
                self.criterion(cls_out, gts)
            return loss
        else:
            output_dict = {'pred': cls_out}
            return output_dict
Exemplo n.º 11
0
    def _fwd(self, x, aspp_lo=None, aspp_attn=None, scale_float=None):
        x_size = x.size()
        s2_features, _, final_features = self.backbone(x)

        aspp = self.aspp(final_features)

        if self.fuse_aspp and \
           aspp_lo is not None and aspp_attn is not None:
            aspp_attn = scale_as(aspp_attn, aspp)
            aspp_lo = scale_as(aspp_lo, aspp)
            aspp = aspp_attn * aspp_lo + (1 - aspp_attn) * aspp

        conv_aspp_ = self.bot_aspp(aspp)
        conv_s2 = self.bot_fine(s2_features)
        # spatial attention here.
        #conv_aspp_ = self.asnb(conv_s2, conv_aspp_)
        conv_aspp_ = Upsample(conv_aspp_, conv_aspp_.size()[2:])
        conv_aspp_shape = conv_aspp_.shape
        conv_aspp_ = self.adnb([conv_aspp_],
                              masks=[conv_aspp_.new_zeros((conv_aspp_.shape[0], conv_aspp_.shape[2], conv_aspp_.shape[3]), dtype=torch.bool)],
                              pos_embeds=[None])
        conv_aspp_ = conv_aspp_.transpose(-1, -2).view(conv_aspp_shape)

        conv_aspp = Upsample(conv_aspp_, s2_features.size()[2:])

        cat_s4 = [conv_s2, conv_aspp]
        cat_s4_attn = [conv_s2, conv_aspp]
        cat_s4 = torch.cat(cat_s4, 1)
        cat_s4_attn = torch.cat(cat_s4_attn, 1)

        final = self.final(cat_s4)
        scale_attn = self.scale_attn(cat_s4_attn)

        out = Upsample(final, x_size[2:])
        scale_attn = Upsample(scale_attn, x_size[2:])

        if self.attn_2b:
            logit_attn = scale_attn[:, 0:1, :, :]
            aspp_attn = scale_attn[:, 1:, :, :]
        else:
            logit_attn = scale_attn
            aspp_attn = scale_attn

        return out, logit_attn, aspp_attn, aspp
Exemplo n.º 12
0
    def two_scale_forward(self, inputs):
        """
        Do we supervised both aux outputs, lo and high scale?
        Should attention be used to combine the aux output?
        Normally we only supervise the combined 1x output

        If we use attention to combine the aux outputs, then
        we can use normal weighting for aux vs. cls outputs
        """
        assert 'images' in inputs
        x_1x = inputs['images']

        x_lo = ResizeX(x_1x, cfg.MODEL.MSCALE_LO_SCALE)
        lo_outs = self._fwd(x_lo)
        pred_05x = lo_outs['cls_out']
        p_lo = pred_05x
        aux_lo = lo_outs['aux_out']
        logit_attn = lo_outs['logit_attn']
        attn_05x = logit_attn

        hi_outs = self._fwd(x_1x)
        pred_10x = hi_outs['cls_out']
        p_1x = pred_10x
        aux_1x = hi_outs['aux_out']

        p_lo = logit_attn * p_lo
        aux_lo = logit_attn * aux_lo
        p_lo = scale_as(p_lo, p_1x)
        aux_lo = scale_as(aux_lo, p_1x)

        logit_attn = scale_as(logit_attn, p_1x)

        # combine lo and hi predictions with attention
        joint_pred = p_lo + (1 - logit_attn) * p_1x
        joint_aux = aux_lo + (1 - logit_attn) * aux_1x

        output_dict = {
            'pred': joint_pred,
            'pred_05x': pred_05x,
            'pred_10x': pred_10x,
            'attn_05x': attn_05x,
        }
        return output_dict
Exemplo n.º 13
0
    def forward(self, inputs):
        assert 'images' in inputs
        x = inputs['images']

        _, _, high_level_features = self.backbone(x)
        cls_out, aux_out, _ = self.ocr(high_level_features)
        aux_out = scale_as(aux_out, x)
        cls_out = scale_as(cls_out, x)

        if self.training:
            gts = inputs['gts']
            aux_loss = self.criterion(aux_out, gts,
                                      do_rmi=cfg.LOSS.OCR_AUX_RMI)
            main_loss = self.criterion(cls_out, gts)
            loss = cfg.LOSS.OCR_ALPHA * aux_loss + main_loss
            return loss
        else:
            output_dict = {'pred': cls_out}
            return output_dict
Exemplo n.º 14
0
    def two_scale_forward(self, inputs):
        assert 'images' in inputs

        x_1x = inputs['images']
        x_lo = ResizeX(x_1x, cfg.MODEL.MSCALE_LO_SCALE)

        pred_05x, attn_05x, aspp_attn, aspp_lo = \
            self._fwd(x_lo)

        p_1x, _, _, _ = self._fwd(x_1x, aspp_lo=aspp_lo, aspp_attn=aspp_attn)

        p_lo = attn_05x * pred_05x
        p_lo = scale_as(p_lo, p_1x)
        logit_attn = scale_as(attn_05x, p_1x)
        joint_pred = p_lo + (1 - logit_attn) * p_1x

        joint_pred = self.classifier(joint_pred)
        pred_05x = self.classifier(pred_05x)
        p_1x = self.classifier(p_1x)

        if self.training:
            assert 'gts' in inputs
            gts = inputs['gts']
            loss = self.criterion(joint_pred, gts)

            # Optionally, apply supervision to the multi-scale predictions
            # directly. Turn off RMI to keep things lightweight
            if cfg.LOSS.SUPERVISED_MSCALE_WT:
                scaled_pred_05x = scale_as(pred_05x, p_1x)
                loss_lo = self.criterion(scaled_pred_05x, gts, do_rmi=False)
                loss_hi = self.criterion(p_1x, gts, do_rmi=False)
                loss += cfg.LOSS.SUPERVISED_MSCALE_WT * loss_lo
                loss += cfg.LOSS.SUPERVISED_MSCALE_WT * loss_hi
            return loss
        else:
            output_dict = {
                'pred': joint_pred,
                'pred_05x': pred_05x,
                'pred_10x': p_1x,
                'attn_05x': attn_05x,
            }
            return output_dict
Exemplo n.º 15
0
    def _fwd(self, x, aspp_lo=None, aspp_attn=None):
        """
        Run the network, and return final feature and logit predictions
        """
        x_size = x.size()
        s2_features, _, final_features = self.backbone(x)
        aspp = self.aspp(final_features)

        if self.fuse_aspp and \
           aspp_lo is not None and aspp_attn is not None:
            aspp_attn = scale_as(aspp_attn, aspp)
            aspp_lo = scale_as(aspp_lo, aspp)
            aspp = aspp_attn * aspp_lo + (1 - aspp_attn) * aspp

        conv_aspp = self.bot_aspp(aspp)
        conv_s2 = self.bot_fine(s2_features)
        conv_aspp = Upsample(conv_aspp, s2_features.size()[2:])
        cat_s4 = [conv_s2, conv_aspp]
        cat_s4 = torch.cat(cat_s4, 1)
        final = self.final(cat_s4)
        out = Upsample(final, x_size[2:])

        return out, cat_s4
Exemplo n.º 16
0
    def forward(self, inputs):
        x = inputs['images']
        _, _, final_features = self.backbone(x)
        pred = self.seg_head(final_features)
        pred = scale_as(pred, x)

        if self.training:
            assert 'gts' in inputs
            gts = inputs['gts']
            loss = self.criterion(pred, gts)
            return loss
        else:
            output_dict = {'pred': pred}
            return output_dict
Exemplo n.º 17
0
    def _forward_paired(self, inputs, scales):
        """
        Hierarchical form of attention where we only predict attention for
        pairs of scales at a time.

        At inference time we can combine many scales together.
        """
        x_1x = inputs['images']

        # run 1x scale
        assert 1.0 in scales, 'expected one of scales to be 1.0'
        ps = {}
        all_feats = {}
        ps[1.0], all_feats[1.0] = self._fwd(x_1x)

        # run all other scales
        for scale in scales:
            if scale == 1.0:
                continue
            resized_x = ResizeX(x_1x, scale)
            p, feats = self._fwd(resized_x)
            ps[scale] = scale_as(p, x_1x)
            all_feats[scale] = scale_as(feats, all_feats[1.0])

        # Generate all attention outputs
        output = None
        num_scales = len(scales)
        attn = {}
        for idx in range(num_scales - 1):
            lo_scale = scales[idx]
            hi_scale = scales[idx + 1]
            concat_feats = torch.cat(
                [all_feats[lo_scale], all_feats[hi_scale]], 1)
            p_attn = self.scale_attn(concat_feats)
            attn[lo_scale] = scale_as(p_attn, x_1x)

        # Normalize attentions
        norm_attn = {}
        last_attn = None
        for idx in range(num_scales - 1):
            lo_scale = scales[idx]
            hi_scale = scales[idx + 1]
            attn_lo = attn[lo_scale][:, 0:1, :, :]
            attn_hi = attn[lo_scale][:, 1:2, :, :]
            if last_attn is None:
                norm_attn[lo_scale] = attn_lo
                norm_attn[hi_scale] = attn_hi
            else:
                normalize_this_attn = last_attn / (attn_lo + attn_hi)
                norm_attn[lo_scale] = attn_lo * normalize_this_attn
                norm_attn[hi_scale] = attn_hi * normalize_this_attn
            last_attn = attn_hi

        # Apply attentions
        for idx, scale in enumerate(scales):
            attn = norm_attn[scale]
            attn_1x_scale = scale_as(attn, x_1x)
            if output is None:
                output = ps[scale] * attn_1x_scale
            else:
                output += ps[scale] * attn_1x_scale

        if self.training:
            assert 'gts' in inputs
            gts = inputs['gts']
            loss = self.criterion(output, gts)
            return loss
        else:
            return output, attn
Exemplo n.º 18
0
    def nscale_forward(self, inputs, scales):
        """
        Hierarchical attention, primarily used for getting best inference
        results.

        We use attention at multiple scales, giving priority to the lower
        resolutions. For example, if we have 4 scales {0.5, 1.0, 1.5, 2.0},
        then evaluation is done as follows:

              p_joint = attn_1.5 * p_1.5 + (1 - attn_1.5) * down(p_2.0)
              p_joint = attn_1.0 * p_1.0 + (1 - attn_1.0) * down(p_joint)
              p_joint = up(attn_0.5 * p_0.5) * (1 - up(attn_0.5)) * p_joint

        The target scale is always 1.0, and 1.0 is expected to be part of the
        list of scales. When predictions are done at greater than 1.0 scale,
        the predictions are downsampled before combining with the next lower
        scale.

        Inputs:
          scales - a list of scales to evaluate
          inputs - dict containing 'images', the input, and 'gts', the ground
                   truth mask

        Output:
          If training, return loss, else return prediction + attention
        """
        x_1x = inputs['images']

        assert 1.0 in scales, 'expected 1.0 to be the target scale'
        # Lower resolution provides attention for higher rez predictions,
        # so we evaluate in order: high to low
        scales = sorted(scales, reverse=True)

        pred = None
        output_dict = {}

        for s in scales:
            x = ResizeX(x_1x, s)
            bs = x.shape[0]
            scale_float = torch.Tensor(bs).fill_(s)
            p, attn, _aspp_attn, _aspp = self._fwd(x, scale_float=scale_float)

            output_dict[fmt_scale('pred', s)] = p
            if s != 2.0:
                output_dict[fmt_scale('attn', s)] = attn

            if pred is None:
                pred = p
            elif s >= 1.0:
                # downscale previous
                pred = scale_as(pred, p)
                pred = attn * p + (1 - attn) * pred
            else:
                # upscale current
                p = attn * p
                p = scale_as(p, pred)
                attn = scale_as(attn, pred)
                pred = p + (1 - attn) * pred

        if self.training:
            assert 'gts' in inputs
            gts = inputs['gts']
            loss = self.criterion(pred, gts)
            return loss
        else:
            output_dict['pred'] = pred
            return output_dict
Exemplo n.º 19
0
    def nscale_forward(self, inputs, scales):
        """
        Hierarchical attention, primarily used for getting best inference
        results.

        We use attention at multiple scales, giving priority to the lower
        resolutions. For example, if we have 4 scales {0.5, 1.0, 1.5, 2.0},
        then evaluation is done as follows:

              p_joint = attn_1.5 * p_1.5 + (1 - attn_1.5) * down(p_2.0)
              p_joint = attn_1.0 * p_1.0 + (1 - attn_1.0) * down(p_joint)
              p_joint = up(attn_0.5 * p_0.5) * (1 - up(attn_0.5)) * p_joint

        The target scale is always 1.0, and 1.0 is expected to be part of the
        list of scales. When predictions are done at greater than 1.0 scale,
        the predictions are downsampled before combining with the next lower
        scale.

        Inputs:
          scales - a list of scales to evaluate
          inputs - dict containing 'images', the input, and 'gts', the ground
                   truth mask

        Output:
          If training, return loss, else return prediction + attention
        """
        x_1x = inputs['images']

        assert 1.0 in scales, 'expected 1.0 to be the target scale'
        # Lower resolution provides attention for higher rez predictions,
        # so we evaluate in order: high to low
        scales = sorted(scales, reverse=True)
        pred = None
        last_feats = None

        for idx, s in enumerate(scales):
            x = ResizeX(x_1x, s)
            p, feats = self._fwd(x)

            # Generate attention prediction
            if idx > 0:
                assert last_feats is not None
                # downscale feats
                last_feats = scale_as(last_feats, feats)
                cat_feats = torch.cat([feats, last_feats], 1)
                attn = self.scale_attn(cat_feats)
                attn = scale_as(attn, p)

            if pred is None:
                # This is the top scale prediction
                pred = p
            elif s >= 1.0:
                # downscale previous
                pred = scale_as(pred, p)
                pred = attn * p + (1 - attn) * pred
            else:
                # upscale current
                p = attn * p
                p = scale_as(p, pred)
                attn = scale_as(attn, pred)
                pred = p + (1 - attn) * pred

            last_feats = feats

        if self.training:
            assert 'gts' in inputs
            gts = inputs['gts']
            loss = self.criterion(pred, gts)
            return loss
        else:
            # FIXME: should add multi-scale values for pred and attn
            return {'pred': pred, 'attn_10x': attn}
Exemplo n.º 20
0
    def nscale_forward(self, inputs, scales):
        """
        Hierarchical attention, primarily used for getting best inference
        results.

        We use attention at multiple scales, giving priority to the lower
        resolutions. For example, if we have 4 scales {0.5, 1.0, 1.5, 2.0},
        then evaluation is done as follows:

              p_joint = attn_1.5 * p_1.5 + (1 - attn_1.5) * down(p_2.0)
              p_joint = attn_1.0 * p_1.0 + (1 - attn_1.0) * down(p_joint)
              p_joint = up(attn_0.5 * p_0.5) * (1 - up(attn_0.5)) * p_joint

        The target scale is always 1.0, and 1.0 is expected to be part of the
        list of scales. When predictions are done at greater than 1.0 scale,
        the predictions are downsampled before combining with the next lower
        scale.

        Inputs:
          scales - a list of scales to evaluate
          inputs - dict containing 'images', the input, and 'gts', the ground
                   truth mask

        Output:
          If training, return loss, else return prediction + attention
        """
        x_1x = inputs['images']

        assert 1.0 in scales, 'expected 1.0 to be the target scale'
        # Lower resolution provides attention for higher rez predictions,
        # so we evaluate in order: high to low
        scales = sorted(scales, reverse=True)

        pred = None
        aux = None
        output_dict = {}
        print("scales in forward")
        print(scales)

        for s in scales:
            x = ResizeX(x_1x, s)
            outs = self._fwd(x)
            cls_out = outs['cls_out']
            attn_out = outs['logit_attn']
            aux_out = outs['aux_out']

            output_dict[fmt_scale('pred', s)] = cls_out
            if s != 2.0:
                output_dict[fmt_scale('attn', s)] = attn_out

            if pred is None:
                pred = cls_out
                aux = aux_out
            elif s >= 1.0:
                # downscale previous
                pred = scale_as(pred, cls_out)
                pred = attn_out * cls_out + (1 - attn_out) * pred
                aux = scale_as(aux, cls_out)
                aux = attn_out * aux_out + (1 - attn_out) * aux
            else:
                # s < 1.0: upscale current
                cls_out = attn_out * cls_out
                aux_out = attn_out * aux_out

                cls_out = scale_as(cls_out, pred)
                aux_out = scale_as(aux_out, pred)
                attn_out = scale_as(attn_out, pred)

                pred = cls_out + (1 - attn_out) * pred
                aux = aux_out + (1 - attn_out) * aux

        if self.training:
            assert 'gts' in inputs
            gts = inputs['gts']
            loss = cfg.LOSS.OCR_ALPHA * self.criterion(aux, gts) + \
                   self.criterion(pred, gts)
            return loss
        else:
            output_dict['pred'] = pred
            return output_dict
Exemplo n.º 21
0
    def two_scale_forward(self, inputs):
        """
        Do we supervised both aux outputs, lo and high scale?
        Should attention be used to combine the aux output?
        Normally we only supervise the combined 1x output

        If we use attention to combine the aux outputs, then
        we can use normal weighting for aux vs. cls outputs
        """
        assert 'images' in inputs
        x_1x = inputs['images']

        x_lo = ResizeX(x_1x, cfg.MODEL.MSCALE_LO_SCALE)
        lo_outs = self._fwd(x_lo)
        pred_05x = lo_outs['cls_out']
        p_lo = pred_05x
        aux_lo = lo_outs['aux_out']
        logit_attn = lo_outs['logit_attn']
        attn_05x = logit_attn

        hi_outs = self._fwd(x_1x)
        pred_10x = hi_outs['cls_out']
        p_1x = pred_10x
        aux_1x = hi_outs['aux_out']

        p_lo = logit_attn * p_lo
        aux_lo = logit_attn * aux_lo
        p_lo = scale_as(p_lo, p_1x)
        aux_lo = scale_as(aux_lo, p_1x)

        logit_attn = scale_as(logit_attn, p_1x)

        # combine lo and hi predictions with attention
        joint_pred = p_lo + (1 - logit_attn) * p_1x
        joint_aux = aux_lo + (1 - logit_attn) * aux_1x

        if self.training:
            gts = inputs['gts']
            do_rmi = cfg.LOSS.OCR_AUX_RMI
            aux_loss = self.criterion(joint_aux, gts, do_rmi=do_rmi)

            # Optionally turn off RMI loss for first epoch to try to work
            # around cholesky errors of singular matrix
            do_rmi_main = True  # cfg.EPOCH > 0
            main_loss = self.criterion(joint_pred, gts, do_rmi=do_rmi_main)
            loss = cfg.LOSS.OCR_ALPHA * aux_loss + main_loss

            # Optionally, apply supervision to the multi-scale predictions
            # directly. Turn off RMI to keep things lightweight
            if cfg.LOSS.SUPERVISED_MSCALE_WT:
                scaled_pred_05x = scale_as(pred_05x, p_1x)
                loss_lo = self.criterion(scaled_pred_05x, gts, do_rmi=False)
                loss_hi = self.criterion(pred_10x, gts, do_rmi=False)
                loss += cfg.LOSS.SUPERVISED_MSCALE_WT * loss_lo
                loss += cfg.LOSS.SUPERVISED_MSCALE_WT * loss_hi
            return loss
        else:
            output_dict = {
                'pred': joint_pred,
                'pred_05x': pred_05x,
                'pred_10x': pred_10x,
                'attn_05x': attn_05x,
            }
            return output_dict