Example #1
0
    def __init__(self, args):
        super().__init__()
        self.num_workers = distributed_utils.get_data_parallel_world_size()
        expert_centroids = torch.empty(self.num_workers, args.decoder_embed_dim)
        torch.nn.init.orthogonal_(expert_centroids, gain=0.1)
        self.register_parameter("expert_centroids", torch.nn.Parameter(expert_centroids))
        self.expert_network = nn.Sequential(*([BaseSublayer(args) for _ in range(args.base_sublayers)]))
        self.expert_id = distributed_utils.get_data_parallel_rank()
        self.shuffle = args.base_shuffle
        self.cpp = self.load_assignment()

        # Add a special attribute to the expert parameters, so we know not to sync their gradients
        for param in self.expert_network.parameters():
            param.expert = True
    def __init__(self, cfg: TruncatedBPTTLMConfig):
        super().__init__(cfg)

        if cfg.data_parallel_rank is None or cfg.data_parallel_size is None:
            if torch.distributed.is_initialized():
                cfg.data_parallel_rank = dist_utils.get_data_parallel_rank()
                cfg.data_parallel_size = dist_utils.get_data_parallel_world_size()
            else:
                cfg.data_parallel_rank = 0
                cfg.data_parallel_size = 1

        # load the dictionary
        paths = utils.split_paths(cfg.data)
        assert len(paths) > 0
        self.dictionary = Dictionary.load(os.path.join(paths[0], "dict.txt"))
        logger.info("dictionary: {} types".format(len(self.dictionary)))
Example #3
0
 def data_parallel_world_size(self):
     if self.cfg.distributed_training.distributed_world_size == 1:
         return 1
     return distributed_utils.get_data_parallel_world_size()
Example #4
0
    def valid_step(self, sample, model, criterion):
        res = model(
            **sample["net_input"],
            dense_x_only=True,
        )

        dense_x = res["logits"]
        padding_mask = res["padding_mask"]

        word_scores = None
        if self.compute_word_score is not None:
            word_scores = self.compute_word_score(dense_x.cpu(),
                                                  padding_mask.cpu())

        z = dense_x.argmax(-1)
        z[padding_mask] = self.target_dictionary.pad()

        vocab_seen = torch.zeros(self.num_symbols, dtype=torch.bool)

        import editdistance

        c_err = 0
        c_len = 0
        pred_c_len = 0
        lm_score_sum = 0
        for i, (x, t, id) in enumerate(
                zip(
                    z,
                    sample["target"] if "target" in sample else [None] *
                    len(z),
                    sample["id"],
                )):

            if t is not None:
                t = t[(t >= self.target_dictionary.nspecial)]
            x = x[(x >= self.target_dictionary.nspecial)
                  & (x < (self.num_symbols + self.target_dictionary.nspecial))]
            if self.sil_id >= 0:
                x = x[x != self.sil_id]

            vocab_seen[x - self.target_dictionary.nspecial] = True

            pred_units_arr = x
            if self.cfg.ctc_eval:
                pred_units_arr = pred_units_arr.unique_consecutive()
                pred_units_arr = pred_units_arr[pred_units_arr != 0]

            if id == 0:
                if t is not None:
                    logger.info(f"REF: {self.target_dictionary.string(t)}")
                logger.info(
                    f"HYP: {self.target_dictionary.string(pred_units_arr)}")

                if self.kenlm is not None:
                    if t is not None:
                        ref_lm_s = self.compute_lm_score(
                            self.target_dictionary.string(t))
                        logger.info(
                            f"LM [REF]: {ref_lm_s}, {math.pow(10, -ref_lm_s / (len(t) + 1))}"
                        )

                    hyp_lm_s = self.compute_lm_score(
                        self.target_dictionary.string(pred_units_arr))
                    logger.info(
                        f"LM [HYP]: {hyp_lm_s}, {math.pow(10, -hyp_lm_s / (len(pred_units_arr) + 1))}"
                    )

            pred_units_arr = pred_units_arr.tolist()

            pred_c_len += len(pred_units_arr)

            if t is not None:
                t = t.tolist()
                c_err += editdistance.eval(pred_units_arr, t)
                c_len += len(t)
            else:
                c_len = pred_c_len

            if self.kenlm is not None:
                pred_str = self.target_dictionary.string(pred_units_arr)
                lm_score = self.compute_lm_score(pred_str)
                lm_score_sum += lm_score

        kaldi_score_sum = 0
        word_lm_sum = 0
        num_words = 0
        if word_scores is not None:
            for score, words in word_scores:
                kaldi_score_sum += score
                num_words += len(words)
                if self.word_kenlm is not None:
                    word_lm_sum += self.kenlm.score(" ".join(words))

        try:
            world_size = get_data_parallel_world_size()
        except:
            world_size = 1

        logging_output = {
            "loss": c_err,
            "_num_char_errors": c_err,
            "_num_chars": c_len,
            "_num_pred_chars": pred_c_len,
            "ntokens": c_len,
            "nsentences": z.size(0),
            "sample_size": c_len,
            "_world_size": world_size,
            "_lm_score_sum": lm_score_sum,
            "_kaldi_score_sum": kaldi_score_sum,
            "_word_lm_sum": word_lm_sum,
            "_num_words": num_words,
            "_vocab_seen": vocab_seen,
        }

        return c_err, c_len, logging_output