Ejemplo n.º 1
0
    def __init__(self, num_classes, finetune=False):
        super(VitModel, self).__init__()
        args = get_args()

        self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
        if args.init_method_xavier_uniform:
            self.init_method = torch.nn.init.xavier_uniform_
            self.scaled_init_method = torch.nn.init.xavier_uniform_
        else:
            self.init_method = init_method_normal(args.init_method_std)
            self.scaled_init_method = scaled_init_method_normal(
                args.init_method_std, args.num_layers)

        self.hidden_size = args.hidden_size
        self.num_classes = num_classes
        self.patch_dim = args.patch_dim
        self.img_dim = args.img_dim
        self.finetune = finetune

        assert self.img_dim % self.patch_dim == 0
        self.num_patches_per_dim = self.img_dim // self.patch_dim
        self.num_patches = self.num_patches_per_dim**2
        self.seq_length = self.num_patches + 1
        self.flatten_dim = self.patch_dim * self.patch_dim * args.num_channels

        # cls_token
        self.cls_token = torch.nn.Parameter(torch.randn(
            1, 1, self.hidden_size))
        torch.nn.init.zeros_(self.cls_token)

        # Linear encoder
        self.linear_encoder = torch.nn.Linear(self.flatten_dim,
                                              self.hidden_size)

        # embedding
        self.position_embeddings = torch.nn.Embedding(self.seq_length,
                                                      self.hidden_size)
        init_method_normal(args.init_method_std)(
            self.position_embeddings.weight)
        self.position_ids = torch.arange(self.seq_length).expand(1, -1).cuda()

        self.position_embeddings._register_load_state_dict_pre_hook(
            twod_interpolate_position_embeddings_hook)

        self.embedding_dropout = torch.nn.Dropout(args.hidden_dropout)

        # Transformer
        self.transformer = ParallelTransformer(self.init_method,
                                               self.scaled_init_method)

        # MLP head
        if not self.finetune:
            self.mlp_head = VitMlpHead(self.hidden_size, self.num_classes)
        else:
            self.class_head = get_linear_layer(self.hidden_size, num_classes,
                                               torch.nn.init.zeros_)
Ejemplo n.º 2
0
    def __init__(self,
                 num_tokentypes=2,
                 add_binary_head=True,
                 parallel_output=True):
        super(BertModel, self).__init__()
        args = get_args()

        self.add_binary_head = add_binary_head
        self.parallel_output = parallel_output
        init_method = init_method_normal(args.init_method_std)
        scaled_init_method = scaled_init_method_normal(args.init_method_std,
                                                       args.num_layers)

        self.language_model, self._language_model_key = get_language_model(
            attention_mask_func=bert_attention_mask_func,
            num_tokentypes=num_tokentypes,
            add_pooler=self.add_binary_head,
            init_method=init_method,
            scaled_init_method=scaled_init_method)

        self.lm_head = BertLMHead(
            self.language_model.embedding.word_embeddings.weight.size(0),
            args.hidden_size, init_method, args.layernorm_epsilon,
            parallel_output)
        self._lm_head_key = 'lm_head'

        if self.add_binary_head:
            self.binary_head = get_linear_layer(args.hidden_size, 2,
                                                init_method)
            self._binary_head_key = 'binary_head'
Ejemplo n.º 3
0
def get_language_model(attention_mask_func,
                       num_tokentypes,
                       add_pooler,
                       init_method=None,
                       scaled_init_method=None):
    """Build language model and return along with the key to save."""
    args = get_args()

    if init_method is None:
        init_method = init_method_normal(args.init_method_std)

    if scaled_init_method is None:
        scaled_init_method = scaled_init_method_normal(args.init_method_std,
                                                       args.num_layers)

    # Language model.
    language_model = TransformerLanguageModel(
        attention_mask_func=attention_mask_func,
        init_method=init_method,
        output_layer_init_method=scaled_init_method,
        num_tokentypes=num_tokentypes,
        add_pooler=add_pooler)
    # key used for checkpoints.
    language_model_key = 'language_model'

    return language_model, language_model_key
