Exemplo n.º 1
0
 def from_pretrained(model_id_or_path: str,
                     device: Optional[torch.device] = None):
     torch_model = TorchBertModel.from_pretrained(model_id_or_path)
     model = BertModelNoPooler.from_torch(torch_model, device)
     model.config = torch_model.config
     model._torch_model = torch_model  # prevent destroy torch model.
     return model
    def __init__(self, tagset_size):
        super(BertForSequenceClassification, self).__init__()
        self.tagset_size = tagset_size

        self.BertModel_single = BertModel.from_pretrained(pretrain_model_dir)
        self.single_hidden2tag = BertClassificationHead(
            bert_hidden_dim, tagset_size)
Exemplo n.º 3
0
    def __init__(self,
                 args,
                 tokenizer: BertTokenizer,
                 object_features_variant=False,
                 positional_embed_variant=False,
                 latent_transformer=False):
        super().__init__()

        self.args = args
        self.tokenizer = tokenizer

        self.image_projection = nn.Sequential(
            nn.Linear(512, 768), nn.BatchNorm1d(768, momentum=0.01))

        config = BertConfig.from_pretrained('bert-base-uncased')
        self.tokenizer = tokenizer
        self.embeddings = BertEmbeddings(config)

        self.text_encoder = BertModel.from_pretrained("bert-base-uncased",
                                                      return_dict=True)
        self.decoder = BertLMHeadModel.from_pretrained(
            'bert-base-uncased',
            is_decoder=True,
            use_cache=True,
            add_cross_attention=True)

        if object_features_variant:
            self.image_transformer = ImageTransformerEncoder(args)

        self.positional_embed = True if positional_embed_variant else False

        self.latent_transformer = latent_transformer
Exemplo n.º 4
0
    def __init__(
        self,
        vocab: Vocabulary,
        bert_model: Union[str, BertModel],
        embedding_dropout: float = 0.0,
        initializer: InitializerApplicator = InitializerApplicator(),
        label_smoothing: float = None,
        ignore_span_metric: bool = False,
        srl_eval_path: str = DEFAULT_SRL_EVAL_PATH,
        **kwargs,
    ) -> None:
        super().__init__(vocab, **kwargs)

        if isinstance(bert_model, str):
            self.bert_model = BertModel.from_pretrained(bert_model)
        else:
            self.bert_model = bert_model

        self.num_classes = self.vocab.get_vocab_size("labels")
        if srl_eval_path is not None:
            # For the span based evaluation, we don't want to consider labels
            # for verb, because the verb index is provided to the model.
            self.span_metric = SrlEvalScorer(srl_eval_path,
                                             ignore_classes=["V"])
        else:
            self.span_metric = None
        self.tag_projection_layer = Linear(self.bert_model.config.hidden_size,
                                           self.num_classes)

        self.embedding_dropout = Dropout(p=embedding_dropout)
        self._label_smoothing = label_smoothing
        self.ignore_span_metric = ignore_span_metric
        initializer(self)
