コード例 #1
0
    def init_encoder(self):
        dummy_batch = torch.zeros((2, 3, self.hparams.patch_size, self.hparams.patch_size))

        encoder_name = self.hparams.encoder_name
        if encoder_name == "cpc_encoder":
            return cpc_resnet101(dummy_batch)
        return torchvision_ssl_encoder(encoder_name, return_all_feature_maps=self.hparams.task == "amdim")
コード例 #2
0
    def init_encoder(self):
        dummy_batch = torch.zeros((2, 3, self.hparams.patch_size, self.hparams.patch_size))

        encoder_name = self.hparams.encoder
        if encoder_name == 'cpc_encoder':
            return CPCResNet101(dummy_batch)
        else:
            return torchvision_ssl_encoder(encoder_name, return_all_feature_maps=self.hparams.task == 'amdim')
コード例 #3
0
    def __init__(self, encoder=None):
        super().__init__()

        if encoder is None:
            encoder = torchvision_ssl_encoder('resnet50')
        # Encoder
        self.encoder = encoder
        # Projector
        self.projector = MLP()
        # Predictor
        self.predictor = MLP(input_dim=256)
コード例 #4
0
    def __init__(self, encoder='resnet50', encoder_out_dim=2048, projector_hidden_size=4096, projector_out_dim=256):
        super().__init__()

        if isinstance(encoder, str):
            encoder = torchvision_ssl_encoder(encoder)
        # Encoder
        self.encoder = encoder
        # Projector
        self.projector = MLP(encoder_out_dim, projector_hidden_size, projector_out_dim)
        # Predictor
        self.predictor = MLP(projector_out_dim, projector_hidden_size, projector_out_dim)
コード例 #5
0
    def __init__(
        self,
        encoder: Optional[nn.Module] = None,
        input_dim: int = 2048,
        hidden_size: int = 4096,
        output_dim: int = 256,
    ) -> None:
        super().__init__()

        if encoder is None:
            encoder = torchvision_ssl_encoder('resnet50')
        # Encoder
        self.encoder = encoder
        # Projector
        self.projector = MLP(input_dim, hidden_size, output_dim)
        # Predictor
        self.predictor = MLP(output_dim, hidden_size, output_dim)
コード例 #6
0
    def init_encoder(self):
        dummy_batch = torch.zeros((2, self.hparams.image_channels, self.hparams.image_height,
                                   self.hparams.image_height))
        encoder_name = self.hparams.encoder

        if encoder_name == 'amdim_encoder':
            encoder = AMDIMEncoder(
                dummy_batch,
                num_channels=self.hparams.image_channels,
                encoder_feature_dim=self.hparams.encoder_feature_dim,
                embedding_fx_dim=self.hparams.embedding_fx_dim,
                conv_block_depth=self.hparams.conv_block_depth,
                encoder_size=self.hparams.image_height,
                use_bn=self.hparams.use_bn
            )
            encoder.init_weights()
            return encoder
        else:
            return torchvision_ssl_encoder(encoder_name, return_all_feature_maps=True)
コード例 #7
0
 def init_encoder(self):
     encoder_name = self.hparams.encoder
     return torchvision_ssl_encoder(
         encoder_name, return_all_feature_maps=self.hparams.task == "amdim")