Ejemplo n.º 1
0
 def forward(self, weight):
     if is_tracing_state():
         with no_jit_trace():
             return weight.mul_(self.binary_mask)
     tmp_tensor = self._calc_training_binary_mask(weight)
     self.binary_mask = self._calc_binary_mask(weight)
     return apply_binary_mask_impl(tmp_tensor, weight)
    def forward(self, inputs):
        data_sub1 = inputs
        features_sub1 = self.highres_branch(data_sub1)

        data_sub2 = F.interpolate(data_sub1, self._input_size_hw_ds2,
                                  **self.sampling_params)
        features_sub2 = self.mediumres_branch(data_sub2)

        # Contrary to the ICNet paper Fig.2 , the low-resolution branch does not receive separate
        # 4x-downsampled image input, but instead reuses feature maps from the medium-resolution
        # branch.

        data_sub4 = F.interpolate(features_sub2, self._input_size_hw_ds32,
                                  **self.sampling_params)
        features_sub4 = self.lowres_branch(data_sub4)

        if self.training:
            fused_features_sub42, label_scores_ds16 = self.cff42(
                features_sub4, features_sub2)
            fused_features_sub421, label_scores_ds8 = self.cff421(
                fused_features_sub42, features_sub1)

            fused_features_ds4 = F.interpolate(fused_features_sub421,
                                               self._input_size_hw_ds4,
                                               **self.sampling_params)
            label_scores_ds4 = self.conv6_cls(fused_features_ds4)

            return OrderedDict([("ds4", label_scores_ds4),
                                ("ds8", label_scores_ds8),
                                ("ds16", label_scores_ds16)])

        fused_features_sub42 = self.cff42(features_sub4, features_sub2)
        fused_features_sub421 = self.cff421(fused_features_sub42,
                                            features_sub1)

        fused_features_ds4 = F.interpolate(fused_features_sub421,
                                           self._input_size_hw_ds4,
                                           **self.sampling_params)
        label_scores_ds4 = self.conv6_cls(fused_features_ds4)
        label_scores = F.interpolate(label_scores_ds4, self._input_size_hw,
                                     **self.sampling_params)
        if is_tracing_state() and parse_version(
                torch.__version__) >= parse_version("1.1.0"):
            # While exporting, add extra post-processing layers into the graph
            # so that the model outputs class probabilities instead of class scores
            softmaxed = F.softmax(label_scores, dim=1)
            return softmaxed
        return label_scores
    def forward(self, x):
        blocks = []
        for i, down in enumerate(self.down_path):
            x = down(x)
            if i != len(self.down_path) - 1:
                blocks.append(x)
                x = F.max_pool2d(x, 2)

        for i, up in enumerate(self.up_path):
            x = up(x, blocks[-i - 1])

        x = self.last(x)
        if is_tracing_state() and parse_version(torch.__version__) >= parse_version("1.1.0"):
            # While exporting, add extra post-processing layers into the graph
            # so that the model outputs class probabilities instead of class scores
            softmaxed = F.softmax(x, dim=1)
            return softmaxed
        return x
Ejemplo n.º 4
0
 def forward(self, conv_weight):
     if is_tracing_state():
         with no_jit_trace():
             return conv_weight
     return conv_weight