Ejemplo n.º 4
0
    def __init__(self,
                 num_tokentypes=2,
                 parallel_output=True,
                 pre_process=True,
                 post_process=True):
        super(PretrainedBertModel, self).__init__()

        args = get_args()
        tokenizer = get_tokenizer()
        self.pad_id = tokenizer.pad
        self.biencoder_projection_dim = args.biencoder_projection_dim
        self.parallel_output = parallel_output
        self.pre_process = pre_process
        self.post_process = post_process
        init_method = init_method_normal(args.init_method_std)
        scaled_init_method = scaled_init_method_normal(args.init_method_std,
                                                       args.num_layers)

        self.language_model, self._language_model_key = get_language_model(
            num_tokentypes=num_tokentypes,
            add_pooler=False,
            encoder_attn_mask_type=AttnMaskType.padding,
            init_method=init_method,
            scaled_init_method=scaled_init_method,
            pre_process=self.pre_process,
            post_process=self.post_process)

        if args.biencoder_projection_dim > 0:
            self.projection_enc = get_linear_layer(
                args.hidden_size, args.biencoder_projection_dim, init_method)
            self._projection_enc_key = 'projection_enc'
Ejemplo n.º 5
0
    def __init__(self, num_tokentypes=2, add_binary_head=True,
                 parallel_output=True):
        super(BertModelBase, self).__init__()
        args = get_args()

        self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
        self.add_binary_head = add_binary_head
        self.parallel_output = parallel_output

        init_method = init_method_normal(args.init_method_std)
        scaled_init_method = scaled_init_method_normal(args.init_method_std,
                                                       args.num_layers)

        self.language_model, self._language_model_key = get_language_model(
            attention_mask_func=bert_attention_mask_func,
            num_tokentypes=num_tokentypes,
            add_pooler=self.add_binary_head,
            init_method=init_method,
            scaled_init_method=scaled_init_method)

        self.initialize_word_embeddings(init_method_normal)
        if mpu.is_pipeline_last_stage():
            self.lm_head = BertLMHead(
                self.word_embeddings_weight().size(0),
                args.hidden_size, init_method, args.layernorm_epsilon, parallel_output)
            self._lm_head_key = 'lm_head'
            self.binary_head = None
            if self.add_binary_head:
                self.binary_head = get_linear_layer(args.hidden_size, 2,
                                                    init_method)
                self._binary_head_key = 'binary_head'
Ejemplo n.º 6
0
def get_language_model(num_tokentypes,
                       add_pooler,
                       encoder_attn_mask_type,
                       init_method=None,
                       scaled_init_method=None,
                       add_decoder=False,
                       decoder_attn_mask_type=AttnMaskType.causal,
                       pre_process=True,
                       post_process=True):
    """Build language model and return along with the key to save."""
    args = get_args()

    if init_method is None:
        init_method = init_method_normal(args.init_method_std)

    if scaled_init_method is None:
        scaled_init_method = scaled_init_method_normal(args.init_method_std,
                                                       args.num_layers)

    # Language model.
    language_model = TransformerLanguageModel(
        init_method,
        scaled_init_method,
        encoder_attn_mask_type,
        num_tokentypes=num_tokentypes,
        add_decoder=add_decoder,
        decoder_attn_mask_type=decoder_attn_mask_type,
        add_pooler=add_pooler,
        pre_process=pre_process,
        post_process=post_process)
    # key used for checkpoints.
    language_model_key = 'language_model'

    return language_model, language_model_key