Exemplo n.º 5
0
def main():
    if len(sys.argv) != 3:
        print(
            "Usage: \n"
            "    convert_huggingface_bert_to_npz model_name (bert-base-uncased) output_file"
        )
        exit(0)
    torch.set_grad_enabled(False)

    model_name = sys.argv[1]
    model = BertModel.from_pretrained(model_name)
    arrays = {k: v.detach() for k, v in model.named_parameters()}

    q_weight_key = 'self.query.weight'
    k_weight_key = 'self.key.weight'
    v_weight_key = 'self.value.weight'

    q_bias_key = 'self.query.bias'
    k_bias_key = 'self.key.bias'
    v_bias_key = 'self.value.bias'

    numpy_dict = {}
    for k in arrays.keys():
        if k.endswith(q_weight_key):
            v = torch.clone(
                torch.t(
                    torch.cat([
                        arrays[k],
                        arrays[k[:-len(q_weight_key)] + k_weight_key],
                        arrays[k[:-len(q_weight_key)] + v_weight_key]
                    ], 0).contiguous()).contiguous())
            numpy_dict[k[:-len(q_weight_key)] + "qkv.weight"] = v.numpy()
        elif k.endswith(q_bias_key):
            v = torch.cat([
                arrays[k], arrays[k[:-len(q_bias_key)] + k_bias_key],
                arrays[k[:-len(q_bias_key)] + v_bias_key]
            ], 0).numpy()
            numpy_dict[k[:-len(q_bias_key)] + 'qkv.bias'] = v
        elif any((k.endswith(suffix) for suffix in (k_weight_key, v_weight_key,
                                                    k_bias_key, v_bias_key))):
            continue
        elif (k.endswith("attention.output.dense.weight")
              or k.endswith("pooler.dense.weight")
              or (k.endswith("output.dense.weight")
                  or k.endswith("intermediate.dense.weight"))):
            numpy_dict[k] = torch.clone(torch.t(
                arrays[k]).contiguous()).numpy()
        else:
            numpy_dict[k] = arrays[k].numpy()
    del arrays
    del model
    numpy.savez_compressed(sys.argv[2], **numpy_dict)
Exemplo n.º 6
0
    def __init__(self,
                 bert_model_name,
                 trainable=False,
                 output_size=0,
                 activation=gelu,
                 dropout=0.0):
        """Initializes pertrained `BERT` model

        Arguments:
            bert_model_name {str} -- bert model name

        Keyword Arguments:
            output_size {float} -- output size (default: {None})
            activation {nn.Module} -- activation function (default: {gelu})
            dropout {float} -- dropout rate (default: {0.0})
        """

        super().__init__()
        self.bert_model = BertModel.from_pretrained(bert_model_name,
                                                    output_attentions=True,
                                                    output_hidden_states=True)
        logger.info("Load bert model {} successfully.".format(bert_model_name))

        self.output_size = output_size

        if trainable:
            logger.info(
                "Start fine-tuning bert model {}.".format(bert_model_name))
        else:
            logger.info("Keep fixed bert model {}.".format(bert_model_name))

        for param in self.bert_model.parameters():
            param.requires_grad = trainable

        if self.output_size > 0:
            self.mlp = BertLinear(
                input_size=self.bert_model.config.hidden_size,
                output_size=self.output_size,
                activation=activation)
        else:
            self.output_size = self.bert_model.config.hidden_size
            self.mlp = lambda x: x

        if dropout > 0:
            self.dropout = nn.Dropout(p=dropout)
        else:
            self.dropout = lambda x: x
Exemplo n.º 7
0
    def __init__(self,
                 name_or_path_or_model: Union[str, BertModel],
                 adapter_size: int = 64,
                 adapter_num: int = 12,
                 external_param: Union[bool, List[bool]] = False,
                 **kwargs):
        super().__init__()
        if isinstance(name_or_path_or_model, str):
            self.bert = BertModel.from_pretrained(name_or_path_or_model)
        else:
            self.bert = name_or_path_or_model

        set_requires_grad(self.bert, False)

        if isinstance(external_param, bool):
            param_place = [external_param for _ in range(adapter_num)]
        elif isinstance(external_param, list):
            param_place = [False for _ in range(adapter_num)]
            for i, e in enumerate(external_param, 1):
                param_place[-i] = e
        else:
            raise ValueError("wrong type of external_param!")

        self.adapters = nn.ModuleList([
            nn.ModuleList([
                Adapter(self.bert.config.hidden_size, adapter_size,
                        param_place[i]),
                Adapter(self.bert.config.hidden_size, adapter_size,
                        param_place[i])
            ]) for i in range(adapter_num)
        ])

        for i, adapters in enumerate(self.adapters, 1):
            layer = self.bert.encoder.layer[-i]
            layer.output = AdapterBertOutput(layer.output, adapters[0].forward)
            set_requires_grad(layer.output.base.LayerNorm, True)
            layer.attention.output = AdapterBertOutput(layer.attention.output,
                                                       adapters[1].forward)
            set_requires_grad(layer.attention.output.base.LayerNorm, True)

        self.output_dim = self.bert.config.hidden_size
