Beispiel #1
0
    def build_model(cls, args, task):
        """Build a new model instance."""
        # make sure that all args are properly defaulted (in case there are any new ones)
        base_architecture(args)

        decoder_embed_dict = None
        if args.decoder_embed_path:
            decoder_embed_dict = utils.parse_embedding(args.decoder_embed_path)
            utils.print_embed_overlap(decoder_embed_dict,
                                      task.target_dictionary)

        out_channels = speech_utils.eval_str_nested_list_or_tuple(
            args.encoder_conv_channels, type=int)
        kernel_sizes = speech_utils.eval_str_nested_list_or_tuple(
            args.encoder_conv_kernel_sizes, type=int)
        strides = speech_utils.eval_str_nested_list_or_tuple(
            args.encoder_conv_strides, type=int)
        logger.info('input feature dimension: {}, channels: {}'.format(
            task.feat_dim, task.feat_in_channels))
        assert task.feat_dim % task.feat_in_channels == 0
        conv_layers = ConvBNReLU(
            out_channels,
            kernel_sizes,
            strides,
            in_channels=task.feat_in_channels,
        ) if out_channels is not None else None

        fconv_encoder_input_size = task.feat_dim // task.feat_in_channels
        if conv_layers is not None:
            for stride in strides:
                if isinstance(stride, (list, tuple)):
                    assert len(stride) > 0
                    s = stride[1] if len(stride) > 1 else stride[0]
                else:
                    assert isinstance(stride, int)
                    s = stride
                fconv_encoder_input_size = (fconv_encoder_input_size + s -
                                            1) // s
            fconv_encoder_input_size *= out_channels[-1]

        encoder = SpeechFConvEncoder(
            conv_layers_before=conv_layers,
            input_size=fconv_encoder_input_size,
            embed_dim=args.encoder_embed_dim,
            convolutions=eval(args.encoder_layers),
            dropout=args.dropout,
        )
        decoder = SpeechFConvDecoder(
            dictionary=task.target_dictionary,
            embed_dim=args.decoder_embed_dim,
            embed_dict=decoder_embed_dict,
            convolutions=eval(args.decoder_layers),
            out_embed_dim=args.decoder_out_embed_dim,
            attention=eval(args.decoder_attention),
            dropout=args.dropout,
            max_positions=args.max_target_positions,
            share_embed=args.share_input_output_embed,
            positional_embeddings=args.decoder_positional_embed,
        )
        return cls(encoder, decoder)
Beispiel #2
0
    def build_model(cls, args, task):
        """Build a new model instance."""
        # make sure that all args are properly defaulted (in case there are any new ones)
        base_architecture(args)

        hidden_sizes = speech_utils.eval_str_nested_list_or_tuple(
            args.hidden_sizes, type=int)
        kernel_sizes = speech_utils.eval_str_nested_list_or_tuple(
            args.kernel_sizes, type=int)
        strides = speech_utils.eval_str_nested_list_or_tuple(args.strides,
                                                             type=int)
        dilations = speech_utils.eval_str_nested_list_or_tuple(args.dilations,
                                                               type=int)
        logger.info("input feature dimension: {}, output dimension: {}".format(
            task.feat_dim, task.num_targets))

        encoder = SpeechTdnnEncoder(
            input_size=task.feat_dim,
            output_size=task.num_targets,
            hidden_sizes=hidden_sizes,
            kernel_sizes=kernel_sizes,
            strides=strides,
            dilations=dilations,
            num_layers=args.num_layers,
            dropout_in=args.dropout_in,
            dropout_out=args.dropout_out,
            residual=args.residual,
            chunk_width=getattr(task, "chunk_width", None),
            chunk_left_context=getattr(task, "chunk_left_context", 0),
            training_stage=getattr(task, "training_stage", True),
        )
        return cls(encoder,
                   state_prior=getattr(task, "initial_state_prior", None))
