def build_model(cls, args, task): """Build a new model instance.""" # make sure all arguments are present in older models base_architecture(args) if args.encoder_layers_to_keep: args.encoder_layers = len(args.encoder_layers_to_keep.split(",")) if args.decoder_layers_to_keep: args.decoder_layers = len(args.decoder_layers_to_keep.split(",")) if getattr(args, "max_source_positions", None) is None: args.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS if getattr(args, "max_target_positions", None) is None: args.max_target_positions = DEFAULT_MAX_TARGET_POSITIONS src_dict, tgt_dict = task.source_dictionary, task.target_dictionary if args.share_all_embeddings: if src_dict != tgt_dict: raise ValueError("--share-all-embeddings requires a joined dictionary") if args.encoder_embed_dim != args.decoder_embed_dim: raise ValueError( "--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim" ) if args.decoder_embed_path and ( args.decoder_embed_path != args.encoder_embed_path ): raise ValueError( "--share-all-embeddings not compatible with --decoder-embed-path" ) encoder_embed_tokens = cls.build_embedding( args, src_dict, args.encoder_embed_dim, args.encoder_embed_path ) decoder_embed_tokens = encoder_embed_tokens args.share_decoder_input_output_embed = True else: encoder_embed_tokens = cls.build_embedding( args, src_dict, args.encoder_embed_dim, args.encoder_embed_path ) decoder_embed_tokens = cls.build_embedding( args, tgt_dict, args.decoder_embed_dim, args.decoder_embed_path ) if getattr(args, "offload_activations", False): args.checkpoint_activations = True # offloading implies checkpointing # build domain embeddings from pretrained graph embeddings tag_embedding = Embedding(49, args.encoder_embed_dim, 0) encoder = cls.build_encoder(args, src_dict, encoder_embed_tokens, tag_embedding) decoder = cls.build_decoder(args, tgt_dict, decoder_embed_tokens) if not args.share_all_embeddings: min_params_to_wrap = getattr( args, "min_params_to_wrap", DEFAULT_MIN_PARAMS_TO_WRAP ) # fsdp_wrap is a no-op when --ddp-backend != fully_sharded encoder = fsdp_wrap(encoder, min_num_params=min_params_to_wrap) decoder = fsdp_wrap(decoder, min_num_params=min_params_to_wrap) return cls(args, encoder, decoder)
def build_model(cls, cfg, task): """Build a new model instance.""" # -- TODO T96535332 # bug caused by interaction between OmegaConf II and argparsing cfg.decoder.input_dim = int(cfg.decoder.input_dim) cfg.decoder.output_dim = int(cfg.decoder.output_dim) # -- if cfg.encoder.layers_to_keep: cfg.encoder.layers = len(cfg.encoder.layers_to_keep.split(",")) if cfg.decoder.layers_to_keep: cfg.decoder.layers = len(cfg.decoder.layers_to_keep.split(",")) src_dict, tgt_dict = task.source_dictionary, task.target_dictionary if cfg.share_all_embeddings: if src_dict != tgt_dict: raise ValueError( "--share-all-embeddings requires a joined dictionary") if cfg.encoder.embed_dim != cfg.decoder.embed_dim: raise ValueError( "--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim" ) if cfg.decoder.embed_path and (cfg.decoder.embed_path != cfg.encoder.embed_path): raise ValueError( "--share-all-embeddings not compatible with --decoder-embed-path" ) encoder_embed_tokens = cls.build_embedding(cfg, src_dict, cfg.encoder.embed_dim, cfg.encoder.embed_path) decoder_embed_tokens = encoder_embed_tokens cfg.share_decoder_input_output_embed = True else: encoder_embed_tokens = cls.build_embedding(cfg, src_dict, cfg.encoder.embed_dim, cfg.encoder.embed_path) decoder_embed_tokens = cls.build_embedding(cfg, tgt_dict, cfg.decoder.embed_dim, cfg.decoder.embed_path) if cfg.offload_activations: cfg.checkpoint_activations = True # offloading implies checkpointing encoder = cls.build_encoder(cfg, src_dict, encoder_embed_tokens) decoder = cls.build_decoder(cfg, tgt_dict, decoder_embed_tokens) if not cfg.share_all_embeddings: # fsdp_wrap is a no-op when --ddp-backend != fully_sharded encoder = fsdp_wrap(encoder, min_num_params=cfg.min_params_to_wrap) decoder = fsdp_wrap(decoder, min_num_params=cfg.min_params_to_wrap) return cls(cfg, encoder, decoder)
def build_decoder_layer(self, cfg, no_encoder_attn=False): if self.cfg.shared_layer_qkv_conv == 1 and self.compress_layer is None: target_dim = cfg.compressed_dim compress_layer = nn.Linear( self.cfg.max_positions, target_dim, ) nn.init.xavier_uniform_(compress_layer.weight, gain=1 / math.sqrt(2)) if self.cfg.freeze_conv == 1: compress_layer.weight.requires_grad = False self.compress_layer = compress_layer #return PrimerTransformerEncoderLayer(cfg, self.compress_layer) layer = primer_layer.PrimerDecoderLayerBase(cfg, self.compress_layer, no_encoder_attn) checkpoint = cfg.checkpoint_activations if checkpoint: offload_to_cpu = cfg.offload_activations layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu) # if we are checkpointing, enforce that FSDP always wraps the # checkpointed layer, regardless of layer size min_params_to_wrap = cfg.min_params_to_wrap if not checkpoint else 0 layer = fsdp_wrap(layer, min_num_params=min_params_to_wrap) return layer
def build_encoder_layer(self, args: Wav2Vec2Config): if args.layer_type == "transformer": layer = TransformerSentenceEncoderLayer( embedding_dim=self.embedding_dim, ffn_embedding_dim=args.encoder_ffn_embed_dim, num_attention_heads=args.encoder_attention_heads, dropout=self.dropout, attention_dropout=args.attention_dropout, activation_dropout=args.activation_dropout, activation_fn=args.activation_fn, layer_norm_first=args.layer_norm_first, ) elif args.layer_type == "conformer": layer = ConformerWav2Vec2EncoderLayer( embed_dim=self.embedding_dim, ffn_embed_dim=args.encoder_ffn_embed_dim, attention_heads=args.encoder_attention_heads, dropout=args.dropout, depthwise_conv_kernel_size=args.depthwise_conv_kernel_size, activation_fn="swish", attn_type=args.attn_type, use_fp16=args.fp16, pos_enc_type="abs", ) layer = fsdp_wrap(layer) if args.checkpoint_activations: layer = checkpoint_wrapper(layer) return layer
def build_encoder_layer(self, args): layer = TransformerEncoderLayer(args) # if we are checkpointing, enforce that FSDP always wraps the # checkpointed layer, regardless of layer size min_params_to_wrap = (getattr(args, "min_params_to_wrap", DEFAULT_MIN_PARAMS_TO_WRAP)) layer = fsdp_wrap(layer, min_num_params=min_params_to_wrap) return layer
def build_encoder_layer(self, cfg): layer = transformer_layer.TransformerEncoderLayerBase(cfg) checkpoint = cfg.checkpoint_activations if checkpoint: offload_to_cpu = cfg.offload_activations layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu) # if we are checkpointing, enforce that FSDP always wraps the # checkpointed layer, regardless of layer size min_params_to_wrap = cfg.min_params_to_wrap if not checkpoint else 0 layer = fsdp_wrap(layer, min_num_params=min_params_to_wrap) return layer
def build_decoder_layer(self, cfg, no_encoder_attn=False): layer = UniLMDecoderLayer(cfg, no_encoder_attn) checkpoint = cfg.checkpoint_activations if checkpoint: offload_to_cpu = cfg.offload_activations layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu) # if we are checkpointing, enforce that FSDP always wraps the # checkpointed layer, regardless of layer size min_params_to_wrap = cfg.min_params_to_wrap if not checkpoint else 0 layer = fsdp_wrap(layer, min_num_params=min_params_to_wrap) return layer
def build_encoder_layer(self, args): layer = TransformerEncoderLayer(args) checkpoint = getattr(args, "checkpoint_activations", False) if checkpoint: offload_to_cpu = getattr(args, "offload_activations", False) layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu) # if we are checkpointing, enforce that FSDP always wraps the # checkpointed layer, regardless of layer size min_params_to_wrap = (getattr(args, "min_params_to_wrap", DEFAULT_MIN_PARAMS_TO_WRAP) if not checkpoint else 0) layer = fsdp_wrap(layer, min_num_params=min_params_to_wrap) return layer
def __init__(self, args): super().__init__() self.dropout = args.dropout self.embedding_dim = args.encoder_embed_dim self.pos_conv = nn.Conv1d( self.embedding_dim, self.embedding_dim, kernel_size=args.conv_pos, padding=args.conv_pos // 2, groups=args.conv_pos_groups, ) dropout = 0 std = math.sqrt( (4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim)) nn.init.normal_(self.pos_conv.weight, mean=0, std=std) nn.init.constant_(self.pos_conv.bias, 0) self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2) self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU()) layers = [] for _ in range(args.encoder_layers): layer = TransformerSentenceEncoderLayer( embedding_dim=self.embedding_dim, ffn_embedding_dim=args.encoder_ffn_embed_dim, num_attention_heads=args.encoder_attention_heads, dropout=self.dropout, attention_dropout=args.attention_dropout, activation_dropout=args.activation_dropout, activation_fn=args.activation_fn, layer_norm_first=args.layer_norm_first, ) if args.checkpoint_activations: layer = fsdp_wrap(layer) layer = checkpoint_wrapper(layer) layers.append(layer) self.layers = nn.ModuleList(layers) self.layer_norm_first = args.layer_norm_first self.layer_norm = LayerNorm(self.embedding_dim) self.layerdrop = args.encoder_layerdrop self.apply(init_bert_params)
def build_encoder_layer(self, args): layer = ConformerWav2Vec2EncoderLayer( embed_dim=self.embedding_dim, ffn_embed_dim=args.encoder_ffn_embed_dim, attention_heads=args.encoder_attention_heads, dropout=args.dropout, depthwise_conv_kernel_size=args.depthwise_conv_kernel_size, activation_fn="swish", attn_type=args.attn_type, pos_enc_type=args.pos_enc_type, use_fp16=args.fp16, # only used for rope ) layer = fsdp_wrap(layer) if args.checkpoint_activations: layer = checkpoint_wrapper(layer) return layer
def build_decoder_layer( self, cfg, no_encoder_attn=False, positional_embedding: Optional[RelativePositionalEmbedding] = None, ): layer = TransformerWithRelativePositionalEmbeddingDecoderLayerBase( cfg, no_encoder_attn=no_encoder_attn, positional_embedding=positional_embedding, ) checkpoint = cfg.checkpoint_activations if checkpoint: offload_to_cpu = cfg.offload_activations layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu) # if we are checkpointing, enforce that FSDP always wraps the # checkpointed layer, regardless of layer size min_params_to_wrap = cfg.min_params_to_wrap if not checkpoint else 0 layer = fsdp_wrap(layer, min_num_params=min_params_to_wrap) return layer
def build_encoder_layer( self, cfg, positional_embedding: Optional[RelativePositionalEmbedding] = None ): if cfg.encoder.layer_type == "transformer": layer_cls = TransformerWithRelativePositionalEmbeddingEncoderLayerBase elif cfg.encoder.layer_type == "conformer": layer_cls = ConformerWithRelativePositionalEmbeddingEncoderLayerBase else: raise NotImplementedError layer = layer_cls( cfg, return_fc=self.return_fc, positional_embedding=positional_embedding ) checkpoint = cfg.checkpoint_activations if checkpoint: offload_to_cpu = cfg.offload_activations layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu) # if we are checkpointing, enforce that FSDP always wraps the # checkpointed layer, regardless of layer size min_params_to_wrap = cfg.min_params_to_wrap if not checkpoint else 0 layer = fsdp_wrap(layer, min_num_params=min_params_to_wrap) return layer
def build_model(cls, args, task): """Build a new model instance.""" # make sure that all args are properly defaulted (in case there are any new ones) base_architecture(args) if args.encoder_layers_to_keep: args.encoder_layers = len(args.encoder_layers_to_keep.split(",")) if getattr(args, "max_source_positions", None) is None: args.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS if getattr(args, "offload_activations", False): args.checkpoint_activations = True # offloading implies checkpointing out_channels = speech_utils.eval_str_nested_list_or_tuple( args.encoder_conv_channels, type=int ) kernel_sizes = speech_utils.eval_str_nested_list_or_tuple( args.encoder_conv_kernel_sizes, type=int ) strides = speech_utils.eval_str_nested_list_or_tuple( args.encoder_conv_strides, type=int ) logger.info( "input feature dimension: {}, channels: {}".format( task.feat_dim, task.feat_in_channels ) ) assert task.feat_dim % task.feat_in_channels == 0 conv_layers = ( ConvBNReLU( out_channels, kernel_sizes, strides, in_channels=task.feat_in_channels, ) if out_channels is not None else None ) transformer_encoder_input_size = task.feat_dim // task.feat_in_channels if conv_layers is not None: for stride in strides: if isinstance(stride, (list, tuple)): assert len(stride) > 0 s = stride[1] if len(stride) > 1 else stride[0] else: assert isinstance(stride, int) s = stride transformer_encoder_input_size = ( transformer_encoder_input_size + s - 1 ) // s transformer_encoder_input_size *= out_channels[-1] else: transformer_encoder_input_size = task.feat_dim encoder_transformer_context = speech_utils.eval_str_nested_list_or_tuple( args.encoder_transformer_context, type=int, ) if encoder_transformer_context is not None: assert len(encoder_transformer_context) == 2 for i in range(2): assert encoder_transformer_context[i] is None or ( isinstance(encoder_transformer_context[i], int) and encoder_transformer_context[i] >= 0 ) encoder = cls.build_encoder( args, pre_encoder=conv_layers, input_size=transformer_encoder_input_size, transformer_context=encoder_transformer_context, num_targets=getattr( task, "num_targets", None ), # targets for encoder-only model chunk_width=getattr(task, "chunk_width", None), chunk_left_context=getattr(task, "chunk_left_context", 0), training_stage=getattr(task, "training_stage", True), ) # fsdp_wrap is a no-op when --ddp-backend != fully_sharded encoder = fsdp_wrap(encoder, min_num_params=1e8) return cls( args, encoder, state_prior=getattr(task, "initial_state_prior", None) )
def main(cfg: FairseqConfig) -> None: if isinstance(cfg, argparse.Namespace): cfg = convert_namespace_to_omegaconf(cfg) utils.import_user_module(cfg.common) if distributed_utils.is_master( cfg.distributed_training) and "job_logging_cfg" in cfg: # make hydra logging work with ddp (see # see https://github.com/facebookresearch/hydra/issues/1126) logging.config.dictConfig(OmegaConf.to_container(cfg.job_logging_cfg)) assert ( cfg.dataset.max_tokens is not None or cfg.dataset.batch_size is not None ), "Must specify batch size either with --max-tokens or --batch-size" metrics.reset() np.random.seed(cfg.common.seed) utils.set_torch_seed(cfg.common.seed) if distributed_utils.is_master(cfg.distributed_training): checkpoint_utils.verify_checkpoint_directory(cfg.checkpoint.save_dir) # Print args logger.info(cfg) if cfg.checkpoint.write_checkpoints_asynchronously: try: import iopath # noqa: F401 except ImportError: logging.exception( "Asynchronous checkpoint writing is specified but iopath is " "not installed: `pip install iopath`") return # Setup task, e.g., translation, language modeling, etc. task = tasks.setup_task(cfg.task) assert cfg.criterion, "Please specify criterion to train a model" # Build model and criterion if cfg.distributed_training.ddp_backend == "fully_sharded": with fsdp_enable_wrap(cfg.distributed_training): model = fsdp_wrap(task.build_model(cfg.model)) else: model = task.build_model(cfg.model) criterion = task.build_criterion(cfg.criterion) logger.info(model) logger.info("task: {}".format(task.__class__.__name__)) logger.info("model: {}".format(model.__class__.__name__)) logger.info("criterion: {}".format(criterion.__class__.__name__)) logger.info("num. shared model params: {:,} (num. trained: {:,})".format( sum(p.numel() for p in model.parameters() if not getattr(p, "expert", False)), sum(p.numel() for p in model.parameters() if not getattr(p, "expert", False) and p.requires_grad))) logger.info("num. expert model params: {} (num. trained: {})".format( sum(p.numel() for p in model.parameters() if getattr(p, "expert", False)), sum(p.numel() for p in model.parameters() if getattr(p, "expert", False) and p.requires_grad), )) # Load valid dataset (we load training data below, based on the latest checkpoint) # We load the valid dataset AFTER building the model for valid_sub_split in cfg.dataset.valid_subset.split(","): task.load_dataset(valid_sub_split, combine=False, epoch=1) # (optionally) Configure quantization if cfg.common.quantization_config_path is not None: quantizer = quantization_utils.Quantizer( config_path=cfg.common.quantization_config_path, max_epoch=cfg.optimization.max_epoch, max_update=cfg.optimization.max_update, ) else: quantizer = None # Build trainer if cfg.common.model_parallel_size == 1: trainer = Trainer(cfg, task, model, criterion, quantizer) else: trainer = MegatronTrainer(cfg, task, model, criterion) logger.info("training on {} devices (GPUs/TPUs)".format( cfg.distributed_training.distributed_world_size)) logger.info( "max tokens per device = {} and max sentences per device = {}".format( cfg.dataset.max_tokens, cfg.dataset.batch_size, )) # Load the latest checkpoint if one is available and restore the # corresponding train iterator extra_state, epoch_itr = checkpoint_utils.load_checkpoint( cfg.checkpoint, trainer, # don't cache epoch iterators for sharded datasets disable_iterator_cache=task.has_sharded_data("train"), ) if cfg.common.tpu: import torch_xla.core.xla_model as xm xm.rendezvous("load_checkpoint") # wait for all workers max_epoch = cfg.optimization.max_epoch or math.inf lr = trainer.get_lr() train_meter = meters.StopwatchMeter() train_meter.start() while epoch_itr.next_epoch_idx <= max_epoch: if lr <= cfg.optimization.stop_min_lr: logger.info( f"stopping training because current learning rate ({lr}) is smaller " "than or equal to minimum learning rate " f"(--stop-min-lr={cfg.optimization.stop_min_lr})") break # train for one epoch valid_losses, should_stop = train(cfg, trainer, task, epoch_itr) if should_stop: break # only use first validation loss to update the learning rate lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0]) epoch_itr = trainer.get_train_iterator( epoch_itr.next_epoch_idx, # sharded data: get train iterator for next epoch load_dataset=task.has_sharded_data("train"), # don't cache epoch iterators for sharded datasets disable_iterator_cache=task.has_sharded_data("train"), ) train_meter.stop() logger.info("done training in {:.1f} seconds".format(train_meter.sum)) # ioPath implementation to wait for all asynchronous file writes to complete. if cfg.checkpoint.write_checkpoints_asynchronously: logger.info( "ioPath PathManager waiting for all asynchronous checkpoint " "writes to finish.") PathManager.async_close() logger.info("ioPath PathManager finished waiting.")
def build_model(cls, args, task): """Build a new model instance.""" # make sure all arguments are present in older models base_architecture(args) if args.encoder_layers_to_keep: args.encoder_layers = len(args.encoder_layers_to_keep.split(",")) if args.decoder_layers_to_keep: args.decoder_layers = len(args.decoder_layers_to_keep.split(",")) if getattr(args, "max_source_positions", None) is None: args.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS if getattr(args, "max_target_positions", None) is None: args.max_target_positions = DEFAULT_MAX_TARGET_POSITIONS tgt_dict = task.target_dictionary decoder_embed_tokens = cls.build_embedding(args, tgt_dict, args.decoder_input_dim, args.decoder_embed_path) if getattr(args, "offload_activations", False): args.checkpoint_activations = True # offloading implies checkpointing out_channels = speech_utils.eval_str_nested_list_or_tuple( args.encoder_conv_channels, type=int) kernel_sizes = speech_utils.eval_str_nested_list_or_tuple( args.encoder_conv_kernel_sizes, type=int) strides = speech_utils.eval_str_nested_list_or_tuple( args.encoder_conv_strides, type=int) logger.info("input feature dimension: {}, channels: {}".format( task.feat_dim, task.feat_in_channels)) assert task.feat_dim % task.feat_in_channels == 0 conv_layers = ConvBNReLU( out_channels, kernel_sizes, strides, in_channels=task.feat_in_channels, ) if out_channels is not None else None transformer_encoder_input_size = task.feat_dim // task.feat_in_channels if conv_layers is not None: for stride in strides: if isinstance(stride, (list, tuple)): assert len(stride) > 0 s = stride[1] if len(stride) > 1 else stride[0] else: assert isinstance(stride, int) s = stride transformer_encoder_input_size = ( transformer_encoder_input_size + s - 1) // s transformer_encoder_input_size *= out_channels[-1] else: transformer_encoder_input_size = task.feat_dim encoder_transformer_context = speech_utils.eval_str_nested_list_or_tuple( args.encoder_transformer_context, type=int, ) if encoder_transformer_context is not None: assert len(encoder_transformer_context) == 2 for i in range(2): assert (encoder_transformer_context[i] is None or (isinstance(encoder_transformer_context[i], int) and encoder_transformer_context[i] >= 0)) scheduled_sampling_rate_scheduler = ScheduledSamplingRateScheduler( args.scheduled_sampling_probs, args.start_scheduled_sampling_epoch, ) encoder = cls.build_encoder( args, conv_layers_before=conv_layers, input_size=transformer_encoder_input_size, transformer_context=encoder_transformer_context, ) decoder = cls.build_decoder( args, tgt_dict, decoder_embed_tokens, scheduled_sampling_rate_scheduler=scheduled_sampling_rate_scheduler, ) min_params_to_wrap = getattr(args, "min_params_to_wrap", DEFAULT_MIN_PARAMS_TO_WRAP) # fsdp_wrap is a no-op when --ddp-backend != fully_sharded encoder = fsdp_wrap(encoder, min_num_params=min_params_to_wrap) decoder = fsdp_wrap(decoder, min_num_params=min_params_to_wrap) return cls(args, encoder, decoder)
def build_model(cls, cfg, task): """Build a new model instance.""" if cfg.encoder.layers_to_keep: cfg.encoder.layers = len(cfg.encoder.layers_to_keep.split(",")) if cfg.max_source_positions is None: cfg.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS if cfg.max_target_positions is None: cfg.max_target_positions = DEFAULT_MAX_TARGET_POSITIONS tgt_dict = task.target_dictionary decoder_embed_tokens = cls.build_embedding(cfg, tgt_dict, cfg.decoder.embed_dim, cfg.decoder.embed_path) if cfg.offload_activations: cfg.checkpoint_activations = True # offloading implies checkpointing out_channels = speech_utils.eval_str_nested_list_or_tuple( cfg.encoder.conv_channels, type=int) kernel_sizes = speech_utils.eval_str_nested_list_or_tuple( cfg.encoder.conv_kernel_sizes, type=int) strides = speech_utils.eval_str_nested_list_or_tuple( cfg.encoder.conv_strides, type=int) logger.info("input feature dimension: {}, channels: {}".format( task.feat_dim, task.feat_in_channels)) assert task.feat_dim % task.feat_in_channels == 0 conv_layers = (ConvBNReLU( out_channels, kernel_sizes, strides, in_channels=task.feat_in_channels, ) if out_channels is not None else None) transformer_encoder_input_size = task.feat_dim // task.feat_in_channels if conv_layers is not None: for stride in strides: if isinstance(stride, (list, tuple)): assert len(stride) > 0 s = stride[1] if len(stride) > 1 else stride[0] else: assert isinstance(stride, int) s = stride transformer_encoder_input_size = ( transformer_encoder_input_size + s - 1) // s transformer_encoder_input_size *= out_channels[-1] else: transformer_encoder_input_size = task.feat_dim encoder_transformer_context = speech_utils.eval_str_nested_list_or_tuple( cfg.encoder.transformer_context, type=int, ) if encoder_transformer_context is not None: assert len(encoder_transformer_context) == 2 for i in range(2): assert encoder_transformer_context[i] is None or ( isinstance(encoder_transformer_context[i], int) and encoder_transformer_context[i] >= 0) encoder = cls.build_encoder( cfg, pre_encoder=conv_layers, input_size=transformer_encoder_input_size, transformer_context=encoder_transformer_context, ) decoder = cls.build_decoder(cfg, tgt_dict, decoder_embed_tokens) # fsdp_wrap is a no-op when --ddp-backend != fully_sharded encoder = fsdp_wrap(encoder, min_num_params=cfg.min_params_to_wrap) decoder = fsdp_wrap(decoder, min_num_params=cfg.min_params_to_wrap) return cls(cfg, encoder, decoder)