Beispiel #1
0
    def from_roberta(roberta_enc: roberta.RobertaModel, args, dictionary):
        encoder = roberta_enc.encoder.sentence_encoder
        vocab_size, embed_dim = encoder.embed_tokens.weight.shape

        if args.share_all_embeddings:
            lm_head = roberta_enc.encoder.lm_head
            assert encoder.embed_tokens.weight is lm_head.weight, (
                "Can't use --share-all-embeddings with a model "
                "that was pretraiend with --untie-weights-roberta_enc")
        else:
            lm_head = roberta.RobertaLMHead(embed_dim, vocab_size,
                                            roberta_enc.args.activation_fn)

        dec_embs = nn.Embedding(vocab_size, embed_dim, dictionary.pad())
        if args.share_all_embeddings or args.share_decoder_input_output_embed:
            # Note: I wasn't able to use Embedding _weight parameter to achive this sharing.
            dec_embs.weight = lm_head.weight

        decoder = TransformerDecoder(
            RobertaEncDecModel.read_args_from_roberta(roberta_enc.args),
            dictionary,
            dec_embs,
            no_encoder_attn=False,
            output_projection=lm_head,
        )
        if getattr(args, "pretrained_decoder", False):
            decoder_dict = encoder.state_dict()

            # TODO: hide setting "encoder_attn" layers behind a flag.
            for k, w in list(decoder_dict.items()):
                if ".self_attn" in k:
                    k_enc_attn = k.replace(".self_attn", ".encoder_attn")
                    decoder_dict[k_enc_attn] = w.detach().clone()

            for k, w in lm_head.state_dict().items():
                decoder_dict["output_projection." + k] = w

            missing_keys, unexpected_keys = decoder.load_state_dict(
                decoder_dict, strict=False)
            # missing_keys = [m for m in missing_keys if ".encoder_attn" not in m]
            assert not missing_keys and not unexpected_keys, (
                "Failed to load state dict. "
                f"Missing keys: {missing_keys}. "
                f"Unexpected keys: {unexpected_keys}.")

        if args.share_all_embeddings:
            assert decoder.output_projection.weight is decoder.embed_tokens.weight
            assert encoder.embed_tokens.weight is decoder.embed_tokens.weight
        elif args.share_decoder_input_output_embed:
            assert decoder.output_projection.weight is decoder.embed_tokens.weight
            assert encoder.embed_tokens.weight is not decoder.embed_tokens.weight
        else:
            assert decoder.output_projection.weight is not decoder.embed_tokens.weight
            assert encoder.embed_tokens.weight is not decoder.embed_tokens.weight

        return RobertaEncDecModel(encoder, decoder)
Beispiel #2
0
class BartDecoder(FairseqIncrementalDecoder):
    """
    Transformer decoder consisting of *args.decoder_layers* layers. Each layer
    is a :class:`TransformerDecoderLayer`.

    Args:
        args (argparse.Namespace): parsed command-line arguments
        dictionary (~fairseq.data.Dictionary): decoding dictionary
        embed_tokens (torch.nn.Embedding): output embedding
        no_encoder_attn (bool, optional): whether to attend to encoder outputs
            (default: False).
    """
    def __init__(
        self,
        cfg: Wav2BartPoolConfig,
        dictionary=None,
        embed_tokens=None,
        no_encoder_attn=False,
    ):
        super().__init__(dictionary)
        self.cfg = cfg
        # bart = torch.hub.load('pytorch/fairseq', 'bart.base')
        from fairseq.models.bart import BARTModel
        if os.path.isfile(os.path.join(cfg.bart_path, 'model.pt')):
            print('loading bart from cfg path')
            bart = BARTModel.from_pretrained(cfg.bart_path,
                                             checkpoint_file='model.pt')
        else:
            print('loading bart from relative path')
            bart = BARTModel.from_pretrained('models/bart.base',
                                             checkpoint_file='model.pt')

        bart_decoder = bart.model.decoder
        self.decoder = TransformerDecoder(bart_decoder.args,
                                          bart_decoder.dictionary,
                                          bart_decoder.embed_tokens)
        self.decoder.load_state_dict(bart_decoder.state_dict())

    def forward(self,
                prev_output_tokens,
                encoder_out=None,
                incremental_state=None,
                **unused):
        """
        Args:
            prev_output_tokens (LongTensor): previous decoder outputs of shape
                `(batch, tgt_len)`, for teacher forcing
            encoder_out (Tensor, optional): output from the encoder, used for
                encoder-side attention
            incremental_state (dict): dictionary used for storing state during
                :ref:`Incremental decoding`

        Returns:
            tuple:
                - the decoder's output of shape `(batch, tgt_len, vocab)`
                - a dictionary with any model-specific outputs
        """
        # with torch.no_grad() if self.cfg.fix_decoder else contextlib.ExitStack():
        x, extra = self.decoder(prev_output_tokens, encoder_out,
                                incremental_state)

        return x, extra

    def extract_features(self,
                         prev_output_tokens,
                         encoder_out=None,
                         incremental_state=None,
                         **unused):
        self.decoder.extract_features(prev_output_tokens, encoder_out,
                                      incremental_state)

    def max_positions(self):
        """Maximum output length supported by the decoder."""
        return self.decoder.max_positions()

    def buffered_future_mask(self, tensor):

        return self.decoder.buffered_future_mask

    def upgrade_state_dict_named(self, state_dict, name):
        return state_dict