Beispiel #3
0
    def build_model(cls, args, task):
        """Build a new model instance."""
        # make sure that all args are properly defaulted (in case there are any new ones)
        base_architecture(args)
        max_source_positions = getattr(args, "max_source_positions",
                                       DEFAULT_MAX_SOURCE_POSITIONS)

        out_channels = speech_utils.eval_str_nested_list_or_tuple(
            args.encoder_conv_channels, type=int)
        kernel_sizes = speech_utils.eval_str_nested_list_or_tuple(
            args.encoder_conv_kernel_sizes, type=int)
        strides = speech_utils.eval_str_nested_list_or_tuple(
            args.encoder_conv_strides, type=int)
        logger.info("input feature dimension: {}, channels: {}".format(
            task.feat_dim, task.feat_in_channels))
        assert task.feat_dim % task.feat_in_channels == 0
        conv_layers = ConvBNReLU(
            out_channels,
            kernel_sizes,
            strides,
            in_channels=task.feat_in_channels,
        ) if out_channels is not None else None

        rnn_encoder_input_size = task.feat_dim // task.feat_in_channels
        if conv_layers is not None:
            for stride in strides:
                if isinstance(stride, (list, tuple)):
                    assert len(stride) > 0
                    s = stride[1] if len(stride) > 1 else stride[0]
                else:
                    assert isinstance(stride, int)
                    s = stride
                rnn_encoder_input_size = (rnn_encoder_input_size + s - 1) // s
            rnn_encoder_input_size *= out_channels[-1]
        else:
            rnn_encoder_input_size = task.feat_dim

        encoder = SpeechChunkLSTMEncoder(
            conv_layers_before=conv_layers,
            input_size=rnn_encoder_input_size,
            hidden_size=args.encoder_rnn_hidden_size,
            num_layers=args.encoder_rnn_layers,
            dropout_in=args.encoder_rnn_dropout_in,
            dropout_out=args.encoder_rnn_dropout_out,
            bidirectional=args.encoder_rnn_bidirectional,
            residual=args.encoder_rnn_residual,
            src_bucketed=(getattr(task.args, "num_batch_buckets", 0) > 0),
            num_targets=getattr(task, "num_targets",
                                None),  # targets for encoder-only model
            chunk_width=getattr(task, "chunk_width", None),
            chunk_left_context=getattr(task, "chunk_left_context", 0),
            training_stage=getattr(task, "training_stage", True),
            max_source_positions=max_source_positions,
        )
        return cls(encoder,
                   state_prior=getattr(task, "initial_state_prior", None))
