def __init__(self, config): super().__init__() self.config = config self.batch_size = config.optim_params.batch_size self.encoder, self.viewmaker, self.system, self.pretrain_config = self.load_pretrained_model() resnet = self.pretrain_config.model_params.resnet_version if resnet == 'resnet18': if self.config.model_params.use_prepool: if self.pretrain_config.model_params.resnet_small: num_features = 512 * 4 * 4 else: num_features = 512 * 7 * 7 else: num_features = 512 else: raise Exception(f'resnet {resnet} not supported.') self.train_dataset, self.val_dataset = datasets.get_image_datasets( config.data_params.dataset, default_augmentations=self.pretrain_config.data_params.default_augmentations or False, ) if not self.pretrain_config.model_params.resnet_small: self.encoder = nn.Sequential(*list(self.encoder.children())[:-1]) # keep pooling layer self.encoder = self.encoder.eval() self.viewmaker = self.viewmaker.eval() # linear evaluation freezes pretrained weights utils.frozen_params(self.encoder) utils.frozen_params(self.viewmaker) self.num_features = num_features self.model = self.create_model()
def __init__(self, config): super(TransferViewMakerSystem, self).__init__() self.config = config self.batch_size = config.optim_params.batch_size self.encoder, self.pretrain_config = self.load_pretrained_model() resnet = self.pretrain_config.model_params.resnet_version if resnet == 'resnet18': if self.config.model_params.use_prepool: if self.pretrain_config.model_params.resnet_small: num_features = 512 * 4 * 4 else: num_features = 512 * 7 * 7 else: num_features = 512 else: raise Exception(f'resnet {resnet} not supported.') if not self.pretrain_config.model_params.resnet_small: self.encoder = nn.Sequential(*list(self.encoder.children())[:-1]) # keep pooling layer # Freeze encoder for linear evaluation. self.encoder = self.encoder.eval() utils.frozen_params(self.encoder) default_augmentations = self.pretrain_config.data_params.default_augmentations if self.config.data_params.force_default_views or default_augmentations == DotMap(): default_augmentations = 'all' self.train_dataset, self.val_dataset = datasets.get_image_datasets( config.data_params.dataset, default_augmentations=default_augmentations, ) self.num_features = num_features self.model = self.create_model()
def __init__(self, config): super().__init__() self.config = config self.batch_size = config.optim_params.batch_size self.encoder, self.system, self.pretrain_config = self.load_pretrained_model( ) resnet = self.pretrain_config.model_params.resnet_version if resnet == 'resnet18': if self.config.model_params.use_prepool: if self.pretrain_config.model_params.resnet_small: num_features = 512 * 4 * 4 else: num_features = 512 * 2 * 2 else: num_features = 512 elif resnet == 'resnet50': num_features = 2048 else: raise Exception(f'resnet {resnet} not supported.') if not self.pretrain_config.model_params.resnet_small: raise NotImplementedError() if not self.config.optim_params.supervised: self.encoder = self.encoder.eval() frozen_params(self.encoder) self.num_features = num_features self.train_dataset, self.val_dataset = self.create_datasets() self.model = self.create_model()
def __init__(self, config): super().__init__() self.config = config self.device = f'cuda:{config.gpu_device}' if config.cuda else 'cpu' self.train_dataset, self.val_dataset = datasets.get_datasets( config.data_params.dataset) self.encoder = self.load_pretrained_model() utils.frozen_params(self.encoder) self.model = self.create_model()
def __init__(self, config): super().__init__() self.config = config self.batch_size = config.optim_params.batch_size self.encoder, self.viewmaker, self.system, self.pretrain_config = self.load_pretrained_model( ) resnet = self.pretrain_config.model_params.resnet_version if resnet == 'resnet18': if self.config.model_params.use_prepool: if self.pretrain_config.model_params.resnet_small: num_features = 512 * 4 * 4 else: num_features = 512 * 2 * 2 else: num_features = 512 elif resnet == 'resnet50': if self.config.model_params.use_prepool: num_features = 2048 * 4 * 4 else: num_features = 2048 else: raise Exception(f'resnet {resnet} not supported.') if not self.pretrain_config.model_params.resnet_small: if self.config.model_params.use_prepool: cut_ix = -2 else: cut_ix = -1 self.encoder = nn.Sequential( *list(self.encoder.children())[:cut_ix]) self.encoder = self.encoder.eval() frozen_params(self.encoder) frozen_params(self.viewmaker) self.num_features = num_features self.train_dataset, self.val_dataset = self.create_datasets() self.model = self.create_model()