Ejemplo n.º 1
0
class BertForClassification(nn.Module):
    """BERT with simple linear model."""
    def __init__(self, config):
        """Initialize the model with config dict.

        Args:
            config: python dict must contains the attributes below:
                config.bert_model_path: pretrained model path or model type
                    e.g. 'bert-base-chinese'
                config.hidden_size: The same as BERT model, usually 768
                config.num_classes: int, e.g. 2
                config.dropout: float between 0 and 1
        """
        super().__init__()
        if config.pretrained:
            self.bert = BertModel.from_pretrained(config.bert_model_path)
        else:
            config = BertConfig.from_json_file(config.bert_model_path +
                                               '/config.json')
            self.bert = BertModel(config=config)

        for param in self.bert.parameters():
            param.requires_grad = True
        self.dropout = nn.Dropout(config.dropout)
        self.linear = nn.Linear(config.hidden_size, config.num_classes)
        self.num_classes = config.num_classes

    def forward(self, input_ids, attention_mask, token_type_ids):
        """Forward inputs and get logits.

        Args:
            input_ids: (batch_size, max_seq_len)
            attention_mask: (batch_size, max_seq_len)
            token_type_ids: (batch_size, max_seq_len)

        Returns:
            logits: (batch_size, num_classes)
        """
        batch_size = input_ids.shape[0]
        bert_output = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            # encoder_hidden_states=False
        )
        # bert_output[0]: (batch_size, sequence_length, hidden_size)
        # bert_output[1]: (batch_size, hidden_size)
        pooled_output = bert_output[1]
        pooled_output = self.dropout(pooled_output)
        logits = self.linear(pooled_output).view(batch_size, self.num_classes)
        logits = nn.functional.softmax(logits, dim=-1)
        # logits: (batch_size, num_classes)
        return logits
Ejemplo n.º 2
0
class ParallelAdapterBertModel(BertModel):
    def __init__(self, config):
        super().__init__(config)

        # parallel, adapter-BERT
        self.parabert = BertModel(config.parabert_config)

        # freezing the pre-trained BERT
        self.freeze_original_params()

    def freeze_original_params(self):
        for param in self.parameters():
            param.requires_grad = False

        for param in self.parabert.parameters():
            param.requires_grad = True

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
    ):

        outputs_main = super().forward(input_ids, attention_mask,
                                       token_type_ids)
        outputs_adapter = self.parabert(input_ids, attention_mask,
                                        token_type_ids)

        outs_cls = []
        outs_cls.append(outputs_main[1])
        outs_cls.append(outputs_adapter[1])
        concat_cls = torch.cat(outs_cls, dim=1)

        outs_tok = []
        outs_tok.append(outputs_main[0])
        outs_tok.append(outputs_adapter[0])
        concat_tok = torch.cat(outs_tok, dim=2)

        outputs = (concat_tok, concat_cls)
        return outputs