Beispiel #4
0
    def build_model(cls, args, task):
        """Build a new model instance."""

        max_source_positions = getattr(
            args, "max_source_positions", DEFAULT_MAX_SOURCE_POSITIONS
        )
        max_target_positions = getattr(
            args, "max_target_positions", DEFAULT_MAX_TARGET_POSITIONS
        )

        def load_pretrained_embedding_from_file(embed_path, dictionary, embed_dim):
            num_embeddings = len(dictionary)
            padding_idx = dictionary.pad()
            embed_tokens = Embedding(num_embeddings, embed_dim, padding_idx)
            embed_dict = utils.parse_embedding(embed_path)
            utils.print_embed_overlap(embed_dict, dictionary)
            return utils.load_embedding(embed_dict, dictionary, embed_tokens)

        # separate decoder input embeddings
        pretrained_decoder_embed = None
        if args.decoder_embed_path:
            pretrained_decoder_embed = load_pretrained_embedding_from_file(
                args.decoder_embed_path,
                task.target_dictionary,
                args.decoder_embed_dim,
            )
        # one last double check of parameter combinations
        if args.share_decoder_input_output_embed and (
            args.decoder_embed_dim != args.decoder_out_embed_dim
        ):
            raise ValueError(
                "--share-decoder-input-output-embed requires "
                "--decoder-embed-dim to match --decoder-out-embed-dim"
            )

        if args.decoder_freeze_embed:
            pretrained_decoder_embed.weight.requires_grad = False

        out_channels = speech_utils.eval_str_nested_list_or_tuple(
            args.encoder_conv_channels, type=int
        )
        kernel_sizes = speech_utils.eval_str_nested_list_or_tuple(
            args.encoder_conv_kernel_sizes, type=int
        )
        strides = speech_utils.eval_str_nested_list_or_tuple(
            args.encoder_conv_strides, type=int
        )
        logger.info(
            "input feature dimension: {}, channels: {}".format(
                task.feat_dim, task.feat_in_channels
            )
        )
        assert task.feat_dim % task.feat_in_channels == 0
        conv_layers = (
            ConvBNReLU(
                out_channels,
                kernel_sizes,
                strides,
                in_channels=task.feat_in_channels,
            )
            if out_channels is not None
            else None
        )

        rnn_encoder_input_size = task.feat_dim // task.feat_in_channels
        if conv_layers is not None:
            for stride in strides:
                if isinstance(stride, (list, tuple)):
                    assert len(stride) > 0
                    s = stride[1] if len(stride) > 1 else stride[0]
                else:
                    assert isinstance(stride, int)
                    s = stride
                rnn_encoder_input_size = (rnn_encoder_input_size + s - 1) // s
            rnn_encoder_input_size *= out_channels[-1]
        else:
            rnn_encoder_input_size = task.feat_dim

        if args.encoder_multilayer_rnn_as_single_module and args.encoder_rnn_residual:
            args.encoder_rnn_residual = False
            logger.info(
                "--encoder-rnn-residual is set to False when --encoder-multilayer-rnn-as-single-module=True"
            )

        scheduled_sampling_rate_scheduler = ScheduledSamplingRateScheduler(
            args.scheduled_sampling_probs,
            args.start_scheduled_sampling_epoch,
        )

        encoder = SpeechLSTMEncoder(
            pre_encoder=conv_layers,
            input_size=rnn_encoder_input_size,
            hidden_size=args.encoder_rnn_hidden_size,
            num_layers=args.encoder_rnn_layers,
            dropout_in=args.encoder_rnn_dropout_in,
            dropout_out=args.encoder_rnn_dropout_out,
            bidirectional=args.encoder_rnn_bidirectional,
            residual=args.encoder_rnn_residual,
            src_bucketed=(getattr(task.cfg, "num_batch_buckets", 0) > 0),
            max_source_positions=max_source_positions,
            multilayer_rnn_as_single_module=args.encoder_multilayer_rnn_as_single_module,
        )
        decoder = SpeechLSTMDecoder(
            dictionary=task.target_dictionary,
            embed_dim=args.decoder_embed_dim,
            hidden_size=args.decoder_hidden_size,
            out_embed_dim=args.decoder_out_embed_dim,
            num_layers=args.decoder_layers,
            dropout_in=args.decoder_dropout_in,
            dropout_out=args.decoder_dropout_out,
            encoder_output_units=encoder.output_units,
            attn_type=args.attention_type,
            attn_dim=args.attention_dim,
            need_attn=args.need_attention,
            residual=args.decoder_rnn_residual,
            pretrained_embed=pretrained_decoder_embed,
            share_input_output_embed=args.share_decoder_input_output_embed,
            adaptive_softmax_cutoff=(
                utils.eval_str_list(args.adaptive_softmax_cutoff, type=int)
                if args.criterion_name == "adaptive_loss"
                else None
            ),
            max_target_positions=max_target_positions,
            scheduled_sampling_rate_scheduler=scheduled_sampling_rate_scheduler,
        )
        pretrained_lm = None
        if args.pretrained_lm_checkpoint:
            logger.info(
                "loading pretrained LM from {}".format(args.pretrained_lm_checkpoint)
            )
            pretrained_lm = checkpoint_utils.load_model_ensemble(
                args.pretrained_lm_checkpoint, task=task
            )[0][0]
            pretrained_lm.make_generation_fast_()
            # freeze pretrained model
            for param in pretrained_lm.parameters():
                param.requires_grad = False
        return cls(encoder, decoder, pretrained_lm)
