Exemplo n.º 1
0
    def build_model(cls, args, task):
        """Build a new model instance."""

        # make sure all arguments are present in older models
        base_architecture(args)

        if args.encoder_layers_to_keep:
            args.encoder_layers = len(args.encoder_layers_to_keep.split(","))
        if args.decoder_layers_to_keep:
            args.decoder_layers = len(args.decoder_layers_to_keep.split(","))

        if getattr(args, "max_source_positions", None) is None:
            args.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS
        if getattr(args, "max_target_positions", None) is None:
            args.max_target_positions = DEFAULT_MAX_TARGET_POSITIONS

        src_dict, tgt_dict = task.source_dictionary, task.target_dictionary

        if args.share_all_embeddings:
            if src_dict != tgt_dict:
                raise ValueError("--share-all-embeddings requires a joined dictionary")
            if args.encoder_embed_dim != args.decoder_embed_dim:
                raise ValueError(
                    "--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim"
                )
            if args.decoder_embed_path and (
                args.decoder_embed_path != args.encoder_embed_path
            ):
                raise ValueError(
                    "--share-all-embeddings not compatible with --decoder-embed-path"
                )
            encoder_embed_tokens = cls.build_embedding(
                args, src_dict, args.encoder_embed_dim, args.encoder_embed_path
            )
            decoder_embed_tokens = encoder_embed_tokens
            args.share_decoder_input_output_embed = True
        else:
            encoder_embed_tokens = cls.build_embedding(
                args, src_dict, args.encoder_embed_dim, args.encoder_embed_path
            )
            decoder_embed_tokens = cls.build_embedding(
                args, tgt_dict, args.decoder_embed_dim, args.decoder_embed_path
            )
        if getattr(args, "offload_activations", False):
            args.checkpoint_activations = True  # offloading implies checkpointing

        # build domain embeddings from pretrained graph embeddings
        tag_embedding = Embedding(49, args.encoder_embed_dim, 0)

        encoder = cls.build_encoder(args, src_dict, encoder_embed_tokens, tag_embedding)
        decoder = cls.build_decoder(args, tgt_dict, decoder_embed_tokens)
        if not args.share_all_embeddings:
            min_params_to_wrap = getattr(
                args, "min_params_to_wrap", DEFAULT_MIN_PARAMS_TO_WRAP
            )
            # fsdp_wrap is a no-op when --ddp-backend != fully_sharded
            encoder = fsdp_wrap(encoder, min_num_params=min_params_to_wrap)
            decoder = fsdp_wrap(decoder, min_num_params=min_params_to_wrap)
        return cls(args, encoder, decoder)
Exemplo n.º 2
0
    def build_model(cls, cfg, task):
        """Build a new model instance."""

        # --  TODO T96535332
        #  bug caused by interaction between OmegaConf II and argparsing
        cfg.decoder.input_dim = int(cfg.decoder.input_dim)
        cfg.decoder.output_dim = int(cfg.decoder.output_dim)
        # --

        if cfg.encoder.layers_to_keep:
            cfg.encoder.layers = len(cfg.encoder.layers_to_keep.split(","))
        if cfg.decoder.layers_to_keep:
            cfg.decoder.layers = len(cfg.decoder.layers_to_keep.split(","))

        src_dict, tgt_dict = task.source_dictionary, task.target_dictionary

        if cfg.share_all_embeddings:
            if src_dict != tgt_dict:
                raise ValueError(
                    "--share-all-embeddings requires a joined dictionary")
            if cfg.encoder.embed_dim != cfg.decoder.embed_dim:
                raise ValueError(
                    "--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim"
                )
            if cfg.decoder.embed_path and (cfg.decoder.embed_path !=
                                           cfg.encoder.embed_path):
                raise ValueError(
                    "--share-all-embeddings not compatible with --decoder-embed-path"
                )
            encoder_embed_tokens = cls.build_embedding(cfg, src_dict,
                                                       cfg.encoder.embed_dim,
                                                       cfg.encoder.embed_path)
            decoder_embed_tokens = encoder_embed_tokens
            cfg.share_decoder_input_output_embed = True
        else:
            encoder_embed_tokens = cls.build_embedding(cfg, src_dict,
                                                       cfg.encoder.embed_dim,
                                                       cfg.encoder.embed_path)
            decoder_embed_tokens = cls.build_embedding(cfg, tgt_dict,
                                                       cfg.decoder.embed_dim,
                                                       cfg.decoder.embed_path)
        if cfg.offload_activations:
            cfg.checkpoint_activations = True  # offloading implies checkpointing
        encoder = cls.build_encoder(cfg, src_dict, encoder_embed_tokens)
        decoder = cls.build_decoder(cfg, tgt_dict, decoder_embed_tokens)
        if not cfg.share_all_embeddings:
            # fsdp_wrap is a no-op when --ddp-backend != fully_sharded
            encoder = fsdp_wrap(encoder, min_num_params=cfg.min_params_to_wrap)
            decoder = fsdp_wrap(decoder, min_num_params=cfg.min_params_to_wrap)
        return cls(cfg, encoder, decoder)