Ejemplo n.º 7
0
    def __init__(self, num_tokentypes=2):
        super(MultipleChoice, self).__init__()
        args = get_args()

        init_method = init_method_normal(args.init_method_std)

        self.language_model, self._language_model_key = get_language_model(
            attention_mask_func=bert_attention_mask_func,
            num_tokentypes=num_tokentypes,
            add_pooler=True,
            init_method=init_method,
            scaled_init_method=scaled_init_method_normal(
                args.init_method_std, args.num_layers))

        # Multi-choice head.
        self.multichoice_dropout = torch.nn.Dropout(args.hidden_dropout)
        self.multichoice_head = get_linear_layer(args.hidden_size, 1,
                                                 init_method)
        self._multichoice_head_key = 'multichoice_head'
Ejemplo n.º 8
0
    def __init__(self, ict_head_size, num_tokentypes=2, parallel_output=True):
        super(IREncoderBertModel, self).__init__()
        args = get_args()

        self.ict_head_size = ict_head_size
        self.parallel_output = parallel_output
        init_method = init_method_normal(args.init_method_std)
        scaled_init_method = scaled_init_method_normal(args.init_method_std,
                                                       args.num_layers)

        self.language_model, self._language_model_key = get_language_model(
            attention_mask_func=bert_attention_mask_func,
            num_tokentypes=num_tokentypes,
            add_pooler=True,
            init_method=init_method,
            scaled_init_method=scaled_init_method)

        self.ict_head = get_linear_layer(args.hidden_size, ict_head_size, init_method)
        self._ict_head_key = 'ict_head'
Ejemplo n.º 9
0
    def __init__(
        self,
        model_name,
        vocab_file,
        hidden_size=1024,
        num_attention_heads=16,
        num_layers=24,
        max_seq_length=512,
        tokenizer_type='BertWordPieceLowerCase',
        init_method_std=0.02,
        num_tokentypes=2,
    ):

        super().__init__()

        if not os.path.exists(vocab_file):
            raise ValueError(f'Vocab file not found at {vocab_file}')

        megatron_args = {
            "num_layers": num_layers,
            "hidden_size": hidden_size,
            "num_attention_heads": num_attention_heads,
            "max_position_embeddings": max_seq_length,
            "tokenizer_type": tokenizer_type,
            "vocab_file": vocab_file,
        }

        initialize_megatron(None, megatron_args, ignore_unknown_args=True)
        init_method = init_method_normal(init_method_std)

        self.language_model, self._language_model_key = get_language_model(
            attention_mask_func=bert_attention_mask_func,
            num_tokentypes=num_tokentypes,
            add_pooler=False,
            init_method=init_method,
            scaled_init_method=scaled_init_method_normal(
                init_method_std, num_layers),
        )

        self.language_model.to(self._device)
        self._hidden_size = self.language_model.hidden_size
Ejemplo n.º 10
0
    def __init__(self, num_classes, num_tokentypes=2):
        super(ClassificationBase, self).__init__(share_word_embeddings=False)
        args = get_args()

        self.num_classes = num_classes
        init_method = init_method_normal(args.init_method_std)

        self.language_model, self._language_model_key = get_language_model(
            attention_mask_func=bert_attention_mask_func,
            num_tokentypes=num_tokentypes,
            add_pooler=True,
            init_method=init_method,
            scaled_init_method=scaled_init_method_normal(
                args.init_method_std, args.num_layers))

        # Multi-choice head.
        if mpu.is_pipeline_last_stage():
            self.classification_dropout = torch.nn.Dropout(args.hidden_dropout)
            self.classification_head = get_linear_layer(
                args.hidden_size, self.num_classes, init_method)
            self._classification_head_key = 'classification_head'
Ejemplo n.º 11
0
    def __init__(self, num_tokentypes=0, parallel_output=True):
        super(T5Model, self).__init__()
        args = get_args()

        self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
        self.parallel_output = parallel_output
        init_method = init_method_normal(args.init_method_std)
        scaled_init_method = scaled_init_method_normal(args.init_method_std,
                                                       args.num_layers)

        self.language_model, self._language_model_key = get_language_model(
            num_tokentypes=num_tokentypes,
            add_pooler=False,
            add_decoder=True,
            encoder_attn_mask_type=AttnMaskType.padding,
            init_method=init_method,
            scaled_init_method=scaled_init_method)

        self.lm_head = T5LMHead(
            self.language_model.embedding.word_embeddings.weight.size(0),
            parallel_output)
        self._lm_head_key = 'lm_head'