Beispiel #5
0
    def build_model(cls, args, task):
        """Build a new model instance."""

        # make sure that all args are properly defaulted (in case there are any new ones)
        base_architecture(args)

        if args.encoder_layers_to_keep:
            args.encoder_layers = len(args.encoder_layers_to_keep.split(","))

        if getattr(args, "max_source_positions", None) is None:
            args.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS

        if getattr(args, "offload_activations", False):
            args.checkpoint_activations = True  # offloading implies checkpointing

        out_channels = speech_utils.eval_str_nested_list_or_tuple(
            args.encoder_conv_channels, type=int
        )
        kernel_sizes = speech_utils.eval_str_nested_list_or_tuple(
            args.encoder_conv_kernel_sizes, type=int
        )
        strides = speech_utils.eval_str_nested_list_or_tuple(
            args.encoder_conv_strides, type=int
        )
        logger.info(
            "input feature dimension: {}, channels: {}".format(
                task.feat_dim, task.feat_in_channels
            )
        )
        assert task.feat_dim % task.feat_in_channels == 0
        conv_layers = (
            ConvBNReLU(
                out_channels,
                kernel_sizes,
                strides,
                in_channels=task.feat_in_channels,
            )
            if out_channels is not None
            else None
        )

        transformer_encoder_input_size = task.feat_dim // task.feat_in_channels
        if conv_layers is not None:
            for stride in strides:
                if isinstance(stride, (list, tuple)):
                    assert len(stride) > 0
                    s = stride[1] if len(stride) > 1 else stride[0]
                else:
                    assert isinstance(stride, int)
                    s = stride
                transformer_encoder_input_size = (
                    transformer_encoder_input_size + s - 1
                ) // s
            transformer_encoder_input_size *= out_channels[-1]
        else:
            transformer_encoder_input_size = task.feat_dim

        encoder_transformer_context = speech_utils.eval_str_nested_list_or_tuple(
            args.encoder_transformer_context,
            type=int,
        )
        if encoder_transformer_context is not None:
            assert len(encoder_transformer_context) == 2
            for i in range(2):
                assert encoder_transformer_context[i] is None or (
                    isinstance(encoder_transformer_context[i], int)
                    and encoder_transformer_context[i] >= 0
                )

        encoder = cls.build_encoder(
            args,
            pre_encoder=conv_layers,
            input_size=transformer_encoder_input_size,
            transformer_context=encoder_transformer_context,
            num_targets=getattr(
                task, "num_targets", None
            ),  # targets for encoder-only model
            chunk_width=getattr(task, "chunk_width", None),
            chunk_left_context=getattr(task, "chunk_left_context", 0),
            training_stage=getattr(task, "training_stage", True),
        )
        # fsdp_wrap is a no-op when --ddp-backend != fully_sharded
        encoder = fsdp_wrap(encoder, min_num_params=1e8)
        return cls(
            args, encoder, state_prior=getattr(task, "initial_state_prior", None)
        )
    def build_model(cls, args, task):
        """Build a new model instance."""

        # make sure all arguments are present in older models
        base_architecture(args)

        if args.encoder_layers_to_keep:
            args.encoder_layers = len(args.encoder_layers_to_keep.split(","))
        if args.decoder_layers_to_keep:
            args.decoder_layers = len(args.decoder_layers_to_keep.split(","))

        if getattr(args, "max_source_positions", None) is None:
            args.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS
        if getattr(args, "max_target_positions", None) is None:
            args.max_target_positions = DEFAULT_MAX_TARGET_POSITIONS

        tgt_dict = task.target_dictionary

        decoder_embed_tokens = cls.build_embedding(args, tgt_dict,
                                                   args.decoder_input_dim,
                                                   args.decoder_embed_path)

        out_channels = speech_utils.eval_str_nested_list_or_tuple(
            args.encoder_conv_channels, type=int)
        kernel_sizes = speech_utils.eval_str_nested_list_or_tuple(
            args.encoder_conv_kernel_sizes, type=int)
        strides = speech_utils.eval_str_nested_list_or_tuple(
            args.encoder_conv_strides, type=int)
        logger.info("input feature dimension: {}, channels: {}".format(
            task.feat_dim, task.feat_in_channels))
        assert task.feat_dim % task.feat_in_channels == 0
        conv_layers = ConvBNReLU(
            out_channels,
            kernel_sizes,
            strides,
            in_channels=task.feat_in_channels,
        ) if out_channels is not None else None

        transformer_encoder_input_size = task.feat_dim // task.feat_in_channels
        if conv_layers is not None:
            for stride in strides:
                if isinstance(stride, (list, tuple)):
                    assert len(stride) > 0
                    s = stride[1] if len(stride) > 1 else stride[0]
                else:
                    assert isinstance(stride, int)
                    s = stride
                transformer_encoder_input_size = (
                    transformer_encoder_input_size + s - 1) // s
            transformer_encoder_input_size *= out_channels[-1]
        else:
            transformer_encoder_input_size = task.feat_dim

        encoder_transformer_context = speech_utils.eval_str_nested_list_or_tuple(
            args.encoder_transformer_context,
            type=int,
        )
        if encoder_transformer_context is not None:
            assert len(encoder_transformer_context) == 2
            for i in range(2):
                assert (encoder_transformer_context[i] is None
                        or (isinstance(encoder_transformer_context[i], int)
                            and encoder_transformer_context[i] >= 0))

        scheduled_sampling_rate_scheduler = ScheduledSamplingRateScheduler(
            args.scheduled_sampling_probs,
            args.start_scheduled_sampling_epoch,
        )

        encoder = cls.build_encoder(
            args,
            conv_layers_before=conv_layers,
            input_size=transformer_encoder_input_size,
            transformer_context=encoder_transformer_context,
        )
        decoder = cls.build_decoder(
            args,
            tgt_dict,
            decoder_embed_tokens,
            scheduled_sampling_rate_scheduler=scheduled_sampling_rate_scheduler,
        )
        return cls(args, encoder, decoder)