Exemplo n.º 3
0
    def build_decoder_layer(self, cfg, no_encoder_attn=False):
        if self.cfg.shared_layer_qkv_conv == 1 and self.compress_layer is None:
            target_dim = cfg.compressed_dim
            compress_layer = nn.Linear(
                self.cfg.max_positions,
                target_dim,
            )

            nn.init.xavier_uniform_(compress_layer.weight,
                                    gain=1 / math.sqrt(2))
            if self.cfg.freeze_conv == 1:
                compress_layer.weight.requires_grad = False
            self.compress_layer = compress_layer

        #return PrimerTransformerEncoderLayer(cfg, self.compress_layer)
        layer = primer_layer.PrimerDecoderLayerBase(cfg, self.compress_layer,
                                                    no_encoder_attn)
        checkpoint = cfg.checkpoint_activations
        if checkpoint:
            offload_to_cpu = cfg.offload_activations
            layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu)
        # if we are checkpointing, enforce that FSDP always wraps the
        # checkpointed layer, regardless of layer size
        min_params_to_wrap = cfg.min_params_to_wrap if not checkpoint else 0
        layer = fsdp_wrap(layer, min_num_params=min_params_to_wrap)
        return layer
Exemplo n.º 4
0
 def build_encoder_layer(self, args: Wav2Vec2Config):
     if args.layer_type == "transformer":
         layer = TransformerSentenceEncoderLayer(
             embedding_dim=self.embedding_dim,
             ffn_embedding_dim=args.encoder_ffn_embed_dim,
             num_attention_heads=args.encoder_attention_heads,
             dropout=self.dropout,
             attention_dropout=args.attention_dropout,
             activation_dropout=args.activation_dropout,
             activation_fn=args.activation_fn,
             layer_norm_first=args.layer_norm_first,
         )
     elif args.layer_type == "conformer":
         layer = ConformerWav2Vec2EncoderLayer(
             embed_dim=self.embedding_dim,
             ffn_embed_dim=args.encoder_ffn_embed_dim,
             attention_heads=args.encoder_attention_heads,
             dropout=args.dropout,
             depthwise_conv_kernel_size=args.depthwise_conv_kernel_size,
             activation_fn="swish",
             attn_type=args.attn_type,
             use_fp16=args.fp16,
             pos_enc_type="abs",
         )
     layer = fsdp_wrap(layer)
     if args.checkpoint_activations:
         layer = checkpoint_wrapper(layer)
     return layer
Exemplo n.º 5
0
    def build_encoder_layer(self, args):
        layer = TransformerEncoderLayer(args)

        # if we are checkpointing, enforce that FSDP always wraps the
        # checkpointed layer, regardless of layer size
        min_params_to_wrap = (getattr(args, "min_params_to_wrap",
                                      DEFAULT_MIN_PARAMS_TO_WRAP))
        layer = fsdp_wrap(layer, min_num_params=min_params_to_wrap)
        return layer
Exemplo n.º 6
0
 def build_encoder_layer(self, cfg):
     layer = transformer_layer.TransformerEncoderLayerBase(cfg)
     checkpoint = cfg.checkpoint_activations
     if checkpoint:
         offload_to_cpu = cfg.offload_activations
         layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu)
     # if we are checkpointing, enforce that FSDP always wraps the
     # checkpointed layer, regardless of layer size
     min_params_to_wrap = cfg.min_params_to_wrap if not checkpoint else 0
     layer = fsdp_wrap(layer, min_num_params=min_params_to_wrap)
     return layer