Ejemplo n.º 12
0
    def __init__(self,
                 num_tokentypes=2,
                 add_binary_head=True,
                 parallel_output=True,
                 pre_process=True,
                 post_process=True):
        super(BertModel, self).__init__()
        args = get_args()

        self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
        self.add_binary_head = add_binary_head
        self.parallel_output = parallel_output
        self.pre_process = pre_process
        self.post_process = post_process

        init_method = init_method_normal(args.init_method_std)
        scaled_init_method = scaled_init_method_normal(args.init_method_std,
                                                       args.num_layers)

        self.language_model, self._language_model_key = get_language_model(
            num_tokentypes=num_tokentypes,
            add_pooler=self.add_binary_head,
            encoder_attn_mask_type=AttnMaskType.padding,
            init_method=init_method,
            scaled_init_method=scaled_init_method,
            pre_process=self.pre_process,
            post_process=self.post_process)

        self.initialize_word_embeddings(init_method_normal)
        if self.post_process:
            self.lm_head = BertLMHead(
                self.word_embeddings_weight().size(0),
                args.hidden_size, init_method, args.layernorm_epsilon, parallel_output)
            self._lm_head_key = 'lm_head'
            self.binary_head = None
            if self.add_binary_head:
                self.binary_head = get_linear_layer(args.hidden_size, 2,
                                                    init_method)
                self._binary_head_key = 'binary_head'
Ejemplo n.º 13
0
def get_language_model(attention_mask_func,
                       num_tokentypes,
                       add_pooler,
                       init_method=None,
                       scaled_init_method=None):
    """Build language model and return along with the key to save."""
    args = get_args()

    if init_method is None:
        init_method = init_method_normal(args.init_method_std)

    if scaled_init_method is None:
        scaled_init_method = scaled_init_method_normal(args.init_method_std,
                                                       args.num_layers)

    # Language model.
    args = [attention_mask_func, init_method, scaled_init_method]
    kwargs = {}
    cls = None
    if mpu.is_pipeline_first_stage() and mpu.is_pipeline_last_stage():
        cls = TransformerLanguageModel
        kwargs['num_tokentypes'] = num_tokentypes
        kwargs['add_pooler'] = add_pooler
    elif mpu.is_pipeline_first_stage() and not mpu.is_pipeline_last_stage():
        cls = TransformerLanguageModelFirstStage
        kwargs['num_tokentypes'] = num_tokentypes
    elif not mpu.is_pipeline_first_stage() and mpu.is_pipeline_last_stage():
        cls = TransformerLanguageModelLastStage
        kwargs['add_pooler'] = add_pooler
    else:
        cls = TransformerLanguageModelIntermediateStage

    # Language model.
    language_model = cls(*args, **kwargs)
    # key used for checkpoints.
    language_model_key = 'language_model'

    return language_model, language_model_key
Ejemplo n.º 14
0
    def __init__(self,
                 num_tokentypes=0,
                 parallel_output=True,
                 pre_process=True,
                 post_process=True,
                 add_encoder=True,
                 add_decoder=True):
        super(T5Model, self).__init__()
        args = get_args()

        self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
        self.parallel_output = parallel_output
        init_method = init_method_normal(args.init_method_std)
        scaled_init_method = scaled_init_method_normal(args.init_method_std,
                                                       args.num_layers)
        self.pre_process = pre_process
        self.post_process = post_process
        self.add_encoder = add_encoder
        self.add_decoder = add_decoder

        self.language_model, self._language_model_key = get_language_model(
            num_tokentypes=num_tokentypes,
            add_pooler=False,
            add_encoder=add_encoder,
            add_decoder=add_decoder,
            encoder_attn_mask_type=AttnMaskType.padding,
            init_method=init_method,
            scaled_init_method=scaled_init_method,
            pre_process=self.pre_process,
            post_process=self.post_process)

        self.initialize_word_embeddings(init_method_normal)

        if self.post_process and self.add_decoder:
            self.lm_head = T5LMHead(self.word_embeddings_weight().size(0),
                                    parallel_output)
            self._lm_head_key = 'lm_head'