Beispiel #7
0
    def build_model(cls, args, task):
        """Build a new model instance."""

        # make sure all arguments are present in older models
        base_architecture(args)

        if not hasattr(args, 'max_source_positions'):
            args.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS
        if not hasattr(args, 'max_target_positions'):
            args.max_target_positions = DEFAULT_MAX_TARGET_POSITIONS

        dict = task.target_dictionary

        def build_embedding(dictionary, embed_dim, path=None):
            num_embeddings = len(dictionary)
            padding_idx = dictionary.pad()
            emb = Embedding(num_embeddings, embed_dim, padding_idx)
            # if provided, load from preloaded dictionaries
            if path:
                embed_dict = utils.parse_embedding(path)
                utils.load_embedding(embed_dict, dictionary, emb)
            return emb

        decoder_embed_tokens = build_embedding(dict, args.decoder_embed_dim,
                                               args.decoder_embed_path)

        out_channels = speech_utils.eval_str_nested_list_or_tuple(
            args.encoder_conv_channels, type=int)
        kernel_sizes = speech_utils.eval_str_nested_list_or_tuple(
            args.encoder_conv_kernel_sizes, type=int)
        strides = speech_utils.eval_str_nested_list_or_tuple(
            args.encoder_conv_strides, type=int)
        logger.info('input feature dimension: {}, channels: {}'.format(
            task.feat_dim, task.feat_in_channels))
        assert task.feat_dim % task.feat_in_channels == 0
        conv_layers = ConvBNReLU(
            out_channels,
            kernel_sizes,
            strides,
            in_channels=task.feat_in_channels,
        ) if out_channels is not None else None

        transformer_encoder_input_size = task.feat_dim // task.feat_in_channels
        if conv_layers is not None:
            for stride in strides:
                if isinstance(stride, (list, tuple)):
                    assert len(stride) > 0
                    s = stride[1] if len(stride) > 1 else stride[0]
                else:
                    assert isinstance(stride, int)
                    s = stride
                transformer_encoder_input_size = \
                    (transformer_encoder_input_size + s - 1) // s
            transformer_encoder_input_size *= out_channels[-1]

        encoder = cls.build_encoder(
            args,
            conv_layers_before=conv_layers,
            input_size=transformer_encoder_input_size,
        )
        decoder = cls.build_decoder(args, dict, decoder_embed_tokens)
        return cls(encoder, decoder)
Beispiel #8
0
    def build_model(cls, cfg, task):
        """Build a new model instance."""

        if cfg.encoder.layers_to_keep:
            cfg.encoder.layers = len(cfg.encoder.layers_to_keep.split(","))

        if cfg.max_source_positions is None:
            cfg.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS
        if cfg.max_target_positions is None:
            cfg.max_target_positions = DEFAULT_MAX_TARGET_POSITIONS

        tgt_dict = task.target_dictionary

        decoder_embed_tokens = cls.build_embedding(cfg, tgt_dict,
                                                   cfg.decoder.embed_dim,
                                                   cfg.decoder.embed_path)
        if cfg.offload_activations:
            cfg.checkpoint_activations = True  # offloading implies checkpointing

        out_channels = speech_utils.eval_str_nested_list_or_tuple(
            cfg.encoder.conv_channels, type=int)
        kernel_sizes = speech_utils.eval_str_nested_list_or_tuple(
            cfg.encoder.conv_kernel_sizes, type=int)
        strides = speech_utils.eval_str_nested_list_or_tuple(
            cfg.encoder.conv_strides, type=int)
        logger.info("input feature dimension: {}, channels: {}".format(
            task.feat_dim, task.feat_in_channels))
        assert task.feat_dim % task.feat_in_channels == 0
        conv_layers = (ConvBNReLU(
            out_channels,
            kernel_sizes,
            strides,
            in_channels=task.feat_in_channels,
        ) if out_channels is not None else None)

        transformer_encoder_input_size = task.feat_dim // task.feat_in_channels
        if conv_layers is not None:
            for stride in strides:
                if isinstance(stride, (list, tuple)):
                    assert len(stride) > 0
                    s = stride[1] if len(stride) > 1 else stride[0]
                else:
                    assert isinstance(stride, int)
                    s = stride
                transformer_encoder_input_size = (
                    transformer_encoder_input_size + s - 1) // s
            transformer_encoder_input_size *= out_channels[-1]
        else:
            transformer_encoder_input_size = task.feat_dim

        encoder_transformer_context = speech_utils.eval_str_nested_list_or_tuple(
            cfg.encoder.transformer_context,
            type=int,
        )
        if encoder_transformer_context is not None:
            assert len(encoder_transformer_context) == 2
            for i in range(2):
                assert encoder_transformer_context[i] is None or (
                    isinstance(encoder_transformer_context[i], int)
                    and encoder_transformer_context[i] >= 0)

        encoder = cls.build_encoder(
            cfg,
            pre_encoder=conv_layers,
            input_size=transformer_encoder_input_size,
            transformer_context=encoder_transformer_context,
        )
        decoder = cls.build_decoder(cfg, tgt_dict, decoder_embed_tokens)
        # fsdp_wrap is a no-op when --ddp-backend != fully_sharded
        encoder = fsdp_wrap(encoder, min_num_params=cfg.min_params_to_wrap)
        decoder = fsdp_wrap(decoder, min_num_params=cfg.min_params_to_wrap)
        return cls(cfg, encoder, decoder)