Exemplo n.º 7
0
 def build_decoder_layer(self, cfg, no_encoder_attn=False):
     layer = UniLMDecoderLayer(cfg, no_encoder_attn)
     checkpoint = cfg.checkpoint_activations
     if checkpoint:
         offload_to_cpu = cfg.offload_activations
         layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu)
     # if we are checkpointing, enforce that FSDP always wraps the
     # checkpointed layer, regardless of layer size
     min_params_to_wrap = cfg.min_params_to_wrap if not checkpoint else 0
     layer = fsdp_wrap(layer, min_num_params=min_params_to_wrap)
     return layer
Exemplo n.º 8
0
 def build_encoder_layer(self, args):
     layer = TransformerEncoderLayer(args)
     checkpoint = getattr(args, "checkpoint_activations", False)
     if checkpoint:
         offload_to_cpu = getattr(args, "offload_activations", False)
         layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu)
     # if we are checkpointing, enforce that FSDP always wraps the
     # checkpointed layer, regardless of layer size
     min_params_to_wrap = (getattr(args, "min_params_to_wrap",
                                   DEFAULT_MIN_PARAMS_TO_WRAP)
                           if not checkpoint else 0)
     layer = fsdp_wrap(layer, min_num_params=min_params_to_wrap)
     return layer
Exemplo n.º 9
0
    def __init__(self, args):
        super().__init__()

        self.dropout = args.dropout
        self.embedding_dim = args.encoder_embed_dim

        self.pos_conv = nn.Conv1d(
            self.embedding_dim,
            self.embedding_dim,
            kernel_size=args.conv_pos,
            padding=args.conv_pos // 2,
            groups=args.conv_pos_groups,
        )
        dropout = 0
        std = math.sqrt(
            (4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim))
        nn.init.normal_(self.pos_conv.weight, mean=0, std=std)
        nn.init.constant_(self.pos_conv.bias, 0)

        self.pos_conv = nn.utils.weight_norm(self.pos_conv,
                                             name="weight",
                                             dim=2)
        self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos),
                                      nn.GELU())

        layers = []
        for _ in range(args.encoder_layers):
            layer = TransformerSentenceEncoderLayer(
                embedding_dim=self.embedding_dim,
                ffn_embedding_dim=args.encoder_ffn_embed_dim,
                num_attention_heads=args.encoder_attention_heads,
                dropout=self.dropout,
                attention_dropout=args.attention_dropout,
                activation_dropout=args.activation_dropout,
                activation_fn=args.activation_fn,
                layer_norm_first=args.layer_norm_first,
            )
            if args.checkpoint_activations:
                layer = fsdp_wrap(layer)
                layer = checkpoint_wrapper(layer)
            layers.append(layer)
        self.layers = nn.ModuleList(layers)

        self.layer_norm_first = args.layer_norm_first
        self.layer_norm = LayerNorm(self.embedding_dim)
        self.layerdrop = args.encoder_layerdrop

        self.apply(init_bert_params)
Exemplo n.º 10
0
 def build_encoder_layer(self, args):
     layer = ConformerWav2Vec2EncoderLayer(
         embed_dim=self.embedding_dim,
         ffn_embed_dim=args.encoder_ffn_embed_dim,
         attention_heads=args.encoder_attention_heads,
         dropout=args.dropout,
         depthwise_conv_kernel_size=args.depthwise_conv_kernel_size,
         activation_fn="swish",
         attn_type=args.attn_type,
         pos_enc_type=args.pos_enc_type,
         use_fp16=args.fp16,  # only used for rope
     )
     layer = fsdp_wrap(layer)
     if args.checkpoint_activations:
         layer = checkpoint_wrapper(layer)
     return layer