Ejemplo n.º 3
0
class JointBERT(BertPreTrainedModel):
    def __init__(self, config, args, intent_label_lst, slot_label_lst):
        super(JointBERT, self).__init__(config)
        self.args = args
        self.num_intent_labels = len(intent_label_lst)
        self.num_slot_labels = len(slot_label_lst)
        self.bert = BertModel(config=config)  # Load pretrained bert
        if args.freeze_bert:
            for param in self.bert.parameters():
                param.requires_grad = False

        self.intent_classifier = IntentClassifier(config.hidden_size,
                                                  self.num_intent_labels,
                                                  args.dropout_rate)
        self.slot_classifier = SlotClassifier(config.hidden_size,
                                              self.num_slot_labels,
                                              args.dropout_rate)

        if args.use_crf:
            self.crf = CRF(num_tags=self.num_slot_labels, batch_first=True)

    def forward(self, input_ids, attention_mask, token_type_ids,
                intent_label_ids, slot_labels_ids):
        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids
        )  # sequence_output, pooled_output, (hidden_states), (attentions)
        sequence_output = outputs[0]
        pooled_output = outputs[1]  # [CLS]

        intent_logits = self.intent_classifier(pooled_output)
        slot_logits = self.slot_classifier(sequence_output)

        total_loss = 0
        # 1. Intent Softmax
        if intent_label_ids is not None:
            if self.num_intent_labels == 1:
                intent_loss_fct = nn.MSELoss()
                intent_loss = intent_loss_fct(intent_logits.view(-1),
                                              intent_label_ids.view(-1))
            else:
                intent_loss_fct = nn.CrossEntropyLoss()
                intent_loss = intent_loss_fct(
                    intent_logits.view(-1, self.num_intent_labels),
                    intent_label_ids.view(-1))
            total_loss += intent_loss

        # 2. Slot Softmax
        if slot_labels_ids is not None:
            if self.args.use_crf:
                slot_loss = self.crf(slot_logits,
                                     slot_labels_ids,
                                     mask=attention_mask.byte(),
                                     reduction='mean')
                slot_loss = -1 * slot_loss  # negative log-likelihood
            else:
                slot_loss_fct = nn.CrossEntropyLoss(
                    ignore_index=self.args.ignore_index)
                # Only keep active parts of the loss
                if attention_mask is not None:
                    active_loss = attention_mask.view(-1) == 1
                    active_logits = slot_logits.view(
                        -1, self.num_slot_labels)[active_loss]
                    active_labels = slot_labels_ids.view(-1)[active_loss]
                    slot_loss = slot_loss_fct(active_logits, active_labels)
                else:
                    slot_loss = slot_loss_fct(
                        slot_logits.view(-1, self.num_slot_labels),
                        slot_labels_ids.view(-1))
            total_loss += self.args.slot_loss_coef * slot_loss

        outputs = ((intent_logits, slot_logits), ) + outputs[
            2:]  # add hidden states and attention if they are here

        outputs = (total_loss, ) + outputs

        return outputs  # (loss), logits, (hidden_states), (attentions) # Logits is a tuple of intent and slot logits
