def __init__(self, config: Dict) -> None: super().__init__(config) self.name = "config_default" self.train = trainer_configs.BaseTrainConfig(config.pop(Conf.TRAIN)) self.val = trainer_configs.BaseValConfig(config.pop(Conf.VAL)) self.dataset_train = data.BaseDatasetConfig( config.pop(Conf.DATASET_TRAIN)) self.dataset_val = data.BaseDatasetConfig(config.pop(Conf.DATASET_VAL)) self.logging = utils.BaseLoggingConfig(config.pop(Conf.LOGGING)) self.saving = trainer_configs.BaseSavingConfig(config.pop(Conf.SAVING)) self.optimizer = optimization.OptimizerConfig( config.pop(Conf.OPTIMIZER)) self.lr_scheduler = lr_scheduler.SchedulerConfig( config.pop(Conf.LR_SCHEDULER)) self.mlp = MLPNetConfig(config.pop("mlp"))
def __init__(self, config: Dict[str, Any], *, is_train: bool = True) -> None: super().__init__(config) self.name = "config_ret" self.dim_feat_global: int = config.pop("dim_feat_global", 768) self.dim_feat_local: int = config.pop("dim_feat_local", 384) if not is_train: # Disable dataset caching logger = logging.getLogger(utils.LOGGER_NAME) logger.debug("Disable dataset caching during validation.") config["dataset_val"]["preload_vid_feat"] = False config["dataset_val"]["preload_text_feat"] = False try: self.train = RetrievalTrainConfig(config.pop(Conf.TRAIN)) self.val = RetrievalValConfig(config.pop(Conf.VAL)) self.dataset_train = RetrievalDatasetConfig( config.pop(Conf.DATASET_TRAIN)) self.dataset_val = RetrievalDatasetConfig( config.pop(Conf.DATASET_VAL)) self.logging = trainer_configs.BaseLoggingConfig( config.pop(Conf.LOGGING)) self.saving = trainer_configs.BaseSavingConfig( config.pop(Conf.SAVING)) self.optimizer = optimization.OptimizerConfig( config.pop(Conf.OPTIMIZER)) self.lr_scheduler = lr_scheduler.SchedulerConfig( config.pop(Conf.LR_SCHEDULER)) self.model_cfgs = {} for key in RetrievalNetworksConst.values(): self.model_cfgs[key] = models.TransformerConfig( config.pop(key)) except KeyError as e: print() print(traceback.format_exc()) print( f"ERROR: {e} not defined in config {self.__class__.__name__}\n" ) raise e self.post_init()
def __init__(self, config: Dict[str, Any]) -> None: super().__init__(config) self.name = "config_ret" # mandatory groups, needed for nntrainer to work correctly self.train = trainer_configs.BaseTrainConfig(config.pop("train")) self.val = trainer_configs.BaseValConfig(config.pop("val")) self.dataset_train = MartDatasetConfig(config.pop("dataset_train")) self.dataset_val = MartDatasetConfig(config.pop("dataset_val")) self.logging = trainer_configs.BaseLoggingConfig(config.pop("logging")) self.saving = trainer_configs.BaseSavingConfig(config.pop("saving")) # more training self.label_smoothing: float = config.pop("label_smoothing") # more validation self.save_mode: str = config.pop("save_mode") self.use_beam: bool = config.pop("use_beam") self.beam_size: int = config.pop("beam_size") self.n_best: int = config.pop("n_best") self.min_sen_len: int = config.pop("min_sen_len") self.max_sen_len: int = config.pop("max_sen_len") self.block_ngram_repeat: int = config.pop("block_ngram_repeat") self.length_penalty_name: str = config.pop("length_penalty_name") self.length_penalty_alpha: float = config.pop("length_penalty_alpha") # dataset self.max_n_sen: int = config.pop("max_n_sen") self.max_n_sen_add_val: int = config.pop("max_n_sen_add_val") self.max_t_len: int = config.pop("max_t_len") self.max_v_len: int = config.pop("max_v_len") self.type_vocab_size: int = config.pop("type_vocab_size") self.word_vec_size: int = config.pop("word_vec_size") # dataset: coot features self.coot_model_name: Optional[str] = config.pop("coot_model_name") self.coot_dim_clip: int = config.pop("coot_dim_clip") self.coot_dim_vid: int = config.pop("coot_dim_vid") self.coot_mode: str = config.pop("coot_mode") self.video_feature_size: int = config.pop("video_feature_size") # technical self.debug: bool = config.pop("debug") # model self.attention_probs_dropout_prob: float = config.pop( "attention_probs_dropout_prob") self.hidden_dropout_prob: float = config.pop("hidden_dropout_prob") self.hidden_size: int = config.pop("hidden_size") self.intermediate_size: int = config.pop("intermediate_size") self.layer_norm_eps: float = config.pop("layer_norm_eps") self.memory_dropout_prob: float = config.pop("memory_dropout_prob") self.num_attention_heads: int = config.pop("num_attention_heads") self.num_hidden_layers: int = config.pop("num_hidden_layers") self.n_memory_cells: int = config.pop("n_memory_cells") self.share_wd_cls_weight: bool = config.pop("share_wd_cls_weight") self.recurrent: bool = config.pop("recurrent") self.untied: bool = config.pop("untied") self.mtrans: bool = config.pop("mtrans") self.xl: bool = config.pop("xl") self.xl_grad: bool = config.pop("xl_grad") self.use_glove: bool = config.pop("use_glove") self.freeze_glove: bool = config.pop("freeze_glove") # optimization self.ema_decay: float = config.pop("ema_decay") self.initializer_range: float = config.pop("initializer_range") self.lr: float = config.pop("lr") self.lr_warmup_proportion: float = config.pop("lr_warmup_proportion") self.infty: int = config.pop("infty", 0) self.eps: float = config.pop("eps", 1e-6) # max position embeddings is calculated as the max joint sequence length self.max_position_embeddings: int = self.max_v_len + self.max_t_len # must be set manually as it depends on the dataset self.vocab_size: Optional[int] = None # assert the config is valid if self.xl: assert self.recurrent, "recurrent must be True if TransformerXL is used." if self.xl_grad: assert self.xl, "xl must be True when using xl_grad" assert not (self.recurrent and self.untied), "cannot be True for both" assert not (self.recurrent and self.mtrans), "cannot be True for both" assert not (self.untied and self.mtrans), "cannot be True for both" if self.share_wd_cls_weight: assert self.word_vec_size == self.hidden_size, ( "hidden size has to be the same as word embedding size when " "sharing the word embedding weight and the final classifier weight" ) # infer model type if self.recurrent: # recurrent paragraphs if self.xl: if self.xl_grad: self.model_type = "xl_grad" else: self.model_type = "xl" else: self.model_type = "re" else: # single sentence if self.untied: self.model_type = "untied_single" elif self.mtrans: self.model_type = "mtrans_single" else: self.model_type = "single" self.post_init()