def build_heads(self): """Build the different heads for the model. It can be either the pretraining head or the classifier heads. """ self.heads = nn.ModuleList() head_configs = self.config.get("heads", []) for head_config in head_configs: head_type = head_config.get("type", "mlp") head_class = registry.get_transformer_head_class(head_type) self.heads.append(head_class(head_config))
def __init__( self, head_configs: Optional[Dict] = None, loss_configs: Optional[Dict] = None, tasks: Union[List, str] = DEFAULT_PRETRAINING_TASKS, mask_probability: float = 0, random_init: bool = False, bert_model_name: str = "bert-base-uncased", img_dim: int = 2048, hidden_size: int = 768, hidden_dropout_prob: float = 0, text_embeddings: Any = EMPTY_CONFIG, encoder: Any = EMPTY_CONFIG, ): super().__init__() if head_configs is None: head_configs = copy.deepcopy(DEFAULT_PRETRAINING_HEAD_CONFIGS) if loss_configs is None: loss_configs = {} self.loss_configs = loss_configs self.mask_probability = mask_probability self.uniter = UNITERModelBase( random_init=random_init, bert_model_name=bert_model_name, img_dim=img_dim, hidden_size=hidden_size, hidden_dropout_prob=hidden_dropout_prob, text_embeddings=text_embeddings, encoder=encoder, ) self.heads = nn.ModuleDict() self.tasks = tasks if isinstance(self.tasks, str): self.tasks = self.tasks.split(",") for task in self.tasks: head_config = head_configs[task] head_type = head_config.get("type", "mlp") head_class = registry.get_transformer_head_class(head_type) if head_type == "mrfr": self.heads[task] = head_class( self.uniter.img_embeddings.img_linear.weight, **head_config ) elif head_type in ("itm", "mlm", "mlp"): self.heads[task] = head_class(head_config) else: self.heads[task] = head_class(**head_config) self.init_losses()
def __init__( self, head_configs: Dict, loss_configs: Dict, tasks: Union[str, List], random_init: bool = False, bert_model_name: str = "bert-base-uncased", img_dim: int = 2048, hidden_size: int = 768, hidden_dropout_prob: float = 0, text_embeddings: Any = EMPTY_CONFIG, encoder: Any = EMPTY_CONFIG, ): super().__init__() self.loss_configs = loss_configs self.uniter = UNITERModelBase( random_init=random_init, bert_model_name=bert_model_name, img_dim=img_dim, hidden_size=hidden_size, hidden_dropout_prob=hidden_dropout_prob, text_embeddings=text_embeddings, encoder=encoder, ) self.heads = nn.ModuleDict() self.tasks = tasks if isinstance(self.tasks, str): self.tasks = self.tasks.split(",") for task in self.tasks: assert task in head_configs, ( f"Task {task} is specified in your model configs" + " but there is no head configured for the task. " + "Head configs can be added under model_config.heads " + "in your yaml configs. Either remove this task if UNITER" + " is not meant to run on a dataset named {task}" + " or add a head config." ) head_config = head_configs[task] head_type = head_config.get("type", "mlp") head_class = registry.get_transformer_head_class(head_type) self.heads[task] = head_class(head_config) self.init_losses()
def head_from_config(config): head_type = config.get("type", "mlp") head_class = registry.get_transformer_head_class(head_type) return head_class(config)