示例#1
0
 def build_heads(self):
     """Build the different heads for the model. It can be either the pretraining
     head or the classifier heads.
     """
     self.heads = nn.ModuleList()
     head_configs = self.config.get("heads", [])
     for head_config in head_configs:
         head_type = head_config.get("type", "mlp")
         head_class = registry.get_transformer_head_class(head_type)
         self.heads.append(head_class(head_config))
示例#2
0
    def __init__(
        self,
        head_configs: Optional[Dict] = None,
        loss_configs: Optional[Dict] = None,
        tasks: Union[List, str] = DEFAULT_PRETRAINING_TASKS,
        mask_probability: float = 0,
        random_init: bool = False,
        bert_model_name: str = "bert-base-uncased",
        img_dim: int = 2048,
        hidden_size: int = 768,
        hidden_dropout_prob: float = 0,
        text_embeddings: Any = EMPTY_CONFIG,
        encoder: Any = EMPTY_CONFIG,
    ):
        super().__init__()
        if head_configs is None:
            head_configs = copy.deepcopy(DEFAULT_PRETRAINING_HEAD_CONFIGS)

        if loss_configs is None:
            loss_configs = {}

        self.loss_configs = loss_configs
        self.mask_probability = mask_probability
        self.uniter = UNITERModelBase(
            random_init=random_init,
            bert_model_name=bert_model_name,
            img_dim=img_dim,
            hidden_size=hidden_size,
            hidden_dropout_prob=hidden_dropout_prob,
            text_embeddings=text_embeddings,
            encoder=encoder,
        )

        self.heads = nn.ModuleDict()

        self.tasks = tasks
        if isinstance(self.tasks, str):
            self.tasks = self.tasks.split(",")

        for task in self.tasks:
            head_config = head_configs[task]
            head_type = head_config.get("type", "mlp")
            head_class = registry.get_transformer_head_class(head_type)
            if head_type == "mrfr":
                self.heads[task] = head_class(
                    self.uniter.img_embeddings.img_linear.weight, **head_config
                )
            elif head_type in ("itm", "mlm", "mlp"):
                self.heads[task] = head_class(head_config)
            else:
                self.heads[task] = head_class(**head_config)

        self.init_losses()
示例#3
0
    def __init__(
        self,
        head_configs: Dict,
        loss_configs: Dict,
        tasks: Union[str, List],
        random_init: bool = False,
        bert_model_name: str = "bert-base-uncased",
        img_dim: int = 2048,
        hidden_size: int = 768,
        hidden_dropout_prob: float = 0,
        text_embeddings: Any = EMPTY_CONFIG,
        encoder: Any = EMPTY_CONFIG,
    ):
        super().__init__()
        self.loss_configs = loss_configs
        self.uniter = UNITERModelBase(
            random_init=random_init,
            bert_model_name=bert_model_name,
            img_dim=img_dim,
            hidden_size=hidden_size,
            hidden_dropout_prob=hidden_dropout_prob,
            text_embeddings=text_embeddings,
            encoder=encoder,
        )

        self.heads = nn.ModuleDict()
        self.tasks = tasks
        if isinstance(self.tasks, str):
            self.tasks = self.tasks.split(",")

        for task in self.tasks:
            assert task in head_configs, (
                f"Task {task} is specified in your model configs"
                + " but there is no head configured for the task. "
                + "Head configs can be added under model_config.heads "
                + "in your yaml configs. Either remove this task if UNITER"
                + " is not meant to run on a dataset named {task}"
                + " or add a head config."
            )
            head_config = head_configs[task]
            head_type = head_config.get("type", "mlp")
            head_class = registry.get_transformer_head_class(head_type)
            self.heads[task] = head_class(head_config)

        self.init_losses()
示例#4
0
 def head_from_config(config):
     head_type = config.get("type", "mlp")
     head_class = registry.get_transformer_head_class(head_type)
     return head_class(config)