Exemplo n.º 11
0
 def build_decoder_layer(
     self,
     cfg,
     no_encoder_attn=False,
     positional_embedding: Optional[RelativePositionalEmbedding] = None,
 ):
     layer = TransformerWithRelativePositionalEmbeddingDecoderLayerBase(
         cfg,
         no_encoder_attn=no_encoder_attn,
         positional_embedding=positional_embedding,
     )
     checkpoint = cfg.checkpoint_activations
     if checkpoint:
         offload_to_cpu = cfg.offload_activations
         layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu)
     # if we are checkpointing, enforce that FSDP always wraps the
     # checkpointed layer, regardless of layer size
     min_params_to_wrap = cfg.min_params_to_wrap if not checkpoint else 0
     layer = fsdp_wrap(layer, min_num_params=min_params_to_wrap)
     return layer
Exemplo n.º 12
0
 def build_encoder_layer(
     self, cfg, positional_embedding: Optional[RelativePositionalEmbedding] = None
 ):
     if cfg.encoder.layer_type == "transformer":
         layer_cls = TransformerWithRelativePositionalEmbeddingEncoderLayerBase
     elif cfg.encoder.layer_type == "conformer":
         layer_cls = ConformerWithRelativePositionalEmbeddingEncoderLayerBase
     else:
         raise NotImplementedError
     layer = layer_cls(
         cfg, return_fc=self.return_fc, positional_embedding=positional_embedding
     )
     checkpoint = cfg.checkpoint_activations
     if checkpoint:
         offload_to_cpu = cfg.offload_activations
         layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu)
     # if we are checkpointing, enforce that FSDP always wraps the
     # checkpointed layer, regardless of layer size
     min_params_to_wrap = cfg.min_params_to_wrap if not checkpoint else 0
     layer = fsdp_wrap(layer, min_num_params=min_params_to_wrap)
     return layer
Exemplo n.º 13
0
    def build_model(cls, args, task):
        """Build a new model instance."""

        # make sure that all args are properly defaulted (in case there are any new ones)
        base_architecture(args)

        if args.encoder_layers_to_keep:
            args.encoder_layers = len(args.encoder_layers_to_keep.split(","))

        if getattr(args, "max_source_positions", None) is None:
            args.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS

        if getattr(args, "offload_activations", False):
            args.checkpoint_activations = True  # offloading implies checkpointing

        out_channels = speech_utils.eval_str_nested_list_or_tuple(
            args.encoder_conv_channels, type=int
        )
        kernel_sizes = speech_utils.eval_str_nested_list_or_tuple(
            args.encoder_conv_kernel_sizes, type=int
        )
        strides = speech_utils.eval_str_nested_list_or_tuple(
            args.encoder_conv_strides, type=int
        )
        logger.info(
            "input feature dimension: {}, channels: {}".format(
                task.feat_dim, task.feat_in_channels
            )
        )
        assert task.feat_dim % task.feat_in_channels == 0
        conv_layers = (
            ConvBNReLU(
                out_channels,
                kernel_sizes,
                strides,
                in_channels=task.feat_in_channels,
            )
            if out_channels is not None
            else None
        )

        transformer_encoder_input_size = task.feat_dim // task.feat_in_channels
        if conv_layers is not None:
            for stride in strides:
                if isinstance(stride, (list, tuple)):
                    assert len(stride) > 0
                    s = stride[1] if len(stride) > 1 else stride[0]
                else:
                    assert isinstance(stride, int)
                    s = stride
                transformer_encoder_input_size = (
                    transformer_encoder_input_size + s - 1
                ) // s
            transformer_encoder_input_size *= out_channels[-1]
        else:
            transformer_encoder_input_size = task.feat_dim

        encoder_transformer_context = speech_utils.eval_str_nested_list_or_tuple(
            args.encoder_transformer_context,
            type=int,
        )
        if encoder_transformer_context is not None:
            assert len(encoder_transformer_context) == 2
            for i in range(2):
                assert encoder_transformer_context[i] is None or (
                    isinstance(encoder_transformer_context[i], int)
                    and encoder_transformer_context[i] >= 0
                )

        encoder = cls.build_encoder(
            args,
            pre_encoder=conv_layers,
            input_size=transformer_encoder_input_size,
            transformer_context=encoder_transformer_context,
            num_targets=getattr(
                task, "num_targets", None
            ),  # targets for encoder-only model
            chunk_width=getattr(task, "chunk_width", None),
            chunk_left_context=getattr(task, "chunk_left_context", 0),
            training_stage=getattr(task, "training_stage", True),
        )
        # fsdp_wrap is a no-op when --ddp-backend != fully_sharded
        encoder = fsdp_wrap(encoder, min_num_params=1e8)
        return cls(
            args, encoder, state_prior=getattr(task, "initial_state_prior", None)
        )
