def pretrain_encoder_init( self, group_option: str, lr=1e-6, weight_decay=1e-5, multiplier=300, warmup_max=10, num_clusters=10, num_subheads=10, iic_weight=1, disable_contrastive=False, ctemperature=1, ctype: str = "linear", ptype: str = "mlp", extract_position: str = "Conv5", ): # adding optimizer and scheduler self._extract_position = extract_position self._feature_extractor = UNetFeatureExtractor(self._extract_position) self._projector_contrastive = ProjectionHead( input_dim=UNet.dimension_dict[self._extract_position], output_dim=256, head_type=ptype) # noqa self._projector_iic = ClusterHead( input_dim=UNet.dimension_dict[self._extract_position], num_clusters=num_clusters, head_type=ctype, T=ctemperature, num_subheads=num_subheads) self._optimizer = torch.optim.Adam( itertools.chain( self._model.parameters(), # noqa self._projector_contrastive.parameters(), self._projector_iic.parameters()), # noqa lr=lr, weight_decay=weight_decay) # noqa self._scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( self._optimizer, # noqa self._max_epoch_train_encoder - warmup_max, 0) self._scheduler = GradualWarmupScheduler(self._optimizer, multiplier, warmup_max, self._scheduler) # noqa self._group_option = group_option # noqa self._disable_contrastive = disable_contrastive # set augmentation method as `total_freedom = True` assert hasattr(self._pretrain_loader.dataset._transform, "_total_freedom") # noqa self._pretrain_loader.dataset._transform._total_freedom = True # noqa self._pretrain_loader_iter = iter(self._pretrain_loader) # noqa # contrastive loss self._contrastive_criterion = SupConLoss() # iic weight self._iic_weight = iic_weight
def finetune_network_init(self, lr: float = 1e-7, weight_decay: float = 1e-5, multiplier: int = 200, warmup_max=10): self._optimizer = torch.optim.Adam(self._model.parameters(), lr=lr, weight_decay=weight_decay) self._scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( self._optimizer, self._max_epoch_train_finetune - warmup_max, 5e-7) self._scheduler = GradualWarmupScheduler(self._optimizer, multiplier, warmup_max, self._scheduler) self._sup_criterion = KL_div() # set augmentation method as `total_freedom = True` assert hasattr(self._fine_tune_loader.dataset._transform, "_total_freedom") # noqa self._fine_tune_loader.dataset._transform._total_freedom = True # noqa self._fine_tune_loader_iter = iter(self._fine_tune_loader) # noqa
def pretrain_decoder_init( self, lr: float = 1e-6, weight_decay: float = 0.0, multiplier: int = 300, warmup_max=10, ptype="mlp", extract_position="Up_conv3", enable_grad_from="Up5", ): # feature_exactor self._extract_position = extract_position self._feature_extractor = UNetFeatureExtractor(self._extract_position) projector_input_dim = UNet.dimension_dict[extract_position] # if disable_encoder's gradient self._enable_grad_from = enable_grad_from self._projector = LocalProjectionHead(projector_input_dim, head_type=ptype, output_size=(4, 4)) # noqa self._optimizer = torch.optim.Adam(itertools.chain( self._model.parameters(), self._projector.parameters()), lr=lr, weight_decay=weight_decay) self._scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( self._optimizer, self._max_epoch_train_decoder - warmup_max, 0) self._scheduler = GradualWarmupScheduler(self._optimizer, multiplier, warmup_max, self._scheduler) # set augmentation method as `total_freedom = False` assert hasattr(self._pretrain_loader.dataset._transform, "_total_freedom") # noqa self._pretrain_loader.dataset._transform._total_freedom = False # noqa self._pretrain_loader_iter = iter(self._pretrain_loader) # noqa # contrastive loss self._contrastive_criterion = SupConLoss()
def _init_scheduler(self, optimizer): scheduler_dict = self._config.get("Scheduler", None) if scheduler_dict is None: return else: scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( self._optimizer, T_max=self._config["Trainer"]["max_epoch"] - self._config["Scheduler"]["warmup_max"], eta_min=1e-7) scheduler = GradualWarmupScheduler( optimizer, scheduler_dict["multiplier"], total_epoch=scheduler_dict["warmup_max"], after_scheduler=scheduler) self._scheduler = scheduler
class ContrastTrainer(Trainer): RUN_PATH = Path(PROJECT_PATH) / "runs" def __init__(self, model: nn.Module, pretrain_loader: T_loader, fine_tune_loader: T_loader, val_loader: DataLoader, save_dir: str = "base", max_epoch_train_encoder: int = 100, max_epoch_train_decoder: int = 100, max_epoch_train_finetune: int = 100, num_batches: int = 256, device: str = "cpu", configuration=None, train_encoder: bool = True, train_decoder: bool = True): """ ContrastTraining Trainer :param model: nn.module network to be pretrained :param pretrain_loader: all unlabeled data under ContrastiveBatchSampler :param fine_tune_loader: a fraction of labeled data for finetuning, with InfiniteSampler :param val_loader: validation data :param save_dir: main save_die :param max_epoch_train_encoder: max_epoch to be trained for encoder training :param max_epoch_train_decoder: max_epoch to be trained for decoder training :param max_epoch_train_finetune: max_epoch to be trained for finetuning :param num_batches: num_batches used in training :param device: cpu or cuda :param configuration: configuration dict """ super().__init__(model, save_dir, None, num_batches, device, configuration) # noqa self._pretrain_loader = pretrain_loader self._fine_tune_loader = fine_tune_loader self._val_loader = val_loader self._max_epoch_train_encoder = max_epoch_train_encoder self._max_epoch_train_decoder = max_epoch_train_decoder self._max_epoch_train_finetune = max_epoch_train_finetune self._register_buffer("train_encoder", train_encoder) self._register_buffer("train_decoder", train_decoder) self._register_buffer("train_encoder_done", False) self._register_buffer("train_decoder_done", False) self._pretrain_encoder_storage = Storage(csv_save_dir=os.path.join( self._save_dir, "pretrain_encoder"), csv_name="encoder.csv") self._pretrain_decoder_storage = Storage(csv_save_dir=os.path.join( self._save_dir, "pretrain_decoder"), csv_name="decoder.csv") self._finetune_storage = Storage(csv_save_dir=os.path.join( self._save_dir, "finetune"), csv_name="finetune.csv") def pretrain_encoder_init(self, group_option: str, lr=1e-6, weight_decay=1e-5, multiplier=300, warmup_max=10, ptype="mlp", extract_position="Conv5"): # adding optimizer and scheduler self._extract_position = extract_position self._feature_extractor = UNetFeatureExtractor(self._extract_position) self._projector = ProjectionHead( input_dim=UNet.dimension_dict[self._extract_position], output_dim=256, head_type=ptype) # noqa self._optimizer = torch.optim.Adam( itertools.chain( self._model.parameters(), # noqa self._projector.parameters()), # noqa lr=lr, weight_decay=weight_decay) # noqa self._scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( self._optimizer, # noqa self._max_epoch_train_encoder - warmup_max, 0) self._scheduler = GradualWarmupScheduler(self._optimizer, multiplier, warmup_max, self._scheduler) # noqa self._group_option = group_option # noqa # set augmentation method as `total_freedom = True` assert hasattr(self._pretrain_loader.dataset._transform, "_total_freedom") # noqa self._pretrain_loader.dataset._transform._total_freedom = True # noqa self._pretrain_loader_iter = iter(self._pretrain_loader) # noqa # contrastive loss self._contrastive_criterion = SupConLoss() def pretrain_encoder_run(self): self.to(self._device) self._model.disable_grad_all() self._model.enable_grad(from_="Conv1", util=self._extract_position) for self._cur_epoch in range(self._start_epoch, self._max_epoch_train_encoder): pretrain_encoder_dict = PretrainEncoderEpoch( model=self._model, projection_head=self._projector, optimizer=self._optimizer, pretrain_encoder_loader=self._pretrain_loader_iter, contrastive_criterion=self._contrastive_criterion, num_batches=self._num_batches, cur_epoch=self._cur_epoch, device=self._device, group_option=self._group_option, feature_extractor=self._feature_extractor).run() self._scheduler.step() storage_dict = StorageIncomeDict( PRETRAIN_ENCODER=pretrain_encoder_dict, ) self._pretrain_encoder_storage.put_from_dict(storage_dict, epoch=self._cur_epoch) self._writer.add_scalar_with_StorageDict(storage_dict, self._cur_epoch) self._save_to("last.pth", path=os.path.join(self._save_dir, "pretrain_encoder")) def pretrain_decoder_init( self, lr: float = 1e-6, weight_decay: float = 0.0, multiplier: int = 300, warmup_max=10, ptype="mlp", extract_position="Up_conv3", enable_grad_from="Up5", ): # feature_exactor self._extract_position = extract_position self._feature_extractor = UNetFeatureExtractor(self._extract_position) projector_input_dim = UNet.dimension_dict[extract_position] # if disable_encoder's gradient self._enable_grad_from = enable_grad_from self._projector = LocalProjectionHead(projector_input_dim, head_type=ptype, output_size=(4, 4)) # noqa self._optimizer = torch.optim.Adam(itertools.chain( self._model.parameters(), self._projector.parameters()), lr=lr, weight_decay=weight_decay) self._scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( self._optimizer, self._max_epoch_train_decoder - warmup_max, 0) self._scheduler = GradualWarmupScheduler(self._optimizer, multiplier, warmup_max, self._scheduler) # set augmentation method as `total_freedom = False` assert hasattr(self._pretrain_loader.dataset._transform, "_total_freedom") # noqa self._pretrain_loader.dataset._transform._total_freedom = False # noqa self._pretrain_loader_iter = iter(self._pretrain_loader) # noqa # contrastive loss self._contrastive_criterion = SupConLoss() def pretrain_decoder_run(self): self._model.disable_grad_all() self._model.enable_grad(from_=self._enable_grad_from, util=self._extract_position) self.to(self._device) for self._cur_epoch in range(self._start_epoch, self._max_epoch_train_decoder): pretrain_decoder_dict = PretrainDecoderEpoch( model=self._model, projection_head=self._projector, optimizer=self._optimizer, pretrain_decoder_loader=self._pretrain_loader_iter, contrastive_criterion=self._contrastive_criterion, num_batches=self._num_batches, cur_epoch=self._cur_epoch, device=self._device, feature_extractor=self._feature_extractor).run() self._scheduler.step() storage_dict = StorageIncomeDict( PRETRAIN_DECODER=pretrain_decoder_dict, ) self._pretrain_encoder_storage.put_from_dict(storage_dict, epoch=self._cur_epoch) self._writer.add_scalar_with_StorageDict(storage_dict, self._cur_epoch) self._save_to("last.pth", path=os.path.join(self._save_dir, "pretrain_decoder")) def finetune_network_init(self, lr: float = 1e-7, weight_decay: float = 1e-5, multiplier: int = 200, warmup_max=10): self._optimizer = torch.optim.Adam(self._model.parameters(), lr=lr, weight_decay=weight_decay) self._scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( self._optimizer, self._max_epoch_train_finetune - warmup_max, 5e-7) self._scheduler = GradualWarmupScheduler(self._optimizer, multiplier, warmup_max, self._scheduler) self._sup_criterion = KL_div() # set augmentation method as `total_freedom = True` assert hasattr(self._fine_tune_loader.dataset._transform, "_total_freedom") # noqa self._fine_tune_loader.dataset._transform._total_freedom = True # noqa self._fine_tune_loader_iter = iter(self._fine_tune_loader) # noqa def finetune_network_run(self, epocher_type=SimpleFineTuneEpoch): self.to(self._device) self._model.enable_grad_encoder() # noqa self._model.enable_grad_decoder() # noqa for self._cur_epoch in range(self._start_epoch, self._max_epoch_train_finetune): finetune_dict = epocher_type.create_from_trainer(self).run() val_dict, cur_score = EvalEpoch(self._model, val_loader=self._val_loader, sup_criterion=self._sup_criterion, cur_epoch=self._cur_epoch, device=self._device).run() self._scheduler.step() storage_dict = StorageIncomeDict(finetune=finetune_dict, val=val_dict) self._finetune_storage.put_from_dict(storage_dict, epoch=self._cur_epoch) self._writer.add_scalar_with_StorageDict(storage_dict, self._cur_epoch) self.save(cur_score, os.path.join(self._save_dir, "finetune")) def start_training(self, checkpoint: str = None, pretrain_encoder_init_options=None, pretrain_decoder_init_options=None, finetune_network_init_options=None): if finetune_network_init_options is None: finetune_network_init_options = {} if pretrain_decoder_init_options is None: pretrain_decoder_init_options = {} if pretrain_encoder_init_options is None: pretrain_encoder_init_options = {} with SummaryWriter(str(self._save_dir)) as self._writer: # noqa if self.train_encoder: self.pretrain_encoder_init(**pretrain_encoder_init_options) if checkpoint is not None: try: self.load_state_dict_from_path( os.path.join(checkpoint, "pretrain_encoder")) except Exception as e: raise RuntimeError( f"loading pretrain_encoder_checkpoint failed with {e}, " ) self.pretrain_encoder_run() if self.train_decoder: self.pretrain_decoder_init(**pretrain_decoder_init_options) if checkpoint is not None: try: self.load_state_dict_from_path( os.path.join(checkpoint, "pretrain_decoder")) except Exception as e: print( f"loading pretrain_decoder_checkpoint failed with {e}, " ) self.pretrain_decoder_run() self.finetune_network_init(**finetune_network_init_options) if checkpoint is not None: try: self.load_state_dict_from_path( os.path.join(checkpoint, "finetune")) except Exception as e: print(f"loading finetune_checkpoint failed with {e}, ") self.finetune_network_run()
class IICContrastTrainer(ContrastTrainer): def pretrain_encoder_init( self, group_option: str, lr=1e-6, weight_decay=1e-5, multiplier=300, warmup_max=10, num_clusters=10, num_subheads=10, iic_weight=1, disable_contrastive=False, ctemperature=1, ctype: str = "linear", ptype: str = "mlp", extract_position: str = "Conv5", ): # adding optimizer and scheduler self._extract_position = extract_position self._feature_extractor = UNetFeatureExtractor(self._extract_position) self._projector_contrastive = ProjectionHead( input_dim=UNet.dimension_dict[self._extract_position], output_dim=256, head_type=ptype) # noqa self._projector_iic = ClusterHead( input_dim=UNet.dimension_dict[self._extract_position], num_clusters=num_clusters, head_type=ctype, T=ctemperature, num_subheads=num_subheads) self._optimizer = torch.optim.Adam( itertools.chain( self._model.parameters(), # noqa self._projector_contrastive.parameters(), self._projector_iic.parameters()), # noqa lr=lr, weight_decay=weight_decay) # noqa self._scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( self._optimizer, # noqa self._max_epoch_train_encoder - warmup_max, 0) self._scheduler = GradualWarmupScheduler(self._optimizer, multiplier, warmup_max, self._scheduler) # noqa self._group_option = group_option # noqa self._disable_contrastive = disable_contrastive # set augmentation method as `total_freedom = True` assert hasattr(self._pretrain_loader.dataset._transform, "_total_freedom") # noqa self._pretrain_loader.dataset._transform._total_freedom = True # noqa self._pretrain_loader_iter = iter(self._pretrain_loader) # noqa # contrastive loss self._contrastive_criterion = SupConLoss() # iic weight self._iic_weight = iic_weight def pretrain_encoder_run(self): self.to(self._device) self._model.disable_grad_all() self._model.enable_grad(from_="Conv1", util=self._extract_position) for self._cur_epoch in range(self._start_epoch, self._max_epoch_train_encoder): pretrain_encoder_dict = IICPretrainEcoderEpoch( model=self._model, projection_head=self._projector_contrastive, projection_classifier=self._projector_iic, optimizer=self._optimizer, pretrain_encoder_loader=self._pretrain_loader_iter, contrastive_criterion=self._contrastive_criterion, num_batches=self._num_batches, cur_epoch=self._cur_epoch, device=self._device, group_option=self._group_option, feature_extractor=self._feature_extractor, iic_weight=self._iic_weight, disable_contrastive=self._disable_contrastive, ).run() self._scheduler.step() storage_dict = StorageIncomeDict( PRETRAIN_ENCODER=pretrain_encoder_dict) self._pretrain_encoder_storage.put_from_dict(storage_dict, epoch=self._cur_epoch) self._writer.add_scalar_with_StorageDict(storage_dict, self._cur_epoch) self._save_to("last.pth", path=os.path.join(self._save_dir, "pretrain_encoder")) def pretrain_decoder_init( self, lr: float = 1e-6, weight_decay: float = 0.0, multiplier: int = 300, warmup_max=10, num_clusters=20, ctemperature=1, num_subheads=10, extract_position="Up_conv3", enable_grad_from="Conv1", ptype="mlp", ctype="mlp", iic_weight=1, disable_contrastive=False, padding=0, patch_size=512, ): # feature_exactor self._extract_position = extract_position self._feature_extractor = UNetFeatureExtractor(self._extract_position) projector_input_dim = UNet.dimension_dict[extract_position] # if disable_encoder's gradient self._enable_grad_from = enable_grad_from # adding optimizer and scheduler self._projector_contrastive = LocalProjectionHead(projector_input_dim, head_type=ptype, output_size=(4, 4)) self._projector_iic = LocalClusterHead(projector_input_dim, num_clusters=num_clusters, num_subheads=num_subheads, head_type=ctype, T=ctemperature) self._optimizer = torch.optim.Adam(itertools.chain( self._model.parameters(), self._projector_contrastive.parameters(), self._projector_iic.parameters(), ), lr=lr, weight_decay=weight_decay) self._scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( self._optimizer, self._max_epoch_train_decoder - warmup_max, 0) self._scheduler = GradualWarmupScheduler(self._optimizer, multiplier, warmup_max, self._scheduler) # set augmentation method as `total_freedom = False` assert hasattr(self._pretrain_loader.dataset._transform, "_total_freedom") # noqa self._pretrain_loader.dataset._transform._total_freedom = False # noqa self._pretrain_loader_iter = iter(self._pretrain_loader) # noqa # contrastive_loss self._contrastive_criterion = SupConLoss() self._disable_contrastive = disable_contrastive self._iicseg_criterion = IIDSegmentationSmallPathLoss( padding=padding, patch_size=patch_size) print(self._iicseg_criterion) # iic weight self._iic_weight = iic_weight def pretrain_decoder_run(self): self._model.disable_grad_all() self._model.enable_grad(from_=self._enable_grad_from, util=self._extract_position) self.to(self._device) for self._cur_epoch in range(self._start_epoch, self._max_epoch_train_decoder): pretrain_decoder_dict = IICPretrainDecoderEpoch( model=self._model, projection_head=self._projector_contrastive, projection_classifier=self._projector_iic, optimizer=self._optimizer, pretrain_decoder_loader=self._pretrain_loader_iter, contrastive_criterion=self._contrastive_criterion, iicseg_criterion=self._iicseg_criterion, num_batches=self._num_batches, cur_epoch=self._cur_epoch, device=self._device, disable_contrastive=self._disable_contrastive, iic_weight=self._iic_weight, feature_extractor=self._feature_extractor, ).run() self._scheduler.step() storage_dict = StorageIncomeDict( PRETRAIN_DECODER=pretrain_decoder_dict, ) self._pretrain_encoder_storage.put_from_dict(storage_dict, epoch=self._cur_epoch) self._writer.add_scalar_with_StorageDict(storage_dict, self._cur_epoch) self._save_to("last.pth", path=os.path.join(self._save_dir, "pretrain_decoder"))
def pretrain_decoder_init( self, lr: float = 1e-6, weight_decay: float = 0.0, multiplier: int = 300, warmup_max=10, num_clusters=20, ctemperature=1, num_subheads=10, extract_position="Up_conv3", enable_grad_from="Conv1", ptype="mlp", ctype="mlp", iic_weight=1, disable_contrastive=False, padding=0, patch_size=512, ): # feature_exactor self._extract_position = extract_position self._feature_extractor = UNetFeatureExtractor(self._extract_position) projector_input_dim = UNet.dimension_dict[extract_position] # if disable_encoder's gradient self._enable_grad_from = enable_grad_from # adding optimizer and scheduler self._projector_contrastive = LocalProjectionHead(projector_input_dim, head_type=ptype, output_size=(4, 4)) self._projector_iic = LocalClusterHead(projector_input_dim, num_clusters=num_clusters, num_subheads=num_subheads, head_type=ctype, T=ctemperature) self._optimizer = torch.optim.Adam(itertools.chain( self._model.parameters(), self._projector_contrastive.parameters(), self._projector_iic.parameters(), ), lr=lr, weight_decay=weight_decay) self._scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( self._optimizer, self._max_epoch_train_decoder - warmup_max, 0) self._scheduler = GradualWarmupScheduler(self._optimizer, multiplier, warmup_max, self._scheduler) # set augmentation method as `total_freedom = False` assert hasattr(self._pretrain_loader.dataset._transform, "_total_freedom") # noqa self._pretrain_loader.dataset._transform._total_freedom = False # noqa self._pretrain_loader_iter = iter(self._pretrain_loader) # noqa # contrastive_loss self._contrastive_criterion = SupConLoss() self._disable_contrastive = disable_contrastive self._iicseg_criterion = IIDSegmentationSmallPathLoss( padding=padding, patch_size=patch_size) print(self._iicseg_criterion) # iic weight self._iic_weight = iic_weight