コード例 #1
0
    def __init__(self,
                 cfg: RetrievalConfig,
                 model_mgr: model_retrieval.RetrievalModelManager,
                 exp_group: str,
                 exp_name: str,
                 run_name: str,
                 train_loader_length: int,
                 *,
                 log_dir: str = "experiments",
                 log_level: Optional[int] = None,
                 logger: Optional[logging.Logger] = None,
                 print_graph: bool = False,
                 reset: bool = False,
                 load_best: bool = False,
                 load_epoch: Optional[int] = None,
                 load_model: Optional[str] = None,
                 inference_only: bool = False):
        super().__init__(cfg,
                         model_mgr,
                         exp_group,
                         exp_name,
                         run_name,
                         train_loader_length,
                         ExperimentTypesConst.RETRIEVAL,
                         log_dir=log_dir,
                         log_level=log_level,
                         logger=logger,
                         print_graph=print_graph,
                         reset=reset,
                         load_best=load_best,
                         load_epoch=load_epoch,
                         load_model=load_model,
                         is_test=inference_only)

        # ---------- setup ----------

        # update type hints from base classes to inherited classes
        self.cfg: RetrievalConfig = self.cfg
        self.model_mgr: model_retrieval.RetrievalModelManager = self.model_mgr

        # overwrite default state with inherited trainer state in case we need additional state fields
        self.state = RetrievalTrainerState()

        # ---------- loss ----------

        # contrastive loss
        assert self.cfg.train.loss_func == loss_fn.LossesConst.CONTRASTIVE
        self.loss_contr = ContrastiveLoss(
            self.cfg.train.contrastive_loss_config.margin,
            use_cuda=self.cfg.use_cuda)

        # cycle consistency
        if self.cfg.train.loss_cycle_cons != 0:
            self.loss_cycle_cons = CycleConsistencyLoss(
                use_cuda=self.cfg.use_cuda)

        # ---------- additional metrics ----------

        # loss proportions
        self.metrics.add_meter(CMeters.VAL_LOSS_CC, use_avg=False)
        self.metrics.add_meter(CMeters.VAL_LOSS_CONTRASTIVE, use_avg=False)
        self.metrics.add_meter(CMeters.TRAIN_LOSS_CC,
                               per_step=True,
                               use_avg=False)
        self.metrics.add_meter(CMeters.TRAIN_LOSS_CONTRASTIVE,
                               per_step=True,
                               use_avg=False)

        # retrieval validation metrics must be constructed as product of two lists
        for modality in CMeters.RET_MODALITIES:
            # modality: retrieval from where to where
            for metric in CMeters.RET_METRICS:
                # metric: retrieval@1, mean, ...
                if metric == "r1":
                    # log r1 metric to the overview class
                    metric_class = "val_base"
                else:
                    # log all other metrics to the detail class
                    metric_class = "val_ret"
                self.metrics.add_meter(f"{metric_class}/{modality}-{metric}",
                                       use_avg=False)

        # ---------- optimization ----------
        self.optimizer = None
        self.lr_scheduler = None
        # skip optimizer if not training
        if not self.is_test:
            # create optimizer
            params, _param_names, _params_flat = self.model_mgr.get_all_params(
            )
            self.optimizer = optimization.make_optimizer(
                self.cfg.optimizer, params)

            # create lr scheduler
            self.lr_scheduler = lr_scheduler.make_lr_scheduler(
                self.optimizer,
                self.cfg.lr_scheduler,
                self.cfg.optimizer.lr,
                self.cfg.train.num_epochs,
                self.train_loader_length,
                logger=self.logger)

        # post init hook for checkpoint loading
        self.hook_post_init()
コード例 #2
0
    def __init__(self,
                 cfg: MLPMNISTExperimentConfig,
                 model_mgr: MLPModelManager,
                 exp_dir: str,
                 exp_name: str,
                 run_name: str,
                 train_loader_length: int,
                 *,
                 log_dir: str = Paths.DIR_EXPERIMENTS,
                 log_level: Optional[int] = None,
                 logger: Optional[logging.Logger] = None,
                 print_graph: bool = False,
                 reset: bool = False,
                 load_best: bool = False,
                 load_epoch: Optional[int] = None,
                 inference_only: bool = False):
        super().__init__(cfg,
                         model_mgr,
                         exp_dir,
                         exp_name,
                         run_name,
                         train_loader_length,
                         "mlpmnist",
                         log_dir=log_dir,
                         log_level=log_level,
                         logger=logger,
                         print_graph=print_graph,
                         reset=reset,
                         load_best=load_best,
                         load_epoch=load_epoch,
                         is_test=inference_only)
        # ---------- setup ----------

        # update type hints from base classes to inherited classes
        self.cfg: MLPMNISTExperimentConfig = self.cfg
        self.model_mgr: MLPModelManager = self.model_mgr

        # update trainer state if loading is requested
        if self.load:
            self.state.current_epoch = self.load_ep

        # ---------- loss ----------

        # contrastive
        assert self.cfg.train.loss_func == "crossentropy"
        self.loss_ce = nn.CrossEntropyLoss()

        # ---------- additional metrics ----------

        # metrics logged once per epoch, log only value
        for field in ("val_base/accuracy", ):
            self.metrics.add_meter(field, use_avg=False)

        # create optimizer
        params, _param_names, _params_flat = self.model_mgr.get_all_params()
        self.optimizer = optimization.make_optimizer(self.cfg.optimizer,
                                                     params)

        # create lr scheduler
        self.lr_scheduler = lr_scheduler.make_lr_scheduler(
            self.optimizer,
            self.cfg.lr_scheduler,
            self.cfg.optimizer.lr,
            self.cfg.train.num_epochs,
            self.train_loader_length,
            logger=self.logger)

        self.hook_post_init()