Exemplo n.º 14
0
def main(cfg: FairseqConfig) -> None:
    if isinstance(cfg, argparse.Namespace):
        cfg = convert_namespace_to_omegaconf(cfg)

    utils.import_user_module(cfg.common)

    if distributed_utils.is_master(
            cfg.distributed_training) and "job_logging_cfg" in cfg:
        # make hydra logging work with ddp (see # see https://github.com/facebookresearch/hydra/issues/1126)
        logging.config.dictConfig(OmegaConf.to_container(cfg.job_logging_cfg))

    assert (
        cfg.dataset.max_tokens is not None
        or cfg.dataset.batch_size is not None
    ), "Must specify batch size either with --max-tokens or --batch-size"
    metrics.reset()

    np.random.seed(cfg.common.seed)
    utils.set_torch_seed(cfg.common.seed)

    if distributed_utils.is_master(cfg.distributed_training):
        checkpoint_utils.verify_checkpoint_directory(cfg.checkpoint.save_dir)

    # Print args
    logger.info(cfg)

    if cfg.checkpoint.write_checkpoints_asynchronously:
        try:
            import iopath  # noqa: F401
        except ImportError:
            logging.exception(
                "Asynchronous checkpoint writing is specified but iopath is "
                "not installed: `pip install iopath`")
            return

    # Setup task, e.g., translation, language modeling, etc.
    task = tasks.setup_task(cfg.task)

    assert cfg.criterion, "Please specify criterion to train a model"

    # Build model and criterion
    if cfg.distributed_training.ddp_backend == "fully_sharded":
        with fsdp_enable_wrap(cfg.distributed_training):
            model = fsdp_wrap(task.build_model(cfg.model))
    else:
        model = task.build_model(cfg.model)
    criterion = task.build_criterion(cfg.criterion)
    logger.info(model)
    logger.info("task: {}".format(task.__class__.__name__))
    logger.info("model: {}".format(model.__class__.__name__))
    logger.info("criterion: {}".format(criterion.__class__.__name__))
    logger.info("num. shared model params: {:,} (num. trained: {:,})".format(
        sum(p.numel() for p in model.parameters()
            if not getattr(p, "expert", False)),
        sum(p.numel() for p in model.parameters()
            if not getattr(p, "expert", False) and p.requires_grad)))

    logger.info("num. expert model params: {} (num. trained: {})".format(
        sum(p.numel() for p in model.parameters()
            if getattr(p, "expert", False)),
        sum(p.numel() for p in model.parameters()
            if getattr(p, "expert", False) and p.requires_grad),
    ))

    # Load valid dataset (we load training data below, based on the latest checkpoint)
    # We load the valid dataset AFTER building the model
    for valid_sub_split in cfg.dataset.valid_subset.split(","):
        task.load_dataset(valid_sub_split, combine=False, epoch=1)

    # (optionally) Configure quantization
    if cfg.common.quantization_config_path is not None:
        quantizer = quantization_utils.Quantizer(
            config_path=cfg.common.quantization_config_path,
            max_epoch=cfg.optimization.max_epoch,
            max_update=cfg.optimization.max_update,
        )
    else:
        quantizer = None

    # Build trainer
    if cfg.common.model_parallel_size == 1:
        trainer = Trainer(cfg, task, model, criterion, quantizer)
    else:
        trainer = MegatronTrainer(cfg, task, model, criterion)
    logger.info("training on {} devices (GPUs/TPUs)".format(
        cfg.distributed_training.distributed_world_size))
    logger.info(
        "max tokens per device = {} and max sentences per device = {}".format(
            cfg.dataset.max_tokens,
            cfg.dataset.batch_size,
        ))

    # Load the latest checkpoint if one is available and restore the
    # corresponding train iterator
    extra_state, epoch_itr = checkpoint_utils.load_checkpoint(
        cfg.checkpoint,
        trainer,
        # don't cache epoch iterators for sharded datasets
        disable_iterator_cache=task.has_sharded_data("train"),
    )
    if cfg.common.tpu:
        import torch_xla.core.xla_model as xm
        xm.rendezvous("load_checkpoint")  # wait for all workers

    max_epoch = cfg.optimization.max_epoch or math.inf
    lr = trainer.get_lr()
    train_meter = meters.StopwatchMeter()
    train_meter.start()
    while epoch_itr.next_epoch_idx <= max_epoch:
        if lr <= cfg.optimization.stop_min_lr:
            logger.info(
                f"stopping training because current learning rate ({lr}) is smaller "
                "than or equal to minimum learning rate "
                f"(--stop-min-lr={cfg.optimization.stop_min_lr})")
            break

        # train for one epoch
        valid_losses, should_stop = train(cfg, trainer, task, epoch_itr)
        if should_stop:
            break

        # only use first validation loss to update the learning rate
        lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])

        epoch_itr = trainer.get_train_iterator(
            epoch_itr.next_epoch_idx,
            # sharded data: get train iterator for next epoch
            load_dataset=task.has_sharded_data("train"),
            # don't cache epoch iterators for sharded datasets
            disable_iterator_cache=task.has_sharded_data("train"),
        )
    train_meter.stop()
    logger.info("done training in {:.1f} seconds".format(train_meter.sum))

    # ioPath implementation to wait for all asynchronous file writes to complete.
    if cfg.checkpoint.write_checkpoints_asynchronously:
        logger.info(
            "ioPath PathManager waiting for all asynchronous checkpoint "
            "writes to finish.")
        PathManager.async_close()
        logger.info("ioPath PathManager finished waiting.")
