Ejemplo n.º 1
0
    def build_model(self, args):
        def check_args():
            messages = []
            if (len(
                    set(self.args.lang_pairs).symmetric_difference(
                        args.lang_pairs)) != 0):
                messages.append(
                    "--lang-pairs should include all the language pairs {}.".
                    format(args.lang_pairs))
            if self.args.encoder_langtok != args.encoder_langtok:
                messages.append("--encoder-langtok should be {}.".format(
                    args.encoder_langtok))
            if self.args.decoder_langtok != args.decoder_langtok:
                messages.append("--decoder-langtok should {} be set.".format(
                    "" if args.decoder_langtok else "not"))

            if len(messages) > 0:
                raise ValueError(" ".join(messages))

        # Update args -> the fact that the constructor here
        # changes the args object doesn't mean you get the same one here
        self.update_args(args)

        # Check if task args are consistant with model args
        check_args()

        from fairseq_stchde import models

        model = models.build_model(args, self)
        if not isinstance(model, FairseqMultiModel):
            raise ValueError(
                "MultilingualTranslationTask requires a FairseqMultiModel architecture"
            )
        return model
Ejemplo n.º 2
0
    def build_model(self, args):
        from fairseq_stchde import models

        model = models.build_model(args, self)

        model.register_classification_head(
            getattr(args, "ranking_head_name", "sentence_classification_head"),
            num_classes=1,
        )

        return model
Ejemplo n.º 3
0
    def build_model(self, args: Namespace):
        """
        Build the :class:`~fairseq.models.BaseFairseqModel` instance for this
        task.

        Args:
            args (argparse.Namespace): parsed command-line arguments

        Returns:
            a :class:`~fairseq.models.BaseFairseqModel` instance
        """
        from fairseq_stchde import models, quantization_utils

        model = models.build_model(args, self)
        model = quantization_utils.quantize_model_scalar(model, args)
        return model
Ejemplo n.º 4
0
    def build_model(self, cfg: FairseqDataclass):
        """
        Build the :class:`~fairseq_stchde.models.BaseFairseqModel` instance for this
        task.

        Args:
            cfg (FairseqDataclass): configuration object

        Returns:
            a :class:`~fairseq_stchde.models.BaseFairseqModel` instance
        """
        from fairseq_stchde import models, quantization_utils

        model = models.build_model(cfg, self)
        model = quantization_utils.quantize_model_scalar(model, cfg)
        return model
    def build_model(self, args):
        from fairseq_stchde import models

        model = models.build_model(args, self)
        if not isinstance(model, FairseqMultiModel):
            raise ValueError(
                "SemisupervisedTranslationTask requires a FairseqMultiModel architecture"
            )

        # create SequenceGenerator for each model that has backtranslation dependency on it
        self.sequence_generators = {}
        if (self.lambda_otf_bt > 0.0
                or self.lambda_otf_bt_steps is not None) and self.training:
            for lang_pair in self.lang_pairs:
                src, tgt = lang_pair.split("-")
                key = "{}-{}".format(tgt, src)
                self.sequence_generators[key] = SequenceGenerator(
                    [model.models[key]],
                    tgt_dict=self.dicts[src],
                    beam_size=args.bt_beam_size,
                    max_len_a=args.bt_max_len_a,
                    max_len_b=args.bt_max_len_b,
                )
                decoder_lang_tok_idx = self.get_decoder_langtok(src)

                def backtranslate_fn(
                    sample,
                    model=model.models[key],
                    bos_token=decoder_lang_tok_idx,
                    sequence_generator=self.sequence_generators[key],
                ):
                    return sequence_generator.generate(
                        [model],
                        sample,
                        bos_token=bos_token,
                    )

                self.backtranslators[lang_pair] = backtranslate_fn

        return model