Exemplo n.º 8
0
    def __init__(
        self,
        vocab: Vocabulary,
        bert_model: Union[str, Dict[str, Any], BertModel],
        embedding_dropout: float = 0.0,
        initializer: InitializerApplicator = InitializerApplicator(),
        label_smoothing: float = None,
        ignore_span_metric: bool = False,
        srl_eval_path: str = DEFAULT_SRL_EVAL_PATH,
        **kwargs,
    ) -> None:
        super().__init__(vocab, **kwargs)

        if isinstance(bert_model, str):
            self.bert_model = BertModel.from_pretrained(bert_model)
        elif isinstance(bert_model, dict):
            warnings.warn(
                "Initializing BertModel without pretrained weights. This is fine if you're loading "
                "from an AllenNLP archive, but not if you're training.",
                UserWarning,
            )
            bert_config = BertConfig.from_dict(bert_model)
            self.bert_model = BertModel(bert_config)
        else:
            self.bert_model = bert_model

        self.num_classes = self.vocab.get_vocab_size("labels")
        if srl_eval_path is not None:
            # For the span based evaluation, we don't want to consider labels
            # for verb, because the verb index is provided to the model.
            self.span_metric = SrlEvalScorer(srl_eval_path,
                                             ignore_classes=["V"])
        else:
            self.span_metric = None
        self.tag_projection_layer = Linear(self.bert_model.config.hidden_size,
                                           self.num_classes)

        self.embedding_dropout = Dropout(p=embedding_dropout)
        self._label_smoothing = label_smoothing
        self.ignore_span_metric = ignore_span_metric
        initializer(self)
    def test_from_pytorch(self):
        with torch.no_grad():
            with self.subTest("bert-base-cased"):
                tokenizer = BertTokenizerFast.from_pretrained(
                    "bert-base-cased")
                fx_model = FlaxBertModel.from_pretrained("bert-base-cased")
                pt_model = BertModel.from_pretrained("bert-base-cased")

                # Check for simple input
                pt_inputs = tokenizer.encode_plus(
                    "This is a simple input",
                    return_tensors=TensorType.PYTORCH)
                fx_inputs = tokenizer.encode_plus(
                    "This is a simple input", return_tensors=TensorType.JAX)
                pt_outputs = pt_model(**pt_inputs).to_tuple()
                fx_outputs = fx_model(**fx_inputs)

                self.assertEqual(
                    len(fx_outputs), len(pt_outputs),
                    "Output lengths differ between Flax and PyTorch")

                for fx_output, pt_output in zip(fx_outputs, pt_outputs):
                    self.assert_almost_equals(fx_output, pt_output.numpy(),
                                              5e-3)