Beispiel #3
0
    def build_model(cls, args, task):
        encoder = TrOCREncoder(args=args, dictionary=task.source_dictionary)

        args.encoder_embed_dim = encoder.deit.embed_dim

        if getattr(args, "max_target_positions", None) is None:
            args.max_target_positions = DEFAULT_MAX_TARGET_POSITIONS

        if getattr(args, "decoder_pretrained", None).startswith('roberta2'):
            logger.info(
                'Using the learned pos embedding version loading roberta.')
            decoder_embed_tokens = cls.build_embedding(args,
                                                       task.target_dictionary,
                                                       args.decoder_embed_dim,
                                                       args.decoder_embed_path)

            pretrained_model = getattr(args, "decoder_pretrained", None)
            specified = pretrained_model.find('-') != -1

            if specified:
                pretrained_model = pretrained_model.replace('-', '.')
                logger.info(
                    'Load pre-trained decoder parameters from {}'.format(
                        pretrained_model))
                roberta = torch.hub.load('pytorch/fairseq:main',
                                         pretrained_model)
            elif args.decoder_layers == 6:
                logger.info(
                    'Load pre-trained decoder parameters from roberta.base')
                roberta = torch.hub.load('pytorch/fairseq:main',
                                         'roberta.base')
            elif args.decoder_layers == 12:
                logger.info(
                    'Load pre-trained decoder parameters from roberta.large')
                roberta = torch.hub.load('pytorch/fairseq:main',
                                         'roberta.large')
            else:
                raise AttributeError('Cannot determind the pre-trained model')

            roberta.model.args.encoder_layers = args.decoder_layers
            roberta.model.args.fp16 = args.fp16
            roberta_args = TrOCRModel.read_args_from_roberta(
                roberta.model.args)
            roberta_args.encoder_embed_dim = args.encoder_embed_dim

            decoder = TransformerDecoder(
                roberta_args,
                task.target_dictionary,
                decoder_embed_tokens,
                no_encoder_attn=False,
            )

            roberta_layers = roberta.model.encoder.sentence_encoder.layers
            decoder_layers = decoder.layers
            offset = len(roberta_layers) - len(decoder_layers)
            assert offset >= 0

            decoder_dict = roberta.state_dict()
            new_decoder_dict = {}
            for key, val in decoder_dict.items():
                if key.startswith('model.encoder.sentence_encoder.layers.'):
                    layer_num = int(
                        key[len('model.encoder.sentence_encoder.layers.'
                                ):].split('.')[0])
                    if layer_num - offset < 0:
                        continue
                    else:
                        new_key = 'model.encoder.sentence_encoder.layers.{}.'.format(
                            str(layer_num - offset)
                        ) + '.'.join(
                            key[len('model.encoder.sentence_encoder.layers.'
                                    ):].split('.')[1:])
                        new_decoder_dict[new_key] = val
                else:
                    new_decoder_dict[key] = val
            decoder_dict = new_decoder_dict

            for k, w in list(decoder_dict.items()):
                if '.lm_head' in k:
                    k_proj = "output_projection." + k[
                        len('model.encoder.lm_head.'):]
                    decoder_dict[k_proj] = w.detach().clone()
                    del decoder_dict[k]

            del decoder_dict['_float_tensor']
            del decoder_dict['output_projection.weight']
            del decoder_dict['output_projection.bias']
            del decoder_dict['output_projection.dense.weight']
            del decoder_dict['output_projection.dense.bias']
            del decoder_dict['output_projection.layer_norm.weight']
            del decoder_dict['output_projection.layer_norm.bias']

            new_decoder_dict = {}
            for key, val in decoder_dict.items():
                if "sentence_encoder" in key:
                    key = key[len('model.encoder.sentence_encoder.'):]
                elif "encoder" in key:
                    key = key[len('model.encoder.'):]
                new_decoder_dict[key] = val

            missing_keys, unexpected_keys = decoder.load_state_dict(
                new_decoder_dict, strict=False)

        elif getattr(args, "decoder_pretrained", None) == 'unilm':
            logger.info('Decoder is pretrained using the unilm.')

            prefix_of_parameter = 'bert'

            decoder_embed_tokens = cls.build_embedding(args,
                                                       task.target_dictionary,
                                                       args.decoder_embed_dim,
                                                       args.decoder_embed_path)

            decoder = UniLMDecoder(
                args,
                task.target_dictionary,
                decoder_embed_tokens,
                no_encoder_attn=False,
            )

            if hasattr(
                    args, 'decoder_pretrained_url'
            ) and args.decoder_pretrained_url != None and args.decoder_pretrained_url != '':
                unilm_url = args.decoder_pretrained_url
                logger.info('The unilm model url: {}.'.format(
                    unilm_url[:unilm_url.find('?')]))
                unilm_state_dict = torch.hub.load_state_dict_from_url(
                    unilm_url)

                unilm_layers = OrderedDict([
                    (k, unilm_state_dict[k]) for k in unilm_state_dict.keys()
                    if k.startswith(prefix_of_parameter + '.encoder.layer.')
                ])
                unilm_layers_num = []
                for k in unilm_layers.keys():
                    t = k.replace(prefix_of_parameter + '.encoder.layer.', '')
                    t = t[:t.find('.')]
                    unilm_layers_num.append(int(t))
                unilm_layers_num = max(unilm_layers_num) + 1

                offset = unilm_layers_num - len(decoder.layers)
                assert offset == 0

                decoder_dict = decoder.state_dict()
                # embedding
                new_pos_weight = torch.zeros_like(
                    decoder_dict['embed_positions.weight'])
                # position padding will right offset padding idx + 1
                new_pos_weight[task.target_dictionary.pad() +
                               1:, :] = unilm_state_dict[
                                   prefix_of_parameter +
                                   '.embeddings.position_embeddings.weight']
                new_decoder_dict = {
                    'embed_tokens.weight':
                    unilm_state_dict[prefix_of_parameter +
                                     '.embeddings.word_embeddings.weight'],
                    'embed_positions.weight':
                    new_pos_weight,
                    'layernorm_embedding.weight':
                    unilm_state_dict[prefix_of_parameter +
                                     '.embeddings.LayerNorm.weight'],
                    'layernorm_embedding.bias':
                    unilm_state_dict[prefix_of_parameter +
                                     '.embeddings.LayerNorm.bias']
                }

                # layers
                key_map = {
                    'self_attn.k_proj': 'attention.self.key',
                    'self_attn.v_proj': 'attention.self.value',
                    'self_attn.q_proj': 'attention.self.query',
                    'self_attn.out_proj': 'attention.output.dense',
                    'self_attn_layer_norm': 'attention.output.LayerNorm',
                    'fc1': 'intermediate.dense',
                    'fc2': 'output.dense',
                    'final_layer_norm': 'output.LayerNorm'
                }
                for layer_id in range(unilm_layers_num):
                    unilm_prefix = prefix_of_parameter + '.encoder.layer.{}.'.format(
                        layer_id)
                    decoder_prefix = 'layers.{}.'.format(layer_id)

                    for key in key_map:
                        for suffix in ['.weight', '.bias']:
                            decoder_key = decoder_prefix + key + suffix
                            unilm_key = unilm_prefix + key_map[key] + suffix
                            if decoder_key in decoder_dict and unilm_key in unilm_state_dict:
                                new_decoder_dict[
                                    decoder_key] = unilm_state_dict[unilm_key]

                if hasattr(args, "reset_dictionary") and args.reset_dictionary:
                    logger.info(
                        'Reset token embedding weights during decoder initialization.'
                    )
                    del new_decoder_dict['embed_tokens.weight']
                elif hasattr(args,
                             "adapt_dictionary") and args.adapt_dictionary:
                    unilm_embed_tokens_weight = new_decoder_dict[
                        'embed_tokens.weight']
                    logger.info(
                        'Adapt token embedding weights during decoder initialization from {} to {}'
                        .format(unilm_embed_tokens_weight.shape[0],
                                decoder_embed_tokens.weight.shape[0]))
                    new_decoder_dict['embed_tokens.weight'] = torch.zeros_like(
                        decoder_dict['embed_tokens.weight'])
                    new_decoder_dict['embed_tokens.weight'][:min(
                        unilm_embed_tokens_weight.
                        shape[0], decoder_dict['embed_tokens.weight'].shape[0]
                    ), :] = unilm_embed_tokens_weight[:min(
                        unilm_embed_tokens_weight.shape[0],
                        decoder_dict['embed_tokens.weight'].shape[0]), :]

                missing_keys, unexpected_keys = decoder.load_state_dict(
                    new_decoder_dict, strict=False)
            else:
                logger.warning(
                    'You must specify the unilm model url or the decoder is randomly initialized.'
                )

            # freeze k_proj bias
            for layer in decoder.layers:
                layer.self_attn.k_proj.bias.requires_grad = False

        elif getattr(args, "decoder_pretrained",
                     None).upper() == 'None' or getattr(
                         args, "decoder_pretrained", None) == None:
            logger.info('Decoder is randomly initialized.')
            decoder_embed_tokens = cls.build_embedding(args,
                                                       task.target_dictionary,
                                                       args.decoder_embed_dim,
                                                       args.decoder_embed_path)
            decoder = TransformerDecoder(args=args,
                                         dictionary=task.target_dictionary,
                                         embed_tokens=decoder_embed_tokens,
                                         no_encoder_attn=False)

        elif getattr(args, "decoder_pretrained", None).startswith('roberta'):
            logger.info('Using the old version loading roberta.')
            decoder_embed_tokens = cls.build_embedding(args,
                                                       task.target_dictionary,
                                                       args.decoder_embed_dim,
                                                       args.decoder_embed_path)
            decoder = TransformerDecoder(args=args,
                                         dictionary=task.target_dictionary,
                                         embed_tokens=decoder_embed_tokens,
                                         no_encoder_attn=False)

            pretrained_model = getattr(args, "decoder_pretrained", None)
            specified = pretrained_model.find('-') != -1

            if specified:
                pretrained_model = pretrained_model.replace('-', '.')
                logger.info(
                    'Load pre-trained decoder parameters from {}'.format(
                        pretrained_model))
                roberta = torch.hub.load('pytorch/fairseq:main',
                                         pretrained_model)
            elif args.decoder_layers == 6:
                logger.info(
                    'Load pre-trained decoder parameters from roberta.base')
                roberta = torch.hub.load('pytorch/fairseq:main',
                                         'roberta.base')
            elif args.decoder_layers == 12:
                logger.info(
                    'Load pre-trained decoder parameters from roberta.large')
                roberta = torch.hub.load('pytorch/fairseq:main',
                                         'roberta.large')
            else:
                raise AttributeError('Cannot determind the pre-trained model')

            decoder.embed_tokens.load_state_dict(
                roberta.model.encoder.sentence_encoder.embed_tokens.state_dict(
                ))
            roberta_layers = roberta.model.encoder.sentence_encoder.layers
            decoder_layers = decoder.layers
            offset = len(roberta_layers) - len(decoder_layers)
            assert offset >= 0

            for i in range(len(decoder_layers)):
                roberta_i = i + offset
                decoder_layers[i].self_attn.load_state_dict(
                    roberta_layers[roberta_i].self_attn.state_dict())
                decoder_layers[i].self_attn_layer_norm.load_state_dict(
                    roberta_layers[roberta_i].self_attn_layer_norm.state_dict(
                    ))

        else:
            raise Exception('Undefined decoder pretraining method.')
        model = cls(encoder, decoder)
        return model