Exemplo n.º 15
0
    def build_model(cls, args, task):
        """Build a new model instance."""

        # make sure all arguments are present in older models
        base_architecture(args)

        if args.encoder_layers_to_keep:
            args.encoder_layers = len(args.encoder_layers_to_keep.split(","))
        if args.decoder_layers_to_keep:
            args.decoder_layers = len(args.decoder_layers_to_keep.split(","))

        if getattr(args, "max_source_positions", None) is None:
            args.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS
        if getattr(args, "max_target_positions", None) is None:
            args.max_target_positions = DEFAULT_MAX_TARGET_POSITIONS

        tgt_dict = task.target_dictionary

        decoder_embed_tokens = cls.build_embedding(args, tgt_dict,
                                                   args.decoder_input_dim,
                                                   args.decoder_embed_path)
        if getattr(args, "offload_activations", False):
            args.checkpoint_activations = True  # offloading implies checkpointing

        out_channels = speech_utils.eval_str_nested_list_or_tuple(
            args.encoder_conv_channels, type=int)
        kernel_sizes = speech_utils.eval_str_nested_list_or_tuple(
            args.encoder_conv_kernel_sizes, type=int)
        strides = speech_utils.eval_str_nested_list_or_tuple(
            args.encoder_conv_strides, type=int)
        logger.info("input feature dimension: {}, channels: {}".format(
            task.feat_dim, task.feat_in_channels))
        assert task.feat_dim % task.feat_in_channels == 0
        conv_layers = ConvBNReLU(
            out_channels,
            kernel_sizes,
            strides,
            in_channels=task.feat_in_channels,
        ) if out_channels is not None else None

        transformer_encoder_input_size = task.feat_dim // task.feat_in_channels
        if conv_layers is not None:
            for stride in strides:
                if isinstance(stride, (list, tuple)):
                    assert len(stride) > 0
                    s = stride[1] if len(stride) > 1 else stride[0]
                else:
                    assert isinstance(stride, int)
                    s = stride
                transformer_encoder_input_size = (
                    transformer_encoder_input_size + s - 1) // s
            transformer_encoder_input_size *= out_channels[-1]
        else:
            transformer_encoder_input_size = task.feat_dim

        encoder_transformer_context = speech_utils.eval_str_nested_list_or_tuple(
            args.encoder_transformer_context,
            type=int,
        )
        if encoder_transformer_context is not None:
            assert len(encoder_transformer_context) == 2
            for i in range(2):
                assert (encoder_transformer_context[i] is None
                        or (isinstance(encoder_transformer_context[i], int)
                            and encoder_transformer_context[i] >= 0))

        scheduled_sampling_rate_scheduler = ScheduledSamplingRateScheduler(
            args.scheduled_sampling_probs,
            args.start_scheduled_sampling_epoch,
        )

        encoder = cls.build_encoder(
            args,
            conv_layers_before=conv_layers,
            input_size=transformer_encoder_input_size,
            transformer_context=encoder_transformer_context,
        )
        decoder = cls.build_decoder(
            args,
            tgt_dict,
            decoder_embed_tokens,
            scheduled_sampling_rate_scheduler=scheduled_sampling_rate_scheduler,
        )
        min_params_to_wrap = getattr(args, "min_params_to_wrap",
                                     DEFAULT_MIN_PARAMS_TO_WRAP)
        # fsdp_wrap is a no-op when --ddp-backend != fully_sharded
        encoder = fsdp_wrap(encoder, min_num_params=min_params_to_wrap)
        decoder = fsdp_wrap(decoder, min_num_params=min_params_to_wrap)
        return cls(args, encoder, decoder)
