コード例 #1
0
ファイル: lstm_lm.py プロジェクト: veralily/fairseq
    def build_model(cls, args, task):
        """Build a new model instance."""

        # make sure all arguments are present in older models
        base_architecture(args)

        if getattr(args, "max_target_positions", None) is not None:
            max_target_positions = args.max_target_positions
        else:
            max_target_positions = getattr(args, "tokens_per_sample",
                                           DEFAULT_MAX_TARGET_POSITIONS)

        def load_pretrained_embedding_from_file(embed_path, dictionary,
                                                embed_dim):
            num_embeddings = len(dictionary)
            padding_idx = dictionary.pad()
            embed_tokens = Embedding(num_embeddings, embed_dim, padding_idx)
            embed_dict = utils.parse_embedding(embed_path)
            utils.print_embed_overlap(embed_dict, dictionary)
            return utils.load_embedding(embed_dict, dictionary, embed_tokens)

        pretrained_decoder_embed = None
        if args.decoder_embed_path:
            pretrained_decoder_embed = load_pretrained_embedding_from_file(
                args.decoder_embed_path, task.target_dictionary,
                args.decoder_embed_dim)

        if args.share_decoder_input_output_embed:
            # double check all parameters combinations are valid
            if task.source_dictionary != task.target_dictionary:
                raise ValueError(
                    "--share-decoder-input-output-embeddings requires a joint dictionary"
                )

            if args.decoder_embed_dim != args.decoder_out_embed_dim:
                raise ValueError(
                    "--share-decoder-input-output-embeddings requires "
                    "--decoder-embed-dim to match --decoder-out-embed-dim")

        decoder = LSTMDecoder(
            dictionary=task.dictionary,
            embed_dim=args.decoder_embed_dim,
            hidden_size=args.decoder_hidden_size,
            out_embed_dim=args.decoder_out_embed_dim,
            num_layers=args.decoder_layers,
            dropout_in=args.decoder_dropout_in,
            dropout_out=args.decoder_dropout_out,
            attention=
            False,  # decoder-only language model doesn't support attention
            encoder_output_units=0,
            pretrained_embed=pretrained_decoder_embed,
            share_input_output_embed=args.share_decoder_input_output_embed,
            adaptive_softmax_cutoff=(utils.eval_str_list(
                args.adaptive_softmax_cutoff, type=int) if args.criterion
                                     == "adaptive_loss" else None),
            max_target_positions=max_target_positions,
            residuals=args.residuals,
        )

        return cls(decoder)
コード例 #2
0
    def build_model(cls, args, task):
        """Build a new model instance."""
        # make sure that all args are properly defaulted (in case there are any new ones)
        base_architecture(args)

        encoder = TextRecognitionEncoder(args=args, )
        decoder = LSTMDecoder(
            dictionary=task.target_dictionary,
            embed_dim=encoder.embed_dim,
            hidden_size=args.decoder_hidden_size,
            out_embed_dim=args.decoder_out_embed_dim,
            num_layers=args.decoder_layers,
            dropout_in=args.decoder_dropout_in,
            dropout_out=args.decoder_dropout_out,
            attention=options.eval_bool(args.decoder_attention),
            encoder_output_units=encoder.embed_dim,
            adaptive_softmax_cutoff=(options.eval_str_list(
                args.adaptive_softmax_cutoff, type=int) if args.criterion
                                     == 'adaptive_loss' else None),
        )
        return cls(encoder, decoder)
コード例 #3
0
ファイル: lstm_lm.py プロジェクト: eastonYi/fairseq
    def build_model(cls, args, task, dictionary=None):
        """Build a new model instance."""

        # make sure all arguments are present in older models
        base_architecture(args)

        if getattr(args, 'max_target_positions', None) is not None:
            max_target_positions = args.max_target_positions
        else:
            max_target_positions = getattr(args, 'tokens_per_sample', DEFAULT_MAX_TARGET_POSITIONS)

        def load_pretrained_embedding_from_file(embed_path, dictionary, embed_dim):
            num_embeddings = len(dictionary)
            padding_idx = dictionary.pad()
            embed_tokens = Embedding(num_embeddings, embed_dim, padding_idx)
            embed_dict = utils.parse_embedding(embed_path)
            utils.print_embed_overlap(embed_dict, dictionary)
            return utils.load_embedding(embed_dict, dictionary, embed_tokens)

        pretrained_decoder_embed = None
        if args.decoder_embed_path:
            pretrained_decoder_embed = load_pretrained_embedding_from_file(
                args.decoder_embed_path,
                task.target_dictionary,
                args.decoder_embed_dim
            )

        if args.share_decoder_input_output_embed:
            # double check all parameters combinations are valid
            if task.source_dictionary != task.target_dictionary:
                raise ValueError('--share-decoder-input-output-embeddings requires a joint dictionary')

            if args.decoder_embed_dim != args.decoder_out_embed_dim:
                raise ValueError(
                    '--share-decoder-input-output-embeddings requires '
                    '--decoder-embed-dim to match --decoder-out-embed-dim'
                    )

        decoder = LSTMDecoder(
            dictionary=dictionary if dictionary else task.dictionary,
            embed_dim=args.decoder_embed_dim,
            hidden_size=args.decoder_hidden_size,
            out_embed_dim=args.decoder_out_embed_dim,
            num_layers=args.decoder_layers,
            dropout_in=args.decoder_dropout_in,
            dropout_out=args.decoder_dropout_out,
            attention=False,  # decoder-only language model doesn't support attention
            encoder_output_units=0,
            pretrained_embed=pretrained_decoder_embed,
            share_input_output_embed=args.share_decoder_input_output_embed,
            adaptive_softmax_cutoff=(
                utils.eval_str_list(args.adaptive_softmax_cutoff, type=int)
                if args.criterion == 'adaptive_loss' else None
            ),
            max_target_positions=max_target_positions,
            residuals=args.residuals
        )

        if getattr(args, "lm_path", None):
            # args.lm_path = '../libri/wav2vec2_small.pt'
            print('load LSTM_LM from {}'.format(args.lm_path))
            state = checkpoint_utils.load_checkpoint_to_cpu(args.lm_path)
            lm_args = state["args"]
            lm_args.data = args.data
            assert getattr(lm_args, "lm_path", None) is None

            task = tasks.setup_task(lm_args)
            decoder = task.build_model(lm_args)
            print('restore LSTM_LM from {}'.format(args.lm_path))
            decoder.load_state_dict(state["model"], strict=True)
        decoder.dim_output = len(task.dictionary)

        return cls(decoder)