def init_encoder(self): dummy_batch = torch.zeros((2, 3, self.hparams.patch_size, self.hparams.patch_size)) encoder_name = self.hparams.encoder_name if encoder_name == "cpc_encoder": return cpc_resnet101(dummy_batch) return torchvision_ssl_encoder(encoder_name, return_all_feature_maps=self.hparams.task == "amdim")
def init_encoder(self): dummy_batch = torch.zeros((2, 3, self.hparams.patch_size, self.hparams.patch_size)) encoder_name = self.hparams.encoder if encoder_name == 'cpc_encoder': return CPCResNet101(dummy_batch) else: return torchvision_ssl_encoder(encoder_name, return_all_feature_maps=self.hparams.task == 'amdim')
def __init__(self, encoder=None): super().__init__() if encoder is None: encoder = torchvision_ssl_encoder('resnet50') # Encoder self.encoder = encoder # Projector self.projector = MLP() # Predictor self.predictor = MLP(input_dim=256)
def __init__(self, encoder='resnet50', encoder_out_dim=2048, projector_hidden_size=4096, projector_out_dim=256): super().__init__() if isinstance(encoder, str): encoder = torchvision_ssl_encoder(encoder) # Encoder self.encoder = encoder # Projector self.projector = MLP(encoder_out_dim, projector_hidden_size, projector_out_dim) # Predictor self.predictor = MLP(projector_out_dim, projector_hidden_size, projector_out_dim)
def __init__( self, encoder: Optional[nn.Module] = None, input_dim: int = 2048, hidden_size: int = 4096, output_dim: int = 256, ) -> None: super().__init__() if encoder is None: encoder = torchvision_ssl_encoder('resnet50') # Encoder self.encoder = encoder # Projector self.projector = MLP(input_dim, hidden_size, output_dim) # Predictor self.predictor = MLP(output_dim, hidden_size, output_dim)
def init_encoder(self): dummy_batch = torch.zeros((2, self.hparams.image_channels, self.hparams.image_height, self.hparams.image_height)) encoder_name = self.hparams.encoder if encoder_name == 'amdim_encoder': encoder = AMDIMEncoder( dummy_batch, num_channels=self.hparams.image_channels, encoder_feature_dim=self.hparams.encoder_feature_dim, embedding_fx_dim=self.hparams.embedding_fx_dim, conv_block_depth=self.hparams.conv_block_depth, encoder_size=self.hparams.image_height, use_bn=self.hparams.use_bn ) encoder.init_weights() return encoder else: return torchvision_ssl_encoder(encoder_name, return_all_feature_maps=True)
def init_encoder(self): encoder_name = self.hparams.encoder return torchvision_ssl_encoder( encoder_name, return_all_feature_maps=self.hparams.task == "amdim")