Beispiel #4
0
    def build_model(cls, args, task):
        """Build a new model instance."""

        # make sure all arguments are present in older models
        base_lm_architecture(args)

        if args.decoder_layers_to_keep:
            args.decoder_layers = len(args.decoder_layers_to_keep.split(","))

        if getattr(args, "max_target_positions", None) is None:
            args.max_target_positions = getattr(args, "tokens_per_sample",
                                                DEFAULT_MAX_TARGET_POSITIONS)

        if args.character_embeddings:
            embed_tokens = CharacterTokenEmbedder(
                task.source_dictionary,
                eval(args.character_filters),
                args.character_embedding_dim,
                args.decoder_embed_dim,
                args.char_embedder_highway_layers,
            )
        elif args.adaptive_input:
            embed_tokens = AdaptiveInput(
                len(task.source_dictionary),
                task.source_dictionary.pad(),
                args.decoder_input_dim,
                args.adaptive_input_factor,
                args.decoder_embed_dim,
                options.eval_str_list(args.adaptive_input_cutoff, type=int),
                args.quant_noise_pq,
                args.quant_noise_pq_block_size,
            )
        else:
            embed_tokens = cls.build_embedding(args, task.source_dictionary,
                                               args.decoder_input_dim)

        if args.tie_adaptive_weights:
            assert args.adaptive_input
            assert args.adaptive_input_factor == args.adaptive_softmax_factor
            assert (args.adaptive_softmax_cutoff == args.adaptive_input_cutoff
                    ), "{} != {}".format(args.adaptive_softmax_cutoff,
                                         args.adaptive_input_cutoff)
            assert args.decoder_input_dim == args.decoder_output_dim

        decoder = TransformerDecoder(args,
                                     task.target_dictionary,
                                     embed_tokens,
                                     no_encoder_attn=True)

        if getattr(args, "lm_path", None):
            print('load Transformer_LM from {}'.format(args.lm_path))
            state = checkpoint_utils.load_checkpoint_to_cpu(args.lm_path)
            lm_args = state["args"]
            lm_args.data = args.data
            assert getattr(lm_args, "lm_path", None) is None

            task = tasks.setup_task(lm_args)
            decoder = task.build_model(lm_args)
            print('restore Transformer_LM from {}'.format(args.lm_path))
            decoder.load_state_dict(state["model"], strict=True)
        decoder.dim_output = len(task.dictionary)

        return cls(decoder)