Beispiel #9
0
    def build_model(cls, cfg, task):
        """Build a new model instance."""

        # -- TODO T96535332
        # bug caused by interaction between OmegaConf II and argparsing
        cfg.decoder.input_dim = int(cfg.decoder.input_dim)
        cfg.decoder.output_dim = int(cfg.decoder.output_dim)
        # --

        if cfg.encoder.layers_to_keep:
            cfg.encoder.layers = len(cfg.encoder.layers_to_keep.split(","))
        if cfg.decoder.layers_to_keep:
            cfg.decoder.layers = len(cfg.decoder.layers_to_keep.split(","))

        if cfg.max_source_positions is None:
            cfg.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS
        if cfg.max_target_positions is None:
            cfg.max_target_positions = DEFAULT_MAX_TARGET_POSITIONS

        tgt_dict = task.target_dictionary

        decoder_embed_tokens = cls.build_embedding(cfg, tgt_dict,
                                                   cfg.decoder.input_dim,
                                                   cfg.decoder.embed_path)
        if cfg.offload_activations:
            cfg.checkpoint_activations = True  # offloading implies checkpointing

        out_channels = speech_utils.eval_str_nested_list_or_tuple(
            cfg.encoder.conv_channels, type=int)
        kernel_sizes = speech_utils.eval_str_nested_list_or_tuple(
            cfg.encoder.conv_kernel_sizes, type=int)
        strides = speech_utils.eval_str_nested_list_or_tuple(
            cfg.encoder.conv_strides, type=int)
        logger.info("input feature dimension: {}, channels: {}".format(
            task.feat_dim, task.feat_in_channels))
        assert task.feat_dim % task.feat_in_channels == 0
        conv_layers = (ConvBNReLU(
            out_channels,
            kernel_sizes,
            strides,
            in_channels=task.feat_in_channels,
        ) if out_channels is not None else None)

        transformer_encoder_input_size = task.feat_dim // task.feat_in_channels
        if conv_layers is not None:
            for stride in strides:
                if isinstance(stride, (list, tuple)):
                    assert len(stride) > 0
                    s = stride[1] if len(stride) > 1 else stride[0]
                else:
                    assert isinstance(stride, int)
                    s = stride
                transformer_encoder_input_size = (
                    transformer_encoder_input_size + s - 1) // s
            transformer_encoder_input_size *= out_channels[-1]
        else:
            transformer_encoder_input_size = task.feat_dim

        encoder_transformer_context = speech_utils.eval_str_nested_list_or_tuple(
            cfg.encoder.transformer_context,
            type=int,
        )
        if encoder_transformer_context is not None:
            assert len(encoder_transformer_context) == 2
            for i in range(2):
                assert encoder_transformer_context[i] is None or (
                    isinstance(encoder_transformer_context[i], int)
                    and encoder_transformer_context[i] >= 0)

        scheduled_sampling_rate_scheduler = ScheduledSamplingRateScheduler(
            cfg.scheduled_sampling_probs,
            cfg.start_scheduled_sampling_epoch,
        )

        encoder = cls.build_encoder(
            cfg,
            pre_encoder=conv_layers,
            input_size=transformer_encoder_input_size,
            transformer_context=encoder_transformer_context,
        )
        decoder = cls.build_decoder(
            cfg,
            tgt_dict,
            decoder_embed_tokens,
            scheduled_sampling_rate_scheduler=scheduled_sampling_rate_scheduler,
        )
        return cls(cfg, encoder, decoder)