Exemplo n.º 10
0
def main(config):
    args = config

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    processor = ATEPCProcessor()
    label_list = processor.get_labels()
    num_labels = len(label_list) + 1

    datasets = {
        'camera': "atepc_datasets/camera",
        'car': "atepc_datasets/car",
        'phone': "atepc_datasets/phone",
        'notebook': "atepc_datasets/notebook",
        'laptop': "atepc_datasets/laptop",
        'restaurant': "atepc_datasets/restaurant",
        'twitter': "atepc_datasets/twitter",
        'mixed': "atepc_datasets/mixed",
    }
    pretrained_bert_models = {
        'camera': "bert-base-chinese",
        'car': "bert-base-chinese",
        'phone': "bert-base-chinese",
        'notebook': "bert-base-chinese",
        'laptop': "bert-base-uncased",
        'restaurant': "bert-base-uncased",
        # for loading domain-adapted BERT
        # 'restaurant': "../bert_pretrained_restaurant",
        'twitter': "bert-base-uncased",
        'mixed': "bert-base-multilingual-uncased",
    }

    args.bert_model = pretrained_bert_models[args.dataset]
    args.data_dir = datasets[args.dataset]

    def convert_polarity(examples):
        for i in range(len(examples)):
            polarities = []
            for polarity in examples[i].polarity:
                if polarity == 2:
                    polarities.append(1)
                else:
                    polarities.append(polarity)
            examples[i].polarity = polarities

    tokenizer = BertTokenizer.from_pretrained(args.bert_model,
                                              do_lower_case=True)
    train_examples = processor.get_train_examples(args.data_dir)
    eval_examples = processor.get_test_examples(args.data_dir)
    num_train_optimization_steps = int(
        len(train_examples) / args.train_batch_size /
        args.gradient_accumulation_steps) * args.num_train_epochs
    bert_base_model = BertModel.from_pretrained(args.bert_model)
    bert_base_model.config.num_labels = num_labels

    if args.dataset in {'camera', 'car', 'phone', 'notebook'}:
        convert_polarity(train_examples)
        convert_polarity(eval_examples)
        model = LCF_ATEPC(bert_base_model, args=args)
    else:
        model = LCF_ATEPC(bert_base_model, args=args)

    for arg in vars(args):
        logger.info('>>> {0}: {1}'.format(arg, getattr(args, arg)))

    model.to(device)

    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.00001
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.00001
    }]

    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      weight_decay=0.00001)
    eval_features = convert_examples_to_features(eval_examples, label_list,
                                                 args.max_seq_length,
                                                 tokenizer)
    all_spc_input_ids = torch.tensor([f.input_ids_spc for f in eval_features],
                                     dtype=torch.long)
    all_input_mask = torch.tensor([f.input_mask for f in eval_features],
                                  dtype=torch.long)
    all_segment_ids = torch.tensor([f.segment_ids for f in eval_features],
                                   dtype=torch.long)
    all_label_ids = torch.tensor([f.label_id for f in eval_features],
                                 dtype=torch.long)
    all_polarities = torch.tensor([f.polarities for f in eval_features],
                                  dtype=torch.long)
    all_valid_ids = torch.tensor([f.valid_ids for f in eval_features],
                                 dtype=torch.long)
    all_lmask_ids = torch.tensor([f.label_mask for f in eval_features],
                                 dtype=torch.long)
    eval_data = TensorDataset(all_spc_input_ids, all_input_mask,
                              all_segment_ids, all_label_ids, all_polarities,
                              all_valid_ids, all_lmask_ids)
    # Run prediction for full data
    eval_sampler = RandomSampler(eval_data)
    eval_dataloader = DataLoader(eval_data,
                                 sampler=eval_sampler,
                                 batch_size=args.eval_batch_size)

    def evaluate(eval_ATE=True, eval_APC=True):
        # evaluate
        apc_result = {'max_apc_test_acc': 0, 'max_apc_test_f1': 0}
        ate_result = 0
        y_true = []
        y_pred = []
        n_test_correct, n_test_total = 0, 0
        test_apc_logits_all, test_polarities_all = None, None
        model.eval()
        label_map = {i: label for i, label in enumerate(label_list, 1)}
        for input_ids_spc, input_mask, segment_ids, label_ids, polarities, valid_ids, l_mask in eval_dataloader:
            input_ids_spc = input_ids_spc.to(device)
            input_mask = input_mask.to(device)
            segment_ids = segment_ids.to(device)
            valid_ids = valid_ids.to(device)
            label_ids = label_ids.to(device)
            polarities = polarities.to(device)
            l_mask = l_mask.to(device)

            with torch.no_grad():
                ate_logits, apc_logits = model(input_ids_spc,
                                               segment_ids,
                                               input_mask,
                                               valid_ids=valid_ids,
                                               polarities=polarities,
                                               attention_mask_label=l_mask)
            if eval_APC:
                polarities = model.get_batch_polarities(polarities)
                n_test_correct += (torch.argmax(
                    apc_logits, -1) == polarities).sum().item()
                n_test_total += len(polarities)

                if test_polarities_all is None:
                    test_polarities_all = polarities
                    test_apc_logits_all = apc_logits
                else:
                    test_polarities_all = torch.cat(
                        (test_polarities_all, polarities), dim=0)
                    test_apc_logits_all = torch.cat(
                        (test_apc_logits_all, apc_logits), dim=0)

            if eval_ATE:
                if not args.use_bert_spc:
                    label_ids = model.get_batch_token_labels_bert_base_indices(
                        label_ids)
                ate_logits = torch.argmax(F.log_softmax(ate_logits, dim=2),
                                          dim=2)
                ate_logits = ate_logits.detach().cpu().numpy()
                label_ids = label_ids.to('cpu').numpy()
                input_mask = input_mask.to('cpu').numpy()
                for i, label in enumerate(label_ids):
                    temp_1 = []
                    temp_2 = []
                    for j, m in enumerate(label):
                        if j == 0:
                            continue
                        elif label_ids[i][j] == len(label_list):
                            y_true.append(temp_1)
                            y_pred.append(temp_2)
                            break
                        else:
                            temp_1.append(label_map.get(label_ids[i][j], 'O'))
                            temp_2.append(label_map.get(ate_logits[i][j], 'O'))
        if eval_APC:
            test_acc = n_test_correct / n_test_total
            if args.dataset in {'camera', 'car', 'phone', 'notebook'}:
                test_f1 = f1_score(torch.argmax(test_apc_logits_all, -1).cpu(),
                                   test_polarities_all.cpu(),
                                   labels=[0, 1],
                                   average='macro')
            else:
                test_f1 = f1_score(torch.argmax(test_apc_logits_all, -1).cpu(),
                                   test_polarities_all.cpu(),
                                   labels=[0, 1, 2],
                                   average='macro')
            test_acc = round(test_acc * 100, 2)
            test_f1 = round(test_f1 * 100, 2)
            apc_result = {
                'max_apc_test_acc': test_acc,
                'max_apc_test_f1': test_f1
            }

        if eval_ATE:
            report = classification_report(y_true, y_pred, digits=4)
            tmps = report.split()
            ate_result = round(float(tmps[7]) * 100, 2)
        return apc_result, ate_result

    def save_model(path):
        # Save a trained model and the associated configuration,
        # Take care of the storage!
        os.makedirs(path, exist_ok=True)
        model_to_save = model.module if hasattr(
            model, 'module') else model  # Only save the model it-self
        model_to_save.save_pretrained(path)
        tokenizer.save_pretrained(path)
        label_map = {i: label for i, label in enumerate(label_list, 1)}
        model_config = {
            "bert_model": args.bert_model,
            "do_lower": True,
            "max_seq_length": args.max_seq_length,
            "num_labels": len(label_list) + 1,
            "label_map": label_map
        }
        json.dump(model_config, open(os.path.join(path, "config.json"), "w"))
        logger.info('save model to: {}'.format(path))

    def train():
        train_features = convert_examples_to_features(train_examples,
                                                      label_list,
                                                      args.max_seq_length,
                                                      tokenizer)
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_examples))
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", num_train_optimization_steps)
        all_spc_input_ids = torch.tensor(
            [f.input_ids_spc for f in train_features], dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in train_features],
                                      dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in train_features],
                                       dtype=torch.long)
        all_label_ids = torch.tensor([f.label_id for f in train_features],
                                     dtype=torch.long)
        all_valid_ids = torch.tensor([f.valid_ids for f in train_features],
                                     dtype=torch.long)
        all_lmask_ids = torch.tensor([f.label_mask for f in train_features],
                                     dtype=torch.long)
        all_polarities = torch.tensor([f.polarities for f in train_features],
                                      dtype=torch.long)
        train_data = TensorDataset(all_spc_input_ids, all_input_mask,
                                   all_segment_ids, all_label_ids,
                                   all_polarities, all_valid_ids,
                                   all_lmask_ids)

        train_sampler = SequentialSampler(train_data)
        train_dataloader = DataLoader(train_data,
                                      sampler=train_sampler,
                                      batch_size=args.train_batch_size)
        max_apc_test_acc = 0
        max_apc_test_f1 = 0
        max_ate_test_f1 = 0

        global_step = 0
        for epoch in range(int(args.num_train_epochs)):
            logger.info('#' * 80)
            logger.info('Train {} Epoch{}'.format(args.seed, epoch + 1,
                                                  args.data_dir))
            logger.info('#' * 80)
            nb_tr_examples, nb_tr_steps = 0, 0
            for step, batch in enumerate(train_dataloader):
                model.train()
                batch = tuple(t.to(device) for t in batch)
                input_ids_spc, input_mask, segment_ids, label_ids, polarities, valid_ids, l_mask = batch
                loss_ate, loss_apc = model(input_ids_spc, segment_ids,
                                           input_mask, label_ids, polarities,
                                           valid_ids, l_mask)
                loss = loss_ate + loss_apc
                loss.backward()
                nb_tr_examples += input_ids_spc.size(0)
                nb_tr_steps += 1
                optimizer.step()
                optimizer.zero_grad()
                global_step += 1
                if global_step % args.eval_steps == 0:
                    if epoch >= args.num_train_epochs - 2 or args.num_train_epochs <= 2:
                        # evaluate in last 2 epochs
                        apc_result, ate_result = evaluate(
                            eval_ATE=not args.use_bert_spc)

                        # apc_result, ate_result = evaluate()
                        # path = '{0}/{1}_{2}_apcacc_{3}_apcf1_{4}_atef1_{5}'.format(
                        #     args.output_dir,
                        #     args.dataset,
                        #     args.local_context_focus,
                        #     round(apc_result['max_apc_test_acc'], 2),
                        #     round(apc_result['max_apc_test_f1'], 2),
                        #     round(ate_result, 2)
                        # )
                        # if apc_result['max_apc_test_acc'] > max_apc_test_acc or \
                        #     apc_result['max_apc_test_f1'] > max_apc_test_f1 or \
                        #     ate_result > max_ate_test_f1:
                        #     save_model(path)

                        if apc_result['max_apc_test_acc'] > max_apc_test_acc:
                            max_apc_test_acc = apc_result['max_apc_test_acc']
                        if apc_result['max_apc_test_f1'] > max_apc_test_f1:
                            max_apc_test_f1 = apc_result['max_apc_test_f1']
                        if ate_result > max_ate_test_f1:
                            max_ate_test_f1 = ate_result

                        current_apc_test_acc = apc_result['max_apc_test_acc']
                        current_apc_test_f1 = apc_result['max_apc_test_f1']
                        current_ate_test_f1 = round(ate_result, 2)

                        logger.info('*' * 80)
                        logger.info('Train {} Epoch{}, Evaluate for {}'.format(
                            args.seed, epoch + 1, args.data_dir))
                        logger.info(
                            f'APC_test_acc: {current_apc_test_acc}(max: {max_apc_test_acc})  '
                            f'APC_test_f1: {current_apc_test_f1}(max: {max_apc_test_f1})'
                        )
                        if args.use_bert_spc:
                            logger.info(
                                f'ATE_test_F1: {current_apc_test_f1}(max: {max_apc_test_f1})'
                                f' (Unreliable since `use_bert_spc` is "True".)'
                            )
                        else:
                            logger.info(
                                f'ATE_test_f1: {current_ate_test_f1}(max:{max_ate_test_f1})'
                            )
                        logger.info('*' * 80)

        return [max_apc_test_acc, max_apc_test_f1, max_ate_test_f1]

    return train()