Exemplo n.º 16
0
    def build_model(cls, cfg, task):
        """Build a new model instance."""

        if cfg.encoder.layers_to_keep:
            cfg.encoder.layers = len(cfg.encoder.layers_to_keep.split(","))

        if cfg.max_source_positions is None:
            cfg.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS
        if cfg.max_target_positions is None:
            cfg.max_target_positions = DEFAULT_MAX_TARGET_POSITIONS

        tgt_dict = task.target_dictionary

        decoder_embed_tokens = cls.build_embedding(cfg, tgt_dict,
                                                   cfg.decoder.embed_dim,
                                                   cfg.decoder.embed_path)
        if cfg.offload_activations:
            cfg.checkpoint_activations = True  # offloading implies checkpointing

        out_channels = speech_utils.eval_str_nested_list_or_tuple(
            cfg.encoder.conv_channels, type=int)
        kernel_sizes = speech_utils.eval_str_nested_list_or_tuple(
            cfg.encoder.conv_kernel_sizes, type=int)
        strides = speech_utils.eval_str_nested_list_or_tuple(
            cfg.encoder.conv_strides, type=int)
        logger.info("input feature dimension: {}, channels: {}".format(
            task.feat_dim, task.feat_in_channels))
        assert task.feat_dim % task.feat_in_channels == 0
        conv_layers = (ConvBNReLU(
            out_channels,
            kernel_sizes,
            strides,
            in_channels=task.feat_in_channels,
        ) if out_channels is not None else None)

        transformer_encoder_input_size = task.feat_dim // task.feat_in_channels
        if conv_layers is not None:
            for stride in strides:
                if isinstance(stride, (list, tuple)):
                    assert len(stride) > 0
                    s = stride[1] if len(stride) > 1 else stride[0]
                else:
                    assert isinstance(stride, int)
                    s = stride
                transformer_encoder_input_size = (
                    transformer_encoder_input_size + s - 1) // s
            transformer_encoder_input_size *= out_channels[-1]
        else:
            transformer_encoder_input_size = task.feat_dim

        encoder_transformer_context = speech_utils.eval_str_nested_list_or_tuple(
            cfg.encoder.transformer_context,
            type=int,
        )
        if encoder_transformer_context is not None:
            assert len(encoder_transformer_context) == 2
            for i in range(2):
                assert encoder_transformer_context[i] is None or (
                    isinstance(encoder_transformer_context[i], int)
                    and encoder_transformer_context[i] >= 0)

        encoder = cls.build_encoder(
            cfg,
            pre_encoder=conv_layers,
            input_size=transformer_encoder_input_size,
            transformer_context=encoder_transformer_context,
        )
        decoder = cls.build_decoder(cfg, tgt_dict, decoder_embed_tokens)
        # fsdp_wrap is a no-op when --ddp-backend != fully_sharded
        encoder = fsdp_wrap(encoder, min_num_params=cfg.min_params_to_wrap)
        decoder = fsdp_wrap(decoder, min_num_params=cfg.min_params_to_wrap)
        return cls(cfg, encoder, decoder)