Ejemplo n.º 4
0
class BertForMultiTaskClassification(BertPreTrainedModel):
    """
    PyTorch BERT class for multitask learning. This model allows you to load
    in some pretrained tasks in addition to creating new ones.

    Examples
    --------
    To instantiate a completely new instance of BertForMultiTaskClassification
    and load the weights into this architecture you can use the `from_pretrained`
    method of the base class by specifying the name of the weights to load, e.g.::

        model = BertForMultiTaskClassification.from_pretrained(
            'bert-base-uncased',
            new_task_dict=new_task_dict
        )

        # DO SOME TRAINING

        model.save(SOME_FOLDER, SOME_MODEL_ID)

    To instantiate an instance of BertForMultiTaskClassification that has layers for
    pretrained tasks and new tasks, you would do the following::

        model = BertForMultiTaskClassification.from_pretrained(
            'bert-base-uncased',
            pretrained_task_dict=pretrained_task_dict,
            new_task_dict=new_task_dict
        )

        model.load(SOME_FOLDER, SOME_MODEL_DICT)

        # DO SOME TRAINING

    Parameters
    ----------
    config: json file
        Defines the BERT model architecture.
        Note: you will most likely be instantiating the class with the `from_pretrained` method
        so you don't need to come up with your own config.
    pretrained_task_dict: dict
        dictionary mapping each pretrained task to the number of labels it has
    new_task_dict: dict
        dictionary mapping each new task to the number of labels it has
    dropout: float
        dropout percentage for Dropout layer
    """
    def __init__(self,
                 config,
                 pretrained_task_dict=None,
                 new_task_dict=None,
                 dropout=1e-1):
        super(BertForMultiTaskClassification, self).__init__(config)
        self.bert = BertModel(config)

        self.dropout = torch.nn.Dropout(dropout)

        if pretrained_task_dict is not None:
            pretrained_layers = {}
            for key, task_size in pretrained_task_dict.items():
                pretrained_layers[key] = nn.Linear(config.hidden_size,
                                                   task_size)
            self.pretrained_classifiers = nn.ModuleDict(pretrained_layers)
        if new_task_dict is not None:
            new_layers = {}
            for key, task_size in new_task_dict.items():
                new_layers[key] = nn.Linear(config.hidden_size, task_size)
            self.new_classifiers = nn.ModuleDict(new_layers)

    def forward(self, tokenized_input):
        """
        Defines forward pass for Bert model

        Parameters
        ----------
        tokenized_input: torch tensor of integers
            integers represent tokens for each word

        Returns
        ----------
        A dictionary mapping each task to its logits
        """
        outputs = self.bert(tokenized_input)

        pooled_output = self.dropout(outputs[1])

        logit_dict = {}
        if hasattr(self, 'pretrained_classifiers'):
            for key, classifier in self.pretrained_classifiers.items():
                logit_dict[key] = classifier(pooled_output)
        if hasattr(self, 'new_classifiers'):
            for key, classifier in self.new_classifiers.items():
                logit_dict[key] = classifier(pooled_output)

        return logit_dict

    def freeze_bert(self):
        """Freeze all core Bert layers"""
        for param in self.bert.parameters():
            param.requires_grad = False

    def freeze_pretrained_classifiers_and_bert(self):
        """Freeze pretrained classifier layers and core Bert layers"""
        self.freeze_bert()
        if hasattr(self, 'pretrained_classifiers'):
            for param in self.pretrained_classifiers.parameters():
                param.requires_grad = False
        else:
            print('There are no pretrained_classifier layers to be frozen.')

    def unfreeze_pretrained_classifiers(self):
        """Unfreeze pretrained classifier layers"""
        if hasattr(self, 'pretrained_classifiers'):
            for param in self.pretrained_classifiers.parameters():
                param.requires_grad = True
        else:
            print('There are no pretrained_classifier layers to be unfrozen.')

    def unfreeze_pretrained_classifiers_and_bert(self):
        """Unfreeze pretrained classifiers and core Bert layers"""
        for param in self.bert.parameters():
            param.requires_grad = True

        self.unfreeze_pretrained_classifiers()

    def save(self, folder, model_id):
        """
        Saves the model state dicts to a specific folder.
        Each part of the model is saved separately to allow for
        new classifiers to be added later.

        Note: if the model has `pretrained_classifiers` and `new_classifers`,
        they will be combined into the `pretrained_classifiers_dict`.

        Parameters
        ----------
        folder: str or Path
            place to store state dictionaries
        model_id: int
            unique id for this model

        Side Effects
        ------------
        saves three files:
            - folder / f'bert_dict_{model_id}.pth'
            - folder / f'dropout_dict_{model_id}.pth'
            - folder / f'pretrained_classifiers_dict_{model_id}.pth'
        """
        if hasattr(self, 'pretrained_classifiers'):
            # PyTorch's update method isn't working because it doesn't think ModuleDict is a Mapping
            classifiers_to_save = copy.deepcopy(self.pretrained_classifiers)
            if hasattr(self, 'new_classifiers'):
                for key, module in self.new_classifiers.items():
                    classifiers_to_save[key] = module
        else:
            classifiers_to_save = copy.deepcopy(self.new_classifiers)

        folder = Path(folder)
        folder.mkdir(parents=True, exist_ok=True)

        torch.save(self.bert.state_dict(),
                   folder / f'bert_dict_{model_id}.pth')
        torch.save(self.dropout.state_dict(),
                   folder / f'dropout_dict_{model_id}.pth')

        torch.save(classifiers_to_save.state_dict(),
                   folder / f'pretrained_classifiers_dict_{model_id}.pth')

    def load(self, folder, model_id):
        """
        Loads the model state dicts from a specific folder.

        Parameters
        ----------
        folder: str or Path
            place where state dictionaries are stored
        model_id: int
            unique id for this model

        Side Effects
        ------------
        loads from three files:
            - folder / f'bert_dict_{model_id}.pth'
            - folder / f'dropout_dict_{model_id}.pth'
            - folder / f'pretrained_classifiers_dict_{model_id}.pth'
        """
        folder = Path(folder)

        if torch.cuda.is_available():
            self.bert.load_state_dict(
                torch.load(folder / f'bert_dict_{model_id}.pth'))
            self.dropout.load_state_dict(
                torch.load(folder / f'dropout_dict_{model_id}.pth'))
            self.pretrained_classifiers.load_state_dict(
                torch.load(folder /
                           f'pretrained_classifiers_dict_{model_id}.pth'))
        else:
            self.bert.load_state_dict(
                torch.load(folder / f'bert_dict_{model_id}.pth',
                           map_location=lambda storage, loc: storage))
            self.dropout.load_state_dict(
                torch.load(folder / f'dropout_dict_{model_id}.pth',
                           map_location=lambda storage, loc: storage))
            self.pretrained_classifiers.load_state_dict(
                torch.load(folder /
                           f'pretrained_classifiers_dict_{model_id}.pth',
                           map_location=lambda storage, loc: storage))

    def export(self, folder, model_id, model_name=None):
        """
        Exports the entire model state dict to a specific folder.

        Note: if the model has `pretrained_classifiers` and `new_classifers`,
        they will be combined into the `pretrained_classifiers` attribute before being saved.

        Parameters
        ----------
        folder: str or Path
            place to store state dictionaries
        model_id: int
            unique id for this model
        model_name: str (defaults to None)
            Name to store model under, if None, will default to `multi_task_bert_{model_id}.pth`

        Side Effects
        ------------
        saves one file:
            - folder / model_name
        """
        hold_new_classifiers = copy.deepcopy(self.new_classifiers)
        hold_pretrained_classifiers = None
        if not hasattr(self, 'pretrained_classifiers'):
            self.pretrained_classifiers = copy.deepcopy(self.new_classifiers)
        else:
            hold_pretrained_classifiers = copy.deepcopy(
                self.pretrained_classifiers)
            # PyTorch's update method isn't working because it doesn't think ModuleDict is a Mapping
            for key, module in self.new_classifiers.items():
                self.pretrained_classifiers[key] = module

        del self.new_classifiers

        if model_name is None:
            model_name = f'multi_task_bert_{model_id}.pth'

        folder = Path(folder)
        folder.mkdir(parents=True, exist_ok=True)

        torch.save(self.state_dict(), folder / model_name)
        if hold_pretrained_classifiers is not None:
            self.pretrained_classifiers = hold_pretrained_classifiers
        else:
            del self.pretrained_classifiers
        self.new_classifiers = hold_new_classifiers

    def import_model(self, folder, file):
        """
        Imports the entire model state dict from a specific folder.

        Note: to export a model based on the import_model from this method,
        use the export method


        Parameters
        ----------
        folder: str or Path
            place to store state dictionaries
        file: str
            filename for the exported model object
        """
        folder = Path(folder)
        self.load_state_dict(torch.load(folder / file))