Exemplo n.º 11
0
def test_fused_softmax():
    from megatron.model.fused_softmax import FusedScaleMaskSoftmax, SoftmaxFusionTypes
    from megatron.model.gpt2_model import (
        gpt2_attention_mask_func as attention_mask_func, )

    bert = BertModel.from_pretrained("bert-base-cased").cuda().half()
    tokenizer = BertTokenizer.from_pretrained("bert-base-cased")
    test_text = (
        "Hello. How are you? I am fine thank you and you? yes Good. "
        "hi hi hi hi hi hi hi hi hi hi hi hi hi"  # 32
    )

    tokens = tokenizer(
        [test_text] * 4,
        return_tensors="pt",
    )

    embedding_output = bert.embeddings(
        input_ids=tokens["input_ids"].cuda(),
        position_ids=None,
        token_type_ids=tokens["token_type_ids"].cuda(),
        inputs_embeds=None,
        past_key_values_length=0,
    )

    # (bsz, 1, 1, seq_len)
    mask = bert.get_extended_attention_mask(
        attention_mask=tokens["attention_mask"].cuda(),
        input_shape=tokens["input_ids"].shape,
        device=bert.device,
    )
    # (bsz, 1, seq_len, seq_len)
    mask = mask.repeat(1, 1, mask.size()[-1], 1)

    attention = bert.encoder.layer[0].attention.self
    key_layer = attention.transpose_for_scores(attention.key(embedding_output))
    query_layer = attention.transpose_for_scores(
        attention.query(embedding_output))

    attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
    attention_scores /= math.sqrt(key_layer.size()[-1])

    fused_softmax = (FusedScaleMaskSoftmax(
        input_in_fp16=True,
        input_in_bf16=False,
        fusion_type=SoftmaxFusionTypes.general,
        mask_func=attention_mask_func,
        scale=None,
        softmax_in_fp32=False,
    ).cuda().half())

    fused_softmax_output = fused_softmax(
        attention_scores,
        (mask != 0),
    )

    torch_softmax = (FusedScaleMaskSoftmax(
        input_in_fp16=True,
        input_in_bf16=False,
        mask_func=attention_mask_func,
        fusion_type=SoftmaxFusionTypes.none,
        scale=None,
        softmax_in_fp32=False,
    ).cuda().half())

    torch_softmax_output = torch_softmax(
        attention_scores,
        (mask != 0),
    )

    test_result = (fused_softmax_output - torch_softmax_output).abs()

    while test_result.dim() != 1:
        test_result = test_result.mean(dim=-1)

    diff = test_result.mean(dim=-1)

    if diff <= 1e-3:
        print(
            f"\n[Success] test_fused_softmax"
            f"\n > mean_difference={diff}"
            f"\n > fused_values={fused_softmax_output[-1][-1][-1][:5].tolist()}"
            f"\n > torch_values={torch_softmax_output[-1][-1][-1][:5].tolist()}"
        )
    else:
        print(
            f"\n[Fail] test_fused_softmax"
            f"\n > mean_difference={diff}, "
            f"\n > fused_values={fused_softmax_output[-1][-1][-1][:5].tolist()}, "
            f"\n > torch_values={torch_softmax_output[-1][-1][-1][:5].tolist()}"
        )