Ejemplo n.º 15
0
    def __init__(self, num_tokentypes=2, pre_process=True, post_process=True):
        super(MultipleChoice, self).__init__(share_word_embeddings=False)
        args = get_args()

        init_method = init_method_normal(args.init_method_std)
        self.pre_process = pre_process
        self.post_process = post_process

        self.language_model, self._language_model_key = get_language_model(
            num_tokentypes=num_tokentypes,
            add_pooler=True,
            encoder_attn_mask_type=AttnMaskType.padding,
            init_method=init_method,
            scaled_init_method=scaled_init_method_normal(
                args.init_method_std, args.num_layers),
            pre_process=self.pre_process,
            post_process=self.post_process)

        # Multi-choice head.
        if self.post_process:
            self.multichoice_dropout = torch.nn.Dropout(args.hidden_dropout)
            self.multichoice_head = get_linear_layer(args.hidden_size, 1,
                                                     init_method)
            self._multichoice_head_key = 'multichoice_head'
Ejemplo n.º 16
0
    def __init__(self,
                 pre_process=True,
                 post_process=True,
                 class_token=True,
                 single_token_output=False):
        super(VitBackbone, self).__init__(share_word_embeddings=False)
        args = get_args()

        self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
        if args.init_method_xavier_uniform:
            self.init_method = torch.nn.init.xavier_uniform_
            self.scaled_init_method = torch.nn.init.xavier_uniform_
        else:
            self.init_method = init_method_normal(args.init_method_std)
            self.scaled_init_method = scaled_init_method_normal(
                args.init_method_std, args.num_layers)

        self.pre_process = pre_process
        self.post_process = post_process
        self.class_token = class_token
        self.hidden_size = args.hidden_size
        self.patch_dim = args.patch_dim
        self.img_h = args.img_h
        self.img_w = args.img_w
        self.micro_batch_size = args.micro_batch_size
        self.single_token_output = single_token_output

        assert self.img_h % self.patch_dim == 0
        assert self.img_w % self.patch_dim == 0
        self.num_patches_per_dim_h = self.img_h // self.patch_dim
        self.num_patches_per_dim_w = self.img_w // self.patch_dim
        self.num_patches = self.num_patches_per_dim_h * self.num_patches_per_dim_w
        self.seq_length = self.num_patches + (CLASS_TOKEN_LENGTH
                                              if self.class_token else 0)
        self.flatten_dim = self.patch_dim * self.patch_dim * args.num_channels
        self.input_tensor = None
        self.position_ids = None

        if self.pre_process:
            # cls_token
            if self.class_token:
                self.cls_token = torch.nn.Parameter(
                    torch.randn(1, CLASS_TOKEN_LENGTH, self.hidden_size))
                torch.nn.init.zeros_(self.cls_token)
            self.position_ids = torch.arange(self.seq_length).expand(
                1, -1).cuda()

            # Linear encoder
            self.linear_encoder = torch.nn.Linear(self.flatten_dim,
                                                  self.hidden_size)

            # embedding
            self.position_embeddings = torch.nn.Embedding(
                self.seq_length, self.hidden_size)
            init_method_normal(args.init_method_std)(
                self.position_embeddings.weight)

            args.class_token_present = self.class_token
            self.position_embeddings._register_load_state_dict_pre_hook(
                twod_interpolate_position_embeddings_hook)

            self.embedding_dropout = torch.nn.Dropout(args.hidden_dropout)

        # Transformer
        self.transformer = ParallelTransformer(
            self.init_method,
            self.scaled_init_method,
            pre_process=self.pre_process,
            post_process=self.post_process,
        )