Ejemplo n.º 5
0
class BertResnetEnsembleForMultiTaskClassification(nn.Module):
    """
    PyTorch ensemble class for multitask learning consisting of a text and image models

    This model is made up of multiple component models:
    - for text: Google's BERT model
    - for images: multiple ResNet50's (the exact number depends on how
    the image model tasks were split up)

    You may need to train the component image and text models first
    before combining them into an ensemble model to get good results.

    Note: For explicitness, `vanilla` refers to the
    `transformers` BERT or `PyTorch` ResNet50 weights while
    `pretrained` refers to previously trained Tonks weights.

    Examples
    --------
    The ensemble model should be used with pretrained
    BERT and ResNet50 component models.
    To initialize a model in this way::

        image_task_dict = {
            'color_pattern': {
                'color': color_train_df['labels'].nunique(),
                'pattern': pattern_train_df['labels'].nunique()
            },
            'dress_sleeve': {
                'dress_length': dl_train_df['labels'].nunique(),
                'sleeve_length': sl_train_df['labels'].nunique()
            },
            'season': {
                'season': season_train_df['labels'].nunique()
            }
        }
        model = BertResnetEnsembleForMultiTaskClassification(
            image_task_dict=image_task_dict
        )

        resnet_model_id_dict = {
            'color_pattern': 'SOME_RESNET_MODEL_ID1',
            'dress_sleeve': 'SOME_RESNET_MODEL_ID2',
            'season': 'SOME_RESNET_MODEL_ID3'
        }

        model.load_core_models(
            folder='SOME_FOLDER',
            bert_model_id='SOME_BERT_MODEL_ID',
            resnet_model_id_dict=resnet_model_id_dict
        )

        # DO SOME TRAINING

        model.save(SOME_FOLDER, SOME_MODEL_ID)

        # OR

        model.export(SOME_FOLDER, SOME_MODEL_ID)

    Parameters
    ----------
    image_task_dict: dict
        dictionary mapping each pretrained ResNet50 models to a dictionary
        of the tasks it was trained on
    dropout: float
        dropout percentage for Dropout layer
    """
    def __init__(self, image_task_dict=None, dropout=1e-1):
        super(BertResnetEnsembleForMultiTaskClassification, self).__init__()

        # Define text architecture
        config = BertConfig()
        self.bert = BertModel(config)
        self.dropout = torch.nn.Dropout(dropout)

        self.image_task_dict = image_task_dict
        self.text_task_dict = self.create_text_dict(image_task_dict)

        # Define image architecture
        image_resnets = {}
        image_dense_layers = {}
        ensemble_layers = {}
        for key in self.image_task_dict.keys():
            resnet = torch_models.resnet50(pretrained=False)
            resnet.fc = _Identity()
            image_resnets[key] = resnet
            image_dense_layers[key] = nn.Sequential(
                _dense_block(2048 * 2, 1024, 2e-3),
                _dense_block(1024, 512, 2e-3), _dense_block(512, 256, 2e-3))

            # Define final ensemble before classifier layers
            # The input is size 768 from BERT and 256 from ResNet50 models
            # so the total size is 1024
            ensemble_layers[key] = nn.Sequential(
                _dense_block(1024, 512, 2e-3),
                _dense_block(512, 512, 2e-3),
                _dense_block(512, 256, 2e-3),
            )

        self.image_resnets = nn.ModuleDict(image_resnets)
        self.image_dense_layers = nn.ModuleDict(image_dense_layers)
        self.ensemble_layers = nn.ModuleDict(ensemble_layers)

        pretrained_layers = {}
        for key, task_size in self.text_task_dict.items():
            pretrained_layers[key] = nn.Linear(256, task_size)
        self.classifiers = nn.ModuleDict(pretrained_layers)

    def forward(self, x):
        """
        Defines forward pass for ensemble model

        Parameters
        ----------
        x: dict
            dictionary of torch tensors with keys:
                - `bert_text`: integers mapping to BERT vocabulary
                - `full_img`: tensor of full image
                - `crop_img`: tensor of cropped image

        Returns
        ----------
        A dictionary mapping each task to its logits
        """
        bert_output = self.bert(x['bert_text'])

        pooled_output = self.dropout(bert_output[1])

        logit_dict = {}

        for key in self.image_task_dict.keys():
            full_img = self.image_resnets[key](x['full_img'])
            crop_img = self.image_resnets[key](x['crop_img'])
            full_crop_combined = torch.cat((full_img, crop_img), 1)
            dense_layer_output = self.image_dense_layers[key](
                full_crop_combined)
            ensemble_input = torch.cat((pooled_output, dense_layer_output), 1)
            ensemble_layer_output = self.ensemble_layers[key](ensemble_input)

            for task in self.image_task_dict[key].keys():
                classifier = self.classifiers[task]
                logit_dict[task] = classifier(ensemble_layer_output)

        return logit_dict

    def freeze_bert(self):
        """Freeze all core BERT layers"""
        for param in self.bert.parameters():
            param.requires_grad = False

    def freeze_resnets(self):
        """Freeze all core ResNet models layers"""
        for key in self.image_resnets.keys():
            for param in self.image_resnets[key].parameters():
                param.requires_grad = False
            for param in self.image_dense_layers[key].parameters():
                param.requires_grad = False

    def freeze_ensemble_layers(self):
        """Freeze all final ensemble layers"""
        for key in self.ensemble_layers.keys():
            for param in self.ensemble_layers[key].parameters():
                param.requires_grad = False

    def freeze_classifiers_and_core(self):
        """Freeze pretrained classifier layers and core BERT/ResNet layers"""
        self.freeze_bert()
        self.freeze_resnets()
        self.freeze_ensemble_layers()
        for param in self.classifiers.parameters():
            param.requires_grad = False

    def unfreeze_classifiers(self):
        """Unfreeze pretrained classifier layers"""
        for param in self.classifiers.parameters():
            param.requires_grad = True

    def unfreeze_classifiers_and_core(self):
        """Unfreeze pretrained classifiers and core BERT/ResNet layers"""
        for param in self.bert.parameters():
            param.requires_grad = True
        for key in self.image_resnets.keys():
            for param in self.image_resnets[key].parameters():
                param.requires_grad = True
            for param in self.image_dense_layers[key].parameters():
                param.requires_grad = True
            for param in self.ensemble_layers[key].parameters():
                param.requires_grad = True

        self.unfreeze_classifiers()

    def save(self, folder, model_id):
        """
        Saves the model state dicts to a specific folder.
        Each part of the model is saved separately,
        along with the image_task_dict, which is needed to reinstantiate the model.

        Parameters
        ----------
        folder: str or Path
            place to store state dictionaries
        model_id: int
            unique id for this model

        Side Effects
        ------------
        saves six files:
            - folder / f'bert_dict_{model_id}.pth'
            - folder / f'dropout_dict_{model_id}.pth'
            - folder / f'image_resnets_dict_{model_id}.pth'
            - folder / f'image_dense_layers_dict_{model_id}.pth'
            - folder / f'ensemble_layers_dict_{model_id}.pth'
            - folder / f'classifiers_dict_{model_id}.pth'
        """
        folder = Path(folder)
        folder.mkdir(parents=True, exist_ok=True)

        # BERT model
        torch.save(self.bert.state_dict(),
                   folder / f'bert_dict_{model_id}.pth')
        torch.save(self.dropout.state_dict(),
                   folder / f'dropout_dict_{model_id}.pth')

        # ResNet model(s)
        torch.save(self.image_resnets.state_dict(),
                   folder / f'image_resnets_dict_{model_id}.pth')
        torch.save(self.image_dense_layers.state_dict(),
                   folder / f'image_dense_layers_dict_{model_id}.pth')

        # Ensemble layers
        torch.save(self.ensemble_layers.state_dict(),
                   folder / f'ensemble_layers_dict_{model_id}.pth')

        # Classifier layers
        torch.save(self.classifiers.state_dict(),
                   folder / f'classifiers_dict_{model_id}.pth')

        # image_task_dict
        joblib.dump(self.image_task_dict,
                    folder / f'image_task_dict_{model_id}.pickle')

    def load(self, folder, model_id):
        """
        Loads the model state dicts for ensemble model
        from a specific folder. This will load all the model
        components including the final ensemble and existing
        pretrained `classifiers`.

        Parameters
        ----------
        folder: str or Path
            place where state dictionaries are stored
        model_id: int
            unique id for this model

        Side Effects
        ------------
        loads from six files:
            - folder / f'bert_dict_{model_id}.pth'
            - folder / f'dropout_dict_{model_id}.pth'
            - folder / f'image_resnets_dict_{model_id}.pth'
            - folder / f'dense_layers_dict_{model_id}.pth'
            - folder / f'ensemble_layers_dict_{model_id}.pth'
            - folder / f'classifiers_dict_{model_id}.pth'

        """
        folder = Path(folder)

        if torch.cuda.is_available():
            self.bert.load_state_dict(
                torch.load(folder / f'bert_dict_{model_id}.pth'))
            self.dropout.load_state_dict(
                torch.load(folder / f'dropout_dict_{model_id}.pth'))

            self.image_resnets.load_state_dict(
                torch.load(folder / f'image_resnets_dict_{model_id}.pth'))
            self.dense_layers.load_state_dict(
                torch.load(folder / f'image_dense_layers_dict_{model_id}.pth'))

            self.ensemble_layers.load_state_dict(
                torch.load(folder / f'ensemble_layers_dict_{model_id}.pth'))
            self.classifiers.load_state_dict(
                torch.load(folder / f'classifiers_dict_{model_id}.pth'))
        else:
            self.bert.load_state_dict(
                torch.load(folder / f'bert_dict_{model_id}.pth',
                           map_location=lambda storage, loc: storage))
            self.dropout.load_state_dict(
                torch.load(folder / f'dropout_dict_{model_id}.pth',
                           map_location=lambda storage, loc: storage))

            self.resnet.load_state_dict(
                torch.load(folder / f'image_resnets_dict_{model_id}.pth',
                           map_location=lambda storage, loc: storage))
            self.dense_layers.load_state_dict(
                torch.load(folder / f'image_dense_layers_dict_{model_id}.pth',
                           map_location=lambda storage, loc: storage))

            self.final_ensemble.load_state_dict(
                torch.load(folder / f'ensemble_layers_dict_{model_id}.pth',
                           map_location=lambda storage, loc: storage))

            self.classifiers.load_state_dict(
                torch.load(folder / f'classifiers_dict_{model_id}.pth',
                           map_location=lambda storage, loc: storage))

    def load_core_models(self, folder, bert_model_id, resnet_model_id_dict):
        """
        Loads the weights from pretrained BERT and ResNet50 Tonks models

        Does not load weights from the final ensemble and classifier layers.
        use case is for loading SR_pretrained component BERT and image model
        weights into a new ensemble model.

        Parameters
        ----------
        folder: str or Path
            place where state dictionaries are stored
        bert_model_id: int
            unique id for pretrained BERT text model
        resnet_model_id: int
            unique id for pretrained image model

        Side Effects
        ------------
        loads from four files:
            - folder / f'bert_dict_{bert_model_id}.pth'
            - folder / f'dropout_dict_{bert_model_id}.pth'
            - folder / f'resnet_dict_{resnet_model_id}.pth'
            - folder / f'dense_layers_dict_{resnet_model_id}.pth'
        """
        folder = Path(folder)

        if torch.cuda.is_available():
            self.bert.load_state_dict(
                torch.load(folder / f'bert_dict_{bert_model_id}.pth'))
            self.dropout.load_state_dict(
                torch.load(folder / f'dropout_dict_{bert_model_id}.pth'))

            for key, model_id in resnet_model_id_dict.items():
                self.image_resnets[key].load_state_dict(
                    torch.load(folder / f'resnet_dict_{model_id}.pth'))
                self.image_dense_layers[key].load_state_dict(
                    torch.load(folder / f'dense_layers_dict_{model_id}.pth'))

        else:
            self.bert.load_state_dict(
                torch.load(folder / f'bert_dict_{bert_model_id}.pth',
                           map_location=lambda storage, loc: storage))
            self.dropout.load_state_dict(
                torch.load(folder / f'dropout_dict_{bert_model_id}.pth',
                           map_location=lambda storage, loc: storage))

            for key, model_id in resnet_model_id_dict.items():
                self.image_resnets[key].load_state_dict(
                    torch.load(folder / f'resnet_dict_{model_id}.pth'),
                    map_location=lambda storage, loc: storage)
                self.image_dense_layers[key].load_state_dict(
                    torch.load(folder / f'dense_layers_dict_{model_id}.pth'),
                    map_location=lambda storage, loc: storage)

    def export(self, folder, model_id, model_name=None):
        """
        Exports the entire model state dict to a specific folder,
        along with the image_task_dict, which is needed to reinstantiate the model.

        Parameters
        ----------
        folder: str or Path
            place to store state dictionaries
        model_id: int
            unique id for this model
        model_name: str (defaults to None)
            Name to store model under, if None, will default to `multi_task_ensemble_{model_id}.pth`

        Side Effects
        ------------
        saves two files:
            - folder / f'multi_task_ensemble_{model_id}.pth'
            - folder / f'image_task_dict_{model_id}.pickle'
        """
        folder = Path(folder)
        folder.mkdir(parents=True, exist_ok=True)

        if model_name is None:
            model_name = f'multi_task_ensemble_{model_id}.pth'

        torch.save(self.state_dict(), folder / model_name)

        joblib.dump(self.image_task_dict,
                    folder / f'image_task_dict_{model_id}.pickle')

    @staticmethod
    def create_text_dict(image_task_dict):
        """Create a task dict for the text model from the image task dict"""
        text_task_dict = {}
        for joint_task in image_task_dict.keys():
            for task, task_size in image_task_dict[joint_task].items():
                if task in text_task_dict.keys():
                    raise TonksError(
                        'Task {} is in multiple models. Each task can only be in one image model.'
                        .format(task))
                text_task_dict[task] = task_size

        return text_task_dict