Beispiel #5
0
class BartDecoder(FairseqIncrementalDecoder):
    """
    Transformer decoder consisting of *args.decoder_layers* layers. Each layer
    is a :class:`TransformerDecoderLayer`.

    Args:
        args (argparse.Namespace): parsed command-line arguments
        dictionary (~fairseq.data.Dictionary): decoding dictionary
        embed_tokens (torch.nn.Embedding): output embedding
        no_encoder_attn (bool, optional): whether to attend to encoder outputs
            (default: False).
    """

    def __init__(
        self,
        cfg: Wav2Vec2BartConfig,
        dictionary=None,
        embed_tokens=None,
        no_encoder_attn=False,
        transform_embed=None,
        bart=None,
    ):
        super().__init__(dictionary)
        self.cfg = cfg
        # bart = torch.hub.load('pytorch/fairseq', 'bart.base')
        bart_decoder = bart.model.decoder
        self.decoder = TransformerDecoder(bart_decoder.args, bart_decoder.dictionary, bart_decoder.embed_tokens)
        self.decoder.load_state_dict(bart_decoder.state_dict())
        self.decoder.embed_tokens = EmbeddingTransformed(self.decoder.embed_tokens, transform_embed)

    def forward(
        self, prev_output_tokens, encoder_out=None, incremental_state=None, **unused
    ):
        """
        Args:
            prev_output_tokens (LongTensor): previous decoder outputs of shape
                `(batch, tgt_len)`, for teacher forcing
            encoder_out (Tensor, optional): output from the encoder, used for
                encoder-side attention
            incremental_state (dict): dictionary used for storing state during
                :ref:`Incremental decoding`

        Returns:
            tuple:
                - the decoder's output of shape `(batch, tgt_len, vocab)`
                - a dictionary with any model-specific outputs
        """
        # with torch.no_grad() if self.cfg.fix_decoder else contextlib.ExitStack():
        x, extra = self.decoder(prev_output_tokens, encoder_out, incremental_state)

        for k in ['wav2vec_logits', 'wav2vec_padding_mask', 'ctc_weight', 'ce_weight']:
            extra[k] = encoder_out[k]

        print('bart decoder extra.keys()', extra.keys())
        return x, extra

    def extract_features(
        self, prev_output_tokens, encoder_out=None, incremental_state=None, **unused
    ):
        self.decoder.extract_features(prev_output_tokens, encoder_out, incremental_state)

    def max_positions(self):
        """Maximum output length supported by the decoder."""
        return self.decoder.max_positions()

    def buffered_future_mask(self, tensor):
        return self.decoder.buffered_future_mask

    def upgrade_state_dict_named(self, state_dict, name):
        return state_dict