Ejemplo n.º 1
0
 def __init__(self, config: Dict) -> None:
     super().__init__(config)
     self.name = "config_default"
     self.train = trainer_configs.BaseTrainConfig(config.pop(Conf.TRAIN))
     self.val = trainer_configs.BaseValConfig(config.pop(Conf.VAL))
     self.dataset_train = data.BaseDatasetConfig(
         config.pop(Conf.DATASET_TRAIN))
     self.dataset_val = data.BaseDatasetConfig(config.pop(Conf.DATASET_VAL))
     self.logging = utils.BaseLoggingConfig(config.pop(Conf.LOGGING))
     self.saving = trainer_configs.BaseSavingConfig(config.pop(Conf.SAVING))
     self.optimizer = optimization.OptimizerConfig(
         config.pop(Conf.OPTIMIZER))
     self.lr_scheduler = lr_scheduler.SchedulerConfig(
         config.pop(Conf.LR_SCHEDULER))
     self.mlp = MLPNetConfig(config.pop("mlp"))
Ejemplo n.º 2
0
    def __init__(self,
                 config: Dict[str, Any],
                 *,
                 is_train: bool = True) -> None:
        super().__init__(config)
        self.name = "config_ret"
        self.dim_feat_global: int = config.pop("dim_feat_global", 768)
        self.dim_feat_local: int = config.pop("dim_feat_local", 384)
        if not is_train:
            # Disable dataset caching
            logger = logging.getLogger(utils.LOGGER_NAME)
            logger.debug("Disable dataset caching during validation.")
            config["dataset_val"]["preload_vid_feat"] = False
            config["dataset_val"]["preload_text_feat"] = False

        try:
            self.train = RetrievalTrainConfig(config.pop(Conf.TRAIN))
            self.val = RetrievalValConfig(config.pop(Conf.VAL))
            self.dataset_train = RetrievalDatasetConfig(
                config.pop(Conf.DATASET_TRAIN))
            self.dataset_val = RetrievalDatasetConfig(
                config.pop(Conf.DATASET_VAL))
            self.logging = trainer_configs.BaseLoggingConfig(
                config.pop(Conf.LOGGING))
            self.saving = trainer_configs.BaseSavingConfig(
                config.pop(Conf.SAVING))
            self.optimizer = optimization.OptimizerConfig(
                config.pop(Conf.OPTIMIZER))
            self.lr_scheduler = lr_scheduler.SchedulerConfig(
                config.pop(Conf.LR_SCHEDULER))
            self.model_cfgs = {}
            for key in RetrievalNetworksConst.values():
                self.model_cfgs[key] = models.TransformerConfig(
                    config.pop(key))
        except KeyError as e:
            print()
            print(traceback.format_exc())
            print(
                f"ERROR: {e} not defined in config {self.__class__.__name__}\n"
            )
            raise e

        self.post_init()