class DocumentBertTransformer(BertPreTrainedModel):
    """
    BERT -> TransformerEncoder -> Max over attention output.
    """
    def __init__(self, bert_model_config: BertConfig):
        super(DocumentBertTransformer, self).__init__(bert_model_config)
        self.bert = BertModel(bert_model_config)
        self.bert_batch_size = self.bert.config.bert_batch_size
        self.dropout = nn.Dropout(p=bert_model_config.hidden_dropout_prob)

        encoder_layer = TransformerEncoderLayer(
            d_model=bert_model_config.hidden_size,
            nhead=6,
            dropout=bert_model_config.hidden_dropout_prob)
        self.transformer_encoder = TransformerEncoder(encoder_layer,
                                                      num_layers=6)
        self.classifier = nn.Sequential(
            nn.Dropout(p=bert_model_config.hidden_dropout_prob),
            nn.Linear(bert_model_config.hidden_size,
                      bert_model_config.num_labels), nn.Tanh())

    #input_ids, token_type_ids, attention_masks
    def forward(self,
                document_batch: torch.Tensor,
                document_sequence_lengths: list,
                device='cuda'):

        #contains all BERT sequences
        #bert should output a (batch_size, num_sequences, bert_hidden_size)
        bert_output = torch.zeros(size=(document_batch.shape[0],
                                        min(document_batch.shape[1],
                                            self.bert_batch_size),
                                        self.bert.config.hidden_size),
                                  dtype=torch.float,
                                  device=device)

        #only pass through bert_batch_size numbers of inputs into bert.
        #this means that we are possibly cutting off the last part of documents.
        for doc_id in range(document_batch.shape[0]):
            bert_output[doc_id][:self.bert_batch_size] = self.dropout(
                self.bert(document_batch[doc_id][:self.bert_batch_size, 0],
                          token_type_ids=document_batch[doc_id]
                          [:self.bert_batch_size, 1],
                          attention_mask=document_batch[doc_id]
                          [:self.bert_batch_size, 2])[1])

        transformer_output = self.transformer_encoder(
            bert_output.permute(1, 0, 2))

        #print(transformer_output.shape)

        prediction = self.classifier(
            transformer_output.permute(1, 0, 2).max(dim=1)[0])
        assert prediction.shape[0] == document_batch.shape[0]
        return prediction

    def freeze_bert_encoder(self):
        for param in self.bert.parameters():
            param.requires_grad = False

    def unfreeze_bert_encoder(self):
        for param in self.bert.parameters():
            param.requires_grad = True

    def unfreeze_bert_encoder_last_layers(self):
        for name, param in self.bert.named_parameters():
            if "encoder.layer.11" in name or "pooler" in name:
                param.requires_grad = True

    def unfreeze_bert_encoder_pooler_layer(self):
        for name, param in self.bert.named_parameters():
            if "pooler" in name:
                param.requires_grad = True