Exemplo n.º 12
0
####################
# Helper functions #
####################
def process(t):
    return torch.stack(t).squeeze().detach().numpy()


# BERT
if __name__ == "__main__":
    excluded_neurons = {0: (0, ), 1: (0, 1), 2: (0, 1, 2)}
    model = BertModel.from_pretrained("bert-base-cased",
                                      output_attentions=True,
                                      output_values=True,
                                      output_dense=True,
                                      output_mlp_activations=True,
                                      output_q_activations=True,
                                      output_k_activations=True,
                                      output_v_activations=True,
                                      excluded_neurons=excluded_neurons)
    tokenizer = BertTokenizer.from_pretrained("bert-base-cased")
    inputs = tokenizer("Hello", return_tensors="pt")
    print("### inputs ###")
    print(inputs.items())
    outputs = model(**inputs)
    print("### values ###")
    print(len(outputs["values"]))
    print(outputs["values"][0].shape)
    values = outputs["values"]
    values = torch.stack(values).squeeze()
    values = values.detach().numpy()
Exemplo n.º 13
0
 def __init__(self, model_name):
     self.tokenizer = AutoTokenizer.from_pretrained(model_name)
     self.model = BertModel.from_pretrained(model_name).eval()
     self.model.cuda()
Exemplo n.º 14
0
 def __init__(self, args, tokenizer) -> None:
     super().__init__(args, tokenizer)
     self.generative_encoder = BertModel.from_pretrained("bert-base-uncased", return_dict=True)
     self.latent_layer = Latent(args)