Ejemplo n.º 1
0
 def forward(self,
             x,
             out_feat_keys: List[str] = None) -> List[torch.Tensor]:
     if isinstance(x, MultiDimensionalTensor):
         out = get_tunk_forward_interpolated_outputs(
             input_type=self.model_config.INPUT_TYPE,
             interpolate_out_feat_key_name="res5",
             remove_padding_before_feat_key_name="avgpool",
             feat=x,
             feature_blocks=self._feature_blocks,
             # FSDP has its own activation checkpoint method: disable vissl's method here.
             use_checkpointing=False,
             checkpointing_splits=0,
         )
     else:
         model_input = transform_model_input_data_type(
             x, self.model_config.INPUT_TYPE)
         out = get_trunk_forward_outputs(
             feat=model_input,
             out_feat_keys=out_feat_keys,
             feature_blocks=self._feature_blocks,
             # FSDP has its own activation checkpoint method: disable vissl's method here.
             use_checkpointing=False,
             checkpointing_splits=0,
         )
     return out
Ejemplo n.º 2
0
 def forward(
     self, x: torch.Tensor, out_feat_keys: List[str] = None
 ) -> List[torch.Tensor]:
     if isinstance(x, MultiDimensionalTensor):
         out = get_tunk_forward_interpolated_outputs(
             input_type=self.model_config.INPUT_TYPE,
             interpolate_out_feat_key_name="res5",
             remove_padding_before_feat_key_name="avgpool",
             feat=x,
             feature_blocks=self._feature_blocks,
             feature_mapping=self.feat_eval_mapping,
             use_checkpointing=self.use_checkpointing,
             checkpointing_splits=self.num_checkpointing_splits,
         )
     else:
         model_input = transform_model_input_data_type(
             x, self.model_config.INPUT_TYPE
         )
         out = get_trunk_forward_outputs(
             feat=model_input,
             out_feat_keys=out_feat_keys,
             feature_blocks=self._feature_blocks,
             feature_mapping=self.feat_eval_mapping,
             use_checkpointing=self.use_checkpointing,
             checkpointing_splits=self.num_checkpointing_splits,
         )
     return out