class DocumentBertLSTM(BertPreTrainedModel):
    """
    BERT output over document in LSTM
    """
    def __init__(self, bert_model_config: BertConfig):
        super(DocumentBertLSTM, self).__init__(bert_model_config)
        self.bert = BertModel(bert_model_config)
        self.bert_batch_size = self.bert.config.bert_batch_size
        self.dropout = nn.Dropout(p=bert_model_config.hidden_dropout_prob)
        self.lstm = LSTM(
            bert_model_config.hidden_size,
            bert_model_config.hidden_size,
        )
        self.classifier = nn.Sequential(
            nn.Dropout(p=bert_model_config.hidden_dropout_prob),
            nn.Linear(bert_model_config.hidden_size,
                      bert_model_config.num_labels), nn.Tanh())

    #input_ids, token_type_ids, attention_masks
    def forward(self,
                document_batch: torch.Tensor,
                document_sequence_lengths: list,
                device='cuda'):

        #contains all BERT sequences
        #bert should output a (batch_size, num_sequences, bert_hidden_size)
        bert_output = torch.zeros(size=(document_batch.shape[0],
                                        min(document_batch.shape[1],
                                            self.bert_batch_size),
                                        self.bert.config.hidden_size),
                                  dtype=torch.float,
                                  device=device)

        #only pass through bert_batch_size numbers of inputs into bert.
        #this means that we are possibly cutting off the last part of documents.
        #use_grad = not freeze_bert
        #with torch.set_grad_enabled(False):

        for doc_id in range(document_batch.shape[0]):
            bert_output[doc_id][:self.bert_batch_size] = self.dropout(
                self.bert(document_batch[doc_id][:self.bert_batch_size, 0],
                          token_type_ids=document_batch[doc_id]
                          [:self.bert_batch_size, 1],
                          attention_mask=document_batch[doc_id]
                          [:self.bert_batch_size, 2])[1])

        #lstm expects a ( num_sequences, batch_size (i.e. number of documents) , bert_hidden_size )
        #self.lstm.flatten_parameters()
        output, (_, _) = self.lstm(bert_output.permute(1, 0, 2))

        #print(bert_output.requires_grad)
        #print(output.requires_grad)

        last_layer = output[-1]
        #print("Last LSTM layer shape:",last_layer.shape)

        prediction = self.classifier(last_layer)
        #print("Prediction Shape", prediction.shape)
        assert prediction.shape[0] == document_batch.shape[0]
        return prediction

    def freeze_bert_encoder(self):
        for param in self.bert.parameters():
            param.requires_grad = False

    def unfreeze_bert_encoder(self):
        for param in self.bert.parameters():
            param.requires_grad = True

    def unfreeze_bert_encoder_last_layers(self):
        for name, param in self.bert.named_parameters():
            if "encoder.layer.11" in name or "pooler" in name:
                param.requires_grad = True

    def unfreeze_bert_encoder_pooler_layer(self):
        for name, param in self.bert.named_parameters():
            if "pooler" in name:
                param.requires_grad = True