Ejemplo n.º 3
0
    def __init__(self, config: Dict[str, Any]) -> None:
        super().__init__(config)
        self.name = "config_ret"

        # mandatory groups, needed for nntrainer to work correctly
        self.train = trainer_configs.BaseTrainConfig(config.pop("train"))
        self.val = trainer_configs.BaseValConfig(config.pop("val"))
        self.dataset_train = MartDatasetConfig(config.pop("dataset_train"))
        self.dataset_val = MartDatasetConfig(config.pop("dataset_val"))
        self.logging = trainer_configs.BaseLoggingConfig(config.pop("logging"))
        self.saving = trainer_configs.BaseSavingConfig(config.pop("saving"))

        # more training
        self.label_smoothing: float = config.pop("label_smoothing")

        # more validation
        self.save_mode: str = config.pop("save_mode")
        self.use_beam: bool = config.pop("use_beam")
        self.beam_size: int = config.pop("beam_size")
        self.n_best: int = config.pop("n_best")
        self.min_sen_len: int = config.pop("min_sen_len")
        self.max_sen_len: int = config.pop("max_sen_len")
        self.block_ngram_repeat: int = config.pop("block_ngram_repeat")
        self.length_penalty_name: str = config.pop("length_penalty_name")
        self.length_penalty_alpha: float = config.pop("length_penalty_alpha")

        # dataset
        self.max_n_sen: int = config.pop("max_n_sen")
        self.max_n_sen_add_val: int = config.pop("max_n_sen_add_val")
        self.max_t_len: int = config.pop("max_t_len")
        self.max_v_len: int = config.pop("max_v_len")
        self.type_vocab_size: int = config.pop("type_vocab_size")
        self.word_vec_size: int = config.pop("word_vec_size")

        # dataset: coot features
        self.coot_model_name: Optional[str] = config.pop("coot_model_name")
        self.coot_dim_clip: int = config.pop("coot_dim_clip")
        self.coot_dim_vid: int = config.pop("coot_dim_vid")
        self.coot_mode: str = config.pop("coot_mode")
        self.video_feature_size: int = config.pop("video_feature_size")

        # technical
        self.debug: bool = config.pop("debug")

        # model
        self.attention_probs_dropout_prob: float = config.pop(
            "attention_probs_dropout_prob")
        self.hidden_dropout_prob: float = config.pop("hidden_dropout_prob")
        self.hidden_size: int = config.pop("hidden_size")
        self.intermediate_size: int = config.pop("intermediate_size")
        self.layer_norm_eps: float = config.pop("layer_norm_eps")
        self.memory_dropout_prob: float = config.pop("memory_dropout_prob")
        self.num_attention_heads: int = config.pop("num_attention_heads")
        self.num_hidden_layers: int = config.pop("num_hidden_layers")
        self.n_memory_cells: int = config.pop("n_memory_cells")
        self.share_wd_cls_weight: bool = config.pop("share_wd_cls_weight")
        self.recurrent: bool = config.pop("recurrent")
        self.untied: bool = config.pop("untied")
        self.mtrans: bool = config.pop("mtrans")
        self.xl: bool = config.pop("xl")
        self.xl_grad: bool = config.pop("xl_grad")
        self.use_glove: bool = config.pop("use_glove")
        self.freeze_glove: bool = config.pop("freeze_glove")

        # optimization
        self.ema_decay: float = config.pop("ema_decay")
        self.initializer_range: float = config.pop("initializer_range")
        self.lr: float = config.pop("lr")
        self.lr_warmup_proportion: float = config.pop("lr_warmup_proportion")
        self.infty: int = config.pop("infty", 0)
        self.eps: float = config.pop("eps", 1e-6)

        # max position embeddings is calculated as the max joint sequence length
        self.max_position_embeddings: int = self.max_v_len + self.max_t_len

        # must be set manually as it depends on the dataset
        self.vocab_size: Optional[int] = None

        # assert the config is valid
        if self.xl:
            assert self.recurrent, "recurrent must be True if TransformerXL is used."
        if self.xl_grad:
            assert self.xl, "xl must be True when using xl_grad"
        assert not (self.recurrent and self.untied), "cannot be True for both"
        assert not (self.recurrent and self.mtrans), "cannot be True for both"
        assert not (self.untied and self.mtrans), "cannot be True for both"
        if self.share_wd_cls_weight:
            assert self.word_vec_size == self.hidden_size, (
                "hidden size has to be the same as word embedding size when "
                "sharing the word embedding weight and the final classifier weight"
            )

        # infer model type
        if self.recurrent:  # recurrent paragraphs
            if self.xl:
                if self.xl_grad:
                    self.model_type = "xl_grad"
                else:
                    self.model_type = "xl"
            else:
                self.model_type = "re"
        else:  # single sentence
            if self.untied:
                self.model_type = "untied_single"
            elif self.mtrans:
                self.model_type = "mtrans_single"
            else:
                self.model_type = "single"

        self.post_init()