Example #1
0
    def _prepare_sample(self, sample):
        if sample is None or len(sample) == 0:
            return None

        if self.cuda:
            sample = utils.move_to_cuda(sample)

        def apply_half(t):
            if t.dtype is torch.float32:
                return t.half()
            return t

        if self.args.fp16:
            sample = utils.apply_to_sample(apply_half, sample)

        return sample
Example #2
0
    def _reduce_and_log_stats(self, logging_outputs, sample_size):
        #with metrics.aggregate() as agg:
        # convert logging_outputs to CPU to avoid unnecessary
        # device-to-host transfers in reduce_metrics
        logging_outputs = utils.apply_to_sample(
            lambda t: t.to(device='cpu', non_blocking=True), logging_outputs)

        #self.task.reduce_metrics(logging_outputs, self.get_criterion())

        # support legacy interface
        #logging_output = agg.get_smoothed_values()
        logging_output = logging_outputs[-1]
        logging_output["sample_size"] = sample_size
        for key_to_delete in ["ppl", "wps", "wpb", "bsz"]:
            if key_to_delete in logging_output:
                del logging_output[key_to_delete]
        return logging_output
Example #3
0
def translate_batch(model, sids, sentences):
    input = [model.encode(sentence) for sentence in sentences]
    lengths = [len(t) for t in input]
    dataset = model.task.build_dataset_for_inference(input, lengths)
    samples = dataset.collater(dataset)
    samples = utils.apply_to_sample(
        lambda tensor: tensor.to(model.device),
        samples
    )
    ids = samples['id'].cpu()

    generator = model.task.build_generator(model.args)

    translations = model.task.inference_step(generator, model.models, samples)
    hypos = [translation[0]['tokens'] for translation in translations]
    translated = [model.decode(hypo) for hypo in hypos]
    return OrderedDict([(sids[id], tr) for id, tr in zip(ids, translated)])
Example #4
0
def build_sample(
    model,
    src_tokens: List[torch.LongTensor],
    tgt_tokens: List[torch.LongTensor],
):
    # assert torch.is_tensor(src_tokens)
    dataset = LanguagePairDataset(
        src_tokens,
        [x.numel() for x in src_tokens],
        model.task.source_dictionary,
        tgt=tgt_tokens,
        tgt_sizes=[x.numel() for x in tgt_tokens],
        tgt_dict=model.task.target_dictionary,
    )
    sample = dataset.collater(dataset)
    sample = utils.apply_to_sample(lambda tensor: tensor.to(model.device),
                                   sample)
    return sample
Example #5
0
    def start(self, start_with_nothing):
        state = LMState()
        prefix = torch.LongTensor([[self.dictionary.eos()]])
        incremental_state = {} if self.save_incremental else None
        with torch.no_grad():
            res = self.model(prefix.cuda(),
                             incremental_state=incremental_state)
            probs = self.model.get_normalized_probs(res,
                                                    log_probs=True,
                                                    sample=None)

        if incremental_state is not None:
            incremental_state = apply_to_sample(lambda x: x.cpu(),
                                                incremental_state)
        self.states[state] = FairseqLMState(prefix.numpy(), incremental_state,
                                            probs[0, -1].cpu().numpy())
        self.stateq.append(state)

        return state
Example #6
0
    def _prepare_sample(self, sample, dummy=False):
        if sample is None or len(sample) == 0:
            return None

        if self.args.task == 'doc_translation' and not dummy:
            sample = self._prepare_sample_with_context(sample)

        if self.cuda:
            sample = utils.move_to_cuda(sample)

        def apply_half(t):
            if t.dtype is torch.float32:
                return t.half()
            return t

        if self.args.fp16:
            sample = utils.apply_to_sample(apply_half, sample)

        return sample
def prepare_sample(args, task, sample, use_cuda=True):
    if sample is None or len(sample) == 0:
        return None

    if args.task == 'doc_translation':
        sample = prepare_sample_with_context(task, sample)

    if use_cuda:
        sample = utils.move_to_cuda(sample)

    def apply_half(t):
        if t.dtype is torch.float32:
            return t.half()
        return t

    if args.fp16:
        sample = utils.apply_to_sample(apply_half, sample)

    return sample
Example #8
0
 def _build_sample(self,
                   src_tokens: List[torch.LongTensor],
                   src_sent_ids=None,
                   chains_dataset=None):
     # assert torch.is_tensor(src_tokens)
     if src_sent_ids is not None:
         dataset = self.task.build_dataset_for_inference(
             src_tokens, [x.numel() for x in src_tokens],
             src_sent_ids=src_sent_ids,
             chains_dataset=chains_dataset,
             explicit_str_att=chains_dataset is not None)
     else:
         dataset = self.task.build_dataset_for_inference(
             src_tokens,
             [x.numel() for x in src_tokens],
         )
     sample = dataset.collater(dataset)
     sample = utils.apply_to_sample(lambda tensor: tensor.to(self.device),
                                    sample)
     return sample
Example #9
0
 def _build_sample(self, src_tokens: List[torch.LongTensor], src_tokens2 = None):
     # assert torch.is_tensor(src_tokens)
     if src_tokens2 == None:
         dataset = self.task.build_dataset_for_inference(
             src_tokens,
             [x.numel() for x in src_tokens],
         )
     else:
         dataset = self.task.build_dataset_for_inference(
             src_tokens,
             [x.numel() for x in src_tokens],
             src_tokens2,
             [x.numel() for x in src_tokens2],
         )
     #print(self.device)
     sample = dataset.collater(dataset)
     sample = utils.apply_to_sample(
         lambda tensor: tensor.to(self.device),
         sample
     )
     return sample
Example #10
0
    def _prepare_sample(self, sample):
        if sample == "DUMMY":
            raise Exception(
                "Trying to use an uninitialized 'dummy' batch. This usually indicates "
                "that the total number of batches is smaller than the number of "
                "participating GPUs. Try reducing the batch size or using fewer GPUs."
            )

        if sample is None or len(sample) == 0:
            return None

        if self.cuda:
            sample = utils.move_to_cuda(sample)

        def apply_half(t):
            if t.dtype is torch.float32:
                return t.half()
            return t

        if self.args.fp16:
            sample = utils.apply_to_sample(apply_half, sample)

        return sample
    def batch_augments(self, sentences, batch_size=30, progress_bar=True):
        self.from_model.eval()
        self.to_model.eval()

        result = []
        oom = False
        batch_ind = 0

        iterator = tqdm(range(len(sentences) // batch_size +
                              1)) if progress_bar else range(
                                  len(sentences) // batch_size + 1)

        try:
            for batch_ind in iterator:
                inputs = [
                    self.from_model.encode(sample)
                    for sample in sentences[batch_ind *
                                            batch_size:(batch_ind + 1) *
                                            batch_size]
                ]

                if len(inputs) > 0:

                    dataset = self.from_model.task.build_dataset_for_inference(
                        inputs, [input.numel() for input in inputs])
                    sample = dataset.collater(dataset)
                    sample = utils.apply_to_sample(
                        lambda tensor: tensor.to(self.from_model.device),
                        sample)
                    gen_args = copy.copy(self.from_model.args)
                    gen_args.beam = self.from_num_beam
                    generator = self.from_model.task.build_generator(
                        self.from_model.models, args=gen_args)
                    translations = self.from_model.task.inference_step(
                        generator, self.from_model.models, sample)
                    translations = [
                        self.from_model.decode(tr[0]['tokens'])
                        for tr in translations
                    ]
                    translations = [
                        translations[sample['id'].tolist().index(i)]
                        for i in range(len(translations))
                    ]

                    translations = [
                        self.to_model.encode(sample) for sample in translations
                    ]
                    dataset = self.to_model.task.build_dataset_for_inference(
                        translations,
                        [input.numel() for input in translations])
                    sample = dataset.collater(dataset)
                    sample = utils.apply_to_sample(
                        lambda tensor: tensor.to(self.to_model.device), sample)
                    gen_args = copy.copy(self.to_model.args)
                    gen_args.beam = self.to_num_beam
                    generator = self.to_model.task.build_generator(
                        self.to_model.models, args=gen_args)
                    back_translations = self.to_model.task.inference_step(
                        generator, self.to_model.models, sample)
                    back_translations = [
                        self.to_model.decode(tr[0]['tokens'])
                        for tr in back_translations
                    ]
                    back_translations = [
                        back_translations[sample['id'].tolist().index(i)]
                        for i in range(len(back_translations))
                    ]

                    result.extend(back_translations)

        except RuntimeError:
            torch.cuda.empty_cache()
            gc.collect()
            oom = True

        if oom:
            result.extend(
                self.batch_augments(
                    sentences[batch_ind * batch_size:(batch_ind + 1) *
                              batch_size],
                    batch_size=batch_size // 2,
                    progress_bar=False))

            result.extend(
                self.batch_augments(sentences[(batch_ind + 1) * batch_size:],
                                    batch_size=batch_size))

        return result
Example #12
0
    def score(self, state: LMState, token_index: int, no_cache: bool = False):
        """
        Evaluate language model based on the current lm state and new word
        Parameters:
        -----------
        state: current lm state
        token_index: index of the word
                     (can be lexicon index then you should store inside LM the
                      mapping between indices of lexicon and lm, or lm index of a word)

        Returns:
        --------
        (LMState, float): pair of (new state, score for the current word)
        """
        curr_state = self.states[state]

        def trim_cache(targ_size):
            while len(self.stateq) > targ_size:
                rem_k = self.stateq.popleft()
                rem_st = self.states[rem_k]
                rem_st = FairseqLMState(rem_st.prefix, None, None)
                self.states[rem_k] = rem_st

        if curr_state.probs is None:
            new_incremental_state = (
                curr_state.incremental_state.copy()
                if curr_state.incremental_state is not None
                else None
            )
            with torch.no_grad():
                if new_incremental_state is not None:
                    new_incremental_state = apply_to_sample(
                        lambda x: x.cuda(), new_incremental_state
                    )
                elif self.save_incremental:
                    new_incremental_state = {}

                res = self.model(
                    torch.from_numpy(curr_state.prefix).cuda(),
                    incremental_state=new_incremental_state,
                )
                probs = self.model.get_normalized_probs(
                    res, log_probs=True, sample=None
                )

                if new_incremental_state is not None:
                    new_incremental_state = apply_to_sample(
                        lambda x: x.cpu(), new_incremental_state
                    )

                curr_state = FairseqLMState(
                    curr_state.prefix, new_incremental_state, probs[0, -1].cpu().numpy()
                )

            if not no_cache:
                self.states[state] = curr_state
                self.stateq.append(state)

        score = curr_state.probs[token_index].item()

        trim_cache(self.max_cache)

        outstate = state.child(token_index)
        if outstate not in self.states and not no_cache:
            prefix = np.concatenate(
                [curr_state.prefix, torch.LongTensor([[token_index]])], -1
            )
            incr_state = curr_state.incremental_state

            self.states[outstate] = FairseqLMState(prefix, incr_state, None)

        if token_index == self.unk:
            score = float("-inf")

        return outstate, score
Example #13
0
    def generate(
        self,
        tokenized_sentences: List[torch.LongTensor],
        beam: int = 5,
        verbose: bool = False,
        skip_invalid_size_inputs=False,
        inference_step_args=None,
        **kwargs
    ) -> List[List[Dict[str, torch.Tensor]]]:
        if torch.is_tensor(tokenized_sentences) and tokenized_sentences.dim() == 1:
            return self.generate(
                tokenized_sentences.unsqueeze(0), beam=beam, verbose=verbose, **kwargs
            )[0]

        # build generator using current args as well as any kwargs
        gen_args = copy.deepcopy(self.cfg.generation)
        with open_dict(gen_args):
            gen_args.beam = beam
            for k, v in kwargs.items():
                setattr(gen_args, k, v)
        generator = self.task.build_generator(self.models, gen_args)

        inference_step_args = inference_step_args or {}
        results = []
        for batch in self._build_batches(tokenized_sentences, skip_invalid_size_inputs):
            batch = utils.apply_to_sample(lambda t: t.to(self.device), batch)
            translations = self.task.inference_step(
                generator, self.models, batch, **inference_step_args
            )
            for id, hypos in zip(batch["id"].tolist(), translations):
                results.append((id, hypos))

        # sort output to match input order
        outputs = [hypos for _, hypos in sorted(results, key=lambda x: x[0])]

        if verbose:

            def getarg(name, default):
                return getattr(gen_args, name, getattr(self.cfg, name, default))

            for source_tokens, target_hypotheses in zip(tokenized_sentences, outputs):
                src_str_with_unk = self.src_dict.string(source_tokens)
                logger.info("S\t{}".format(src_str_with_unk))
                for hypo in target_hypotheses:
                    hypo_str = self.decode(hypo["tokens"])
                    logger.info("H\t{}\t{}".format(hypo["score"], hypo_str))
                    logger.info(
                        "P\t{}".format(
                            " ".join(
                                map(
                                    lambda x: "{:.4f}".format(x),
                                    hypo["positional_scores"].tolist(),
                                )
                            )
                        )
                    )
                    if hypo["alignment"] is not None and getarg(
                        "print_alignment", False
                    ):
                        logger.info(
                            "A\t{}".format(
                                " ".join(
                                    [
                                        "{}-{}".format(src_idx, tgt_idx)
                                        for src_idx, tgt_idx in hypo["alignment"]
                                    ]
                                )
                            )
                        )
        return outputs
Example #14
0
def variable_beam_stream_fast(sg,
                              model,
                              tokenized_sentences,
                              k=5,
                              max_length=100,
                              rp=0.6,
                              ap=2.5,
                              rpl=0.02,
                              mc=3,
                              find_top_z=1,
                              max_indices=32,
                              encode_batch_size=64,
                              max_si_tokens=7168,
                              bos_token=None,
                              len_penalty=1,
                              one_batch=False):
    ensemble_size = len(model.models)

    BOS_ID = sg.eos if bos_token is None else bos_token
    EOS_ID = sg.eos

    if one_batch:
        full_data_size = tokenized_sentences['net_input']['src_tokens'].shape[
            0]
    else:
        full_data_size = len(tokenized_sentences)
        batch_iterator = model._build_batches(tokenized_sentences,
                                              False)  # not streaming
    master_done_beams = [[] for _ in range(full_data_size)]
    master_batch_ids = [None for _ in range(full_data_size)]

    parent_model = model
    model = model.models

    master_decoded_indices = torch.zeros(1, 0, k).long().to(
        parent_model.device)  # seq, batch, k
    master_log_probs = torch.zeros(0, k).to(parent_model.device)  # batch x k
    master_enc_out = []
    master_state = IncrementalState(
        0, k, ensemble_size, parent_model.device)  # init incremental state

    master_valid_beam_mask = torch.zeros(0, k).to(
        parent_model.device)  # batch x k
    master_num_valid_beams = torch.zeros(0).long().to(
        parent_model.device)  # batch
    master_index = torch.zeros(0).long().to(parent_model.device)  # batch
    master_src_lengths = torch.zeros(0).long().to(parent_model.device)
    master_progress = torch.zeros(0).long().to(parent_model.device)  # batch
    master_end_found = torch.zeros(0, k).long().to(
        parent_model.device)  # batch x k
    master_done_lengths = torch.zeros(0).long().to(
        parent_model.device)  # batch
    master_best_finished_log_probs = torch.zeros(0).to(
        parent_model.device) - 1e8  # batch

    current_idx = 0
    has_more_batches = True
    decode_calls = 0
    n_expansions = 0
    master_remove_indices = torch.zeros(0).long().to(parent_model.device)
    num_pad = 0
    reselect = True
    while True:
        while has_more_batches and master_src_lengths.sum(
        ) <= max_si_tokens - parent_model.args.max_tokens:  # token-based limit
            assert reselect
            if one_batch:  # not streaming
                batch = tokenized_sentences
                has_more_batches = False
            else:
                try:
                    batch = next(batch_iterator)
                except StopIteration:
                    has_more_batches = False
                    break
            batch = utils.apply_to_sample(lambda t: t.to(parent_model.device),
                                          batch)
            for i, id in enumerate(batch['id'].tolist()):
                master_batch_ids[current_idx + i] = id
            net_input = batch["net_input"]
            src_tokens = net_input["src_tokens"]
            num_new_sources = len(src_tokens)

            # encode add the next batch of source infos; update the index
            encoder_outs = sg.model.forward_encoder(net_input)
            # concatenate to the current master tensors
            # decoded_indices; note these are left padded
            current_seqlen = master_decoded_indices.size(0)
            master_decoded_indices = torch.cat([
                master_decoded_indices,
                pad_to_length(torch.zeros(1, num_new_sources, k) + BOS_ID,
                              current_seqlen,
                              0,
                              side='left',
                              value=0).long().to(parent_model.device)
            ],
                                               dim=1)
            # log_probs
            master_log_probs = torch.cat([
                master_log_probs,
                torch.cat([
                    torch.zeros(num_new_sources, 1),
                    torch.zeros(num_new_sources, k - 1) - 1e8
                ],
                          dim=1).to(parent_model.device)
            ],
                                         dim=0)

            if len(master_enc_out) == 0:
                assert current_idx == 0
                master_enc_out = encoder_outs
            else:
                assert len(master_enc_out) == len(encoder_outs)
                for i in range(len(master_enc_out)):
                    meo, eo = master_enc_out[i], encoder_outs[i]
                    max_seq = max(meo.encoder_out.shape[0],
                                  eo.encoder_out.shape[0])
                    new_eo = EncoderOut(encoder_out=torch.cat([
                        pad_to_length(
                            meo.encoder_out, max_seq, 0, side='left', value=0),
                        pad_to_length(
                            eo.encoder_out, max_seq, 0, side='left', value=0)
                    ],
                                                              dim=1),
                                        encoder_padding_mask=torch.cat([
                                            pad_to_length(
                                                meo.encoder_padding_mask,
                                                max_seq,
                                                1,
                                                side='left',
                                                value=True),
                                            pad_to_length(
                                                eo.encoder_padding_mask,
                                                max_seq,
                                                1,
                                                side='left',
                                                value=True)
                                        ],
                                                                       dim=0),
                                        encoder_embedding=torch.cat([
                                            pad_to_length(
                                                meo.encoder_embedding,
                                                max_seq,
                                                1,
                                                side='left',
                                                value=0),
                                            pad_to_length(eo.encoder_embedding,
                                                          max_seq,
                                                          1,
                                                          side='left',
                                                          value=0)
                                        ],
                                                                    dim=0),
                                        encoder_states=None,
                                        src_tokens=None,
                                        src_lengths=None)
                    master_enc_out[i] = new_eo
            if not one_batch:
                # get the encoder attention keys
                sg.model.incremental_states = [{}
                                               for _ in range(ensemble_size)]
                sg.model.forward_decoder(
                    (torch.zeros(num_new_sources) + BOS_ID).long().to(
                        parent_model.device).unsqueeze(1), encoder_outs,
                    sg.temperature)
                dummy_state = sg.model.incremental_states
                master_state.append_new_incremental_state(
                    num_new_sources, dummy_state,
                    torch.arange(num_new_sources).long().to(
                        parent_model.device) + current_idx)

            master_valid_beam_mask = torch.cat([
                master_valid_beam_mask,
                torch.cat([
                    torch.ones(num_new_sources, 1),
                    torch.zeros(num_new_sources, k - 1)
                ],
                          dim=1).to(parent_model.device)
            ],
                                               dim=0)
            # print(net_input['src_lengths'].max())
            master_src_lengths = torch.cat(
                [master_src_lengths, net_input['src_lengths']], dim=0)
            # num_valid_beams
            master_num_valid_beams = torch.cat([
                master_num_valid_beams,
                torch.ones(num_new_sources).long().to(parent_model.device)
            ],
                                               dim=0)
            # index
            master_index = torch.cat([
                master_index, current_idx +
                torch.arange(num_new_sources).to(parent_model.device)
            ],
                                     dim=0)
            # progress
            master_progress = torch.cat([
                master_progress,
                torch.zeros(num_new_sources).long().to(parent_model.device)
            ],
                                        dim=0)
            # end_found
            master_end_found = torch.cat([
                master_end_found,
                torch.zeros(num_new_sources, k).long().to(parent_model.device)
            ],
                                         dim=0)
            # done lengths
            master_done_lengths = torch.cat([
                master_done_lengths,
                torch.zeros(num_new_sources).long().to(parent_model.device)
            ],
                                            dim=0)
            # best done log probs
            master_best_finished_log_probs = torch.cat([
                master_best_finished_log_probs,
                torch.zeros(num_new_sources).to(parent_model.device) - 1e8
            ],
                                                       dim=0)

            current_idx += num_new_sources
            # break # for debugging

        # break if none left
        if not has_more_batches and len(master_index) == 0:
            break

        # based on max_bs and source_info, select which indices to use (sort source_info), then create:
        selected_indices, unselected_indices, prog_min = select_source_indices(
            master_num_valid_beams,
            master_progress,
            master_index,
            max_indices,
            reverse=False,
            sort=False)
        if one_batch:
            assert len(unselected_indices) == 0  # for debugging
        selected_master_indices = master_index[selected_indices]
        batch_size = len(selected_indices)
        selected_enc_out = sg.model.reorder_encoder_out(
            master_enc_out,
            selected_indices.unsqueeze(1).expand(-1, k).flatten())
        # if decode_calls % 50 == 0:
        #     print(decode_calls)

        valid_beam_mask = master_valid_beam_mask[selected_indices]
        valid_beam_indices = valid_beam_mask.flatten().nonzero().flatten(
        )  # idk why need to flatten again
        reverse_idx = (torch.cumsum(valid_beam_mask.flatten(
        ), dim=0) * valid_beam_mask.flatten()).long(
        ) - 1  # it's fine to select whatever position for padding as they'll be removed later
        if num_pad > 0:
            if num_pad >= len(
                    master_decoded_indices
            ):  # edge case: we previously ran out of beams, and we are starting fresh now
                assert num_pad == len(master_decoded_indices)
                num_pad -= 1
            master_decoded_indices = master_decoded_indices[num_pad:]
            master_state.clean_padding(num_pad)

        if reselect:
            selected_state_master_indices, selected_state = master_state.select_incremental_state(
                selected_master_indices, master_remove_indices, prog_min)
            master_state.num_sources -= len(master_remove_indices)
        sg.model.incremental_states = selected_state
        log_probs = master_log_probs[selected_indices]
        progress = master_progress[selected_indices]
        decoded_indices = master_decoded_indices[-progress.max() - 1:,
                                                 selected_indices, :]
        end_found = master_end_found[selected_indices]
        done_lengths = master_done_lengths[selected_indices]
        best_finished_log_probs = master_best_finished_log_probs[
            selected_indices]

        # flattened_indices = last_indices.flatten().unsqueeze(0) # 1 x batch*k
        # create valid beam indices from valid beam mask
        if one_batch and decode_calls == 0:
            selected_state_master_indices = master_index.clone()
        assert len(selected_state_master_indices) == len(valid_beam_indices)
        decode_calls += 1
        n_expansions += len(valid_beam_indices)

        # use valid_beam_mask to select valid indices out of decoded_indices, encoder_outs, model incremental state
        decoding_selected_indices = decoded_indices.flatten(
            1)[:, valid_beam_indices]  # seq x selected
        selected_enc_out = sg.model.reorder_encoder_out(
            selected_enc_out, valid_beam_indices)

        assert torch.all(
            decoding_selected_indices.flatten(1).permute(1, 0)[:, 0] == 2)
        next_log_probs, _ = sg.model.forward_decoder(
            decoding_selected_indices.flatten(1).permute(
                1, 0)[:, :master_progress.max() + 1], selected_enc_out,
            sg.temperature)

        # remake next_scores, state with dummies
        next_log_probs = next_log_probs[reverse_idx].view(1, batch_size, k, -1)
        # reorder incremental model state
        reorder_idx = reverse_idx

        next_log_probs = next_log_probs.view(1, batch_size, k, -1)

        # for edge case where EOS_ID appears later down in the beam but still needs to be dealt with correctly on the next step!
        end_found = end_found.unsqueeze(0).unsqueeze(
            3
        )  # batch_size x k x 1 of whether end index is in tgt_idx already; if so, make prob of padding 1
        end_found = (
            end_found +
            (progress + 1 == max_length).long().view(1, -1, 1, 1)).clamp(max=1)
        end_found_scores = torch.zeros_like(next_log_probs).to(
            parent_model.device) - 1e8
        end_found_scores[:, :, :,
                         EOS_ID] = 0  # make it so you only pick eos for the sequences that are already done, and don't duplicate them, by making other probs -inf
        next_log_probs = end_found * end_found_scores + (
            1 - end_found) * next_log_probs  # ~ is for inverting the mask

        next_log_probs = next_log_probs - 1e8 * (
            1 - valid_beam_mask.unsqueeze(0).unsqueeze(3)
        )  # get rid of padding positions
        next_log_probs = next_log_probs + log_probs.unsqueeze(0).unsqueeze(
            3)  # 1, batch, k, vocab
        mc_probs, mc_indices = next_log_probs.topk(mc,
                                                   dim=3)  # 1, batch, k, mc
        top_log_probs, top_indices = mc_probs.flatten(2).topk(
            k, dim=2)  # 1, batch, k
        mc_vocab_indices = top_indices % mc
        beam_indices = top_indices // mc  # 1, batch, k
        vocab_indices = torch.gather(
            mc_indices.flatten(2).flatten(0, 1),
            1, (mc_vocab_indices + beam_indices * mc).flatten(0, 1)).unsqueeze(
                0)  # 1, batch, k
        # check which vocab_indices are done (in the first beam position), and add the corresponding beam to an array of done predictions
        newly_done_all = (vocab_indices == EOS_ID).long()  # 1, batch, k
        newly_done = torch.cumprod(
            newly_done_all, dim=2
        )  # keep on beam if there's something above it that's not done yet
        done_lengths += newly_done.sum(dim=2).flatten(
        )  # update this one before others since we'll need it earlier
        newly_done_indices = newly_done.flatten().nonzero()  # batch*k
        for j in newly_done_indices:
            source_idx = j // k
            # add to some master list with an entry for each source
            if len(master_done_beams[
                    selected_master_indices[source_idx]]) < find_top_z:
                finished_cand = decoded_indices[:, source_idx,
                                                beam_indices[0, source_idx,
                                                             j % k]].flatten()
                finished_cand_length = progress[source_idx] + 1
                while len(finished_cand) > 0 and finished_cand[-1] == EOS_ID:
                    finished_cand = finished_cand[:-1]
                    finished_cand_length -= 1
                if len(finished_cand) > 0:  # avoid length 0
                    master_done_beams[selected_master_indices[source_idx]].append( \
                            {'tokens': finished_cand.cpu(), 'score': (top_log_probs.flatten()[j] / ((finished_cand_length)**len_penalty)).item() })
                    best_finished_log_probs[source_idx] = max(
                        best_finished_log_probs[source_idx],
                        top_log_probs.flatten()[j])
                else:  # rarely with greedy search (beam size k = 1) you get stuff with length 0... so avoid crashing but give it low score
                    master_done_beams[selected_master_indices[source_idx]].append( \
                            {'tokens': finished_cand.cpu(), 'score': -1e8 })

        # then, shift log_probs and beam_indices for those beams and delete that beam(s); put in placeholder beam and log_prob at the k^th position
        # need to shift top_log_probs, beam_indices, vocab_indices accordingly
        top_log_probs = torch.cat([
            top_log_probs,
            torch.zeros_like(top_log_probs).to(parent_model.device) - 1e8
        ],
                                  dim=2)  # 1, batch, 2k
        shift_indices = newly_done.sum(
            dim=2).unsqueeze(2) + torch.arange(k).to(
                parent_model.device).unsqueeze(0).unsqueeze(1)  # 1, batch, k
        top_log_probs = torch.gather(top_log_probs, 2, shift_indices)
        shift_indices = shift_indices.clamp(max=k - 1)
        beam_indices = torch.gather(beam_indices, 2, shift_indices)
        vocab_indices = torch.gather(vocab_indices, 2, shift_indices)
        newly_done_all = torch.gather(newly_done_all, 2, shift_indices)

        log_probs = top_log_probs.squeeze(0)
        state_indices = (beam_indices + k * torch.arange(batch_size).to(
            parent_model.device).unsqueeze(1).repeat(1, k)).flatten()
        reorder_idx = reorder_idx[state_indices]

        # update valid beam mask
        ap_thresholds = (torch.max(log_probs[:, 0], best_finished_log_probs) -
                         ap).unsqueeze(1)  # batch x 1
        valid_beam_mask = (log_probs > ap_thresholds).float()  # batch x k
        # update valid beam mask based on how many beams are left for each source
        done_mask = pad_mask(
            k - done_lengths, parent_model.device, max_seqlen=k).permute(
                1, 0)  # batch x k of beams to keep, up to k - num done already
        all_low_prob_mask = 1 - valid_beam_mask.max(
            dim=1
        )[0]  # NOTE since we filter out by the absolute threshold including previously finished beams, we could get < k finished candidates, but always at least 1
        found_z_mask = (all_low_prob_mask.bool() |
                        (done_lengths >= find_top_z)).unsqueeze(1)
        valid_beam_mask = valid_beam_mask * done_mask * (1 -
                                                         found_z_mask.long())
        # filter the done ones out of all the master tensors
        keep_indices = (~found_z_mask).flatten().nonzero().flatten().long()
        remove_indices = (found_z_mask).flatten().nonzero().flatten().long()
        keep_indices = torch.cat(
            [selected_indices[keep_indices], unselected_indices], dim=0)
        master_remove_indices = master_index[selected_indices[remove_indices]]

        # update these quantities in their respective source_info objects after computing them
        # just deleting/concatenating to a single master tensor
        # master_decoded_indices seq x batch x k
        new_master_indices = torch.zeros(
            1, master_decoded_indices.size(1),
            k).long().to(parent_model.device)  # 1 x batch x k
        new_master_indices[:, selected_indices] = vocab_indices
        master_decoded_indices[:, selected_indices] = torch.gather(
            master_decoded_indices[:, selected_indices], 2,
            beam_indices.expand(
                master_decoded_indices[:, selected_indices].shape))
        master_decoded_indices = torch.cat(
            [master_decoded_indices, new_master_indices], dim=0)
        if prog_min + 2 >= master_decoded_indices.shape[0]:
            master_decoded_indices = torch.cat([
                torch.zeros(1, master_decoded_indices.size(1), k).long().to(
                    parent_model.device), master_decoded_indices
            ],
                                               dim=0)
        master_decoded_indices[:, selected_indices] = torch.roll(
            master_decoded_indices[:, selected_indices], -1, 0)
        master_decoded_indices = master_decoded_indices[:-1]
        # master_log_probs batch x k
        master_log_probs[selected_indices] = log_probs
        # master_valid_beam_mask batch x k
        master_valid_beam_mask[selected_indices] = valid_beam_mask
        # master_num_valid_beams batch
        master_num_valid_beams = master_valid_beam_mask.sum(dim=1).long()
        # master_progress batch
        master_progress[selected_indices] += 1
        # master_end_found batch x k
        master_end_found[selected_indices] = (
            torch.gather(end_found.squeeze(3), 2, beam_indices)
            | newly_done_all[0, :, :]).squeeze(0)
        # master_done_lengths batch
        master_done_lengths[selected_indices] = done_lengths
        # master_best_finished_log_probs batch
        master_best_finished_log_probs[
            selected_indices] = best_finished_log_probs
        # update master versions of sg.model state
        reorder_idx = reorder_idx[
            valid_beam_mask.flatten().nonzero().flatten()]
        selected_state_master_indices = selected_state_master_indices[
            reorder_idx]
        reorder_incremental_state(sg.model, reorder_idx)

        master_src_lengths = master_src_lengths[keep_indices]

        if master_src_lengths.sum(
        ) <= max_si_tokens - parent_model.args.max_tokens:
            reselect = True
        elif len(progress) < (master_progress == prog_min + 1).sum():
            reselect = True
        else:
            reselect = False
        if reselect:
            # if not one_batch:
            #     print('reselect', decode_calls)
            master_state.recache(selected_state_master_indices,
                                 sg.model.incremental_states)

        master_decoded_indices = master_decoded_indices[:, keep_indices, :]
        master_log_probs = master_log_probs[keep_indices]
        master_enc_out = sg.model.reorder_encoder_out(master_enc_out,
                                                      keep_indices)
        master_valid_beam_mask = master_valid_beam_mask[keep_indices]
        master_num_valid_beams = master_num_valid_beams[keep_indices]
        master_index = master_index[keep_indices]
        master_progress = master_progress[keep_indices]
        master_end_found = master_end_found[keep_indices]
        master_done_lengths = master_done_lengths[keep_indices]
        master_best_finished_log_probs = master_best_finished_log_probs[
            keep_indices]

        # delete any unnecessary padding so we don't keep increasing padding
        num_pad = (master_decoded_indices.sum(dim=1).sum(dim=1) == 0).sum(
            dim=0)
        if not reselect:
            assert num_pad == 0

    assert all([bid is not None for bid in master_batch_ids])
    for i in range(len(master_done_beams)):
        master_done_beams[i] = sorted(master_done_beams[i],
                                      key=lambda x: x['score'],
                                      reverse=True)
    if one_batch:
        return master_done_beams, decode_calls, n_expansions
    else:
        return master_batch_ids, master_done_beams, decode_calls, n_expansions
Example #15
0
def main(args, task=None, model_state=None):
    check_args(args)

    use_fp16 = args.fp16
    if args.max_tokens is None and args.batch_size is None:
        args.max_tokens = 4000000
    logger.info(args)

    use_cuda = torch.cuda.is_available() and not args.cpu

    logger.info("| decoding with criterion {}".format(args.criterion))

    task = tasks.setup_task(args)

    # Load ensemble
    if args.load_emissions:
        models, criterions = [], []
        task.load_dataset(args.gen_subset)
    else:
        logger.info("| loading model(s) from {}".format(args.path))
        models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(
            utils.split_paths(args.path, separator="\\"),
            arg_overrides=ast.literal_eval(args.model_overrides),
            task=task,
            suffix=args.checkpoint_suffix,
            strict=(args.checkpoint_shard_count == 1),
            num_shards=args.checkpoint_shard_count,
            state=model_state,
        )
        optimize_models(args, use_cuda, models)
        task.load_dataset(args.gen_subset, task_cfg=saved_cfg.task)

    # Set dictionary
    tgt_dict = task.target_dictionary

    logger.info("| {} {} {} examples".format(
        args.data, args.gen_subset, len(task.dataset(args.gen_subset))))

    # hack to pass transitions to W2lDecoder
    if args.criterion == "asg_loss":
        raise NotImplementedError("asg_loss is currently not supported")
        # trans = criterions[0].asg.trans.data
        # args.asg_transitions = torch.flatten(trans).tolist()

    # Load dataset (possibly sharded)
    itr = get_dataset_itr(args, task, models)

    # Initialize generator
    gen_timer = StopwatchMeter()

    def build_generator(args):
        w2l_decoder = getattr(args, "w2l_decoder", None)
        if w2l_decoder == "viterbi":
            from examples.speech_recognition.w2l_decoder import W2lViterbiDecoder

            return W2lViterbiDecoder(args, task.target_dictionary)
        elif w2l_decoder == "kenlm":
            from examples.speech_recognition.w2l_decoder import W2lKenLMDecoder

            return W2lKenLMDecoder(args, task.target_dictionary)
        elif w2l_decoder == "fairseqlm":
            from examples.speech_recognition.w2l_decoder import W2lFairseqLMDecoder

            return W2lFairseqLMDecoder(args, task.target_dictionary)
        else:
            print(
                "only flashlight decoders with (viterbi, kenlm, fairseqlm) options are supported at the moment"
            )

    # please do not touch this unless you test both generate.py and infer.py with audio_pretraining task
    generator = build_generator(args)

    if args.load_emissions:
        generator = ExistingEmissionsDecoder(
            generator, np.load(args.load_emissions, allow_pickle=True))
        logger.info("loaded emissions from " + args.load_emissions)

    num_sentences = 0

    if args.results_path is not None and not os.path.exists(args.results_path):
        os.makedirs(args.results_path)

    max_source_pos = (utils.resolve_max_positions(
        task.max_positions(), *[model.max_positions() for model in models]), )

    if max_source_pos is not None:
        max_source_pos = max_source_pos[0]
        if max_source_pos is not None:
            max_source_pos = max_source_pos[0] - 1

    if args.dump_emissions:
        emissions = {}
    if args.dump_features:
        features = {}
        models[0].bert.proj = None
    else:
        res_files = prepare_result_files(args)
    errs_t = 0
    lengths_t = 0
    with progress_bar.build_progress_bar(args, itr) as t:
        wps_meter = TimeMeter()
        for sample in t:
            sample = utils.move_to_cuda(sample) if use_cuda else sample
            if use_fp16:
                sample = utils.apply_to_sample(apply_half, sample)
            if "net_input" not in sample:
                continue

            prefix_tokens = None
            if args.prefix_size > 0:
                prefix_tokens = sample["target"][:, :args.prefix_size]

            gen_timer.start()
            if args.dump_emissions:
                with torch.no_grad():
                    encoder_out = models[0](**sample["net_input"])
                    emm = models[0].get_normalized_probs(encoder_out,
                                                         log_probs=True)
                    emm = emm.transpose(0, 1).cpu().numpy()
                    for i, id in enumerate(sample["id"]):
                        emissions[id.item()] = emm[i]
                    continue
            elif args.dump_features:
                with torch.no_grad():
                    encoder_out = models[0](**sample["net_input"])
                    feat = encoder_out["encoder_out"].transpose(
                        0, 1).cpu().numpy()
                    for i, id in enumerate(sample["id"]):
                        padding = (encoder_out["encoder_padding_mask"][i].cpu(
                        ).numpy() if encoder_out["encoder_padding_mask"]
                                   is not None else None)
                        features[id.item()] = (feat[i], padding)
                    continue
            hypos = task.inference_step(generator, models, sample,
                                        prefix_tokens)
            num_generated_tokens = sum(len(h[0]["tokens"]) for h in hypos)
            gen_timer.stop(num_generated_tokens)

            for i, sample_id in enumerate(sample["id"].tolist()):
                speaker = None
                # id = task.dataset(args.gen_subset).ids[int(sample_id)]
                id = sample_id
                toks = (sample["target"][i, :] if "target_label" not in sample
                        else sample["target_label"][i, :])
                target_tokens = utils.strip_pad(toks,
                                                tgt_dict.pad()).int().cpu()
                # Process top predictions
                errs, length = process_predictions(
                    args,
                    hypos[i],
                    None,
                    tgt_dict,
                    target_tokens,
                    res_files,
                    speaker,
                    id,
                )
                errs_t += errs
                lengths_t += length

            wps_meter.update(num_generated_tokens)
            t.log({"wps": round(wps_meter.avg)})
            num_sentences += (sample["nsentences"] if "nsentences" in sample
                              else sample["id"].numel())

    wer = None
    if args.dump_emissions:
        emm_arr = []
        for i in range(len(emissions)):
            emm_arr.append(emissions[i])
        np.save(args.dump_emissions, emm_arr)
        logger.info(
            f"saved {len(emissions)} emissions to {args.dump_emissions}")
    elif args.dump_features:
        feat_arr = []
        for i in range(len(features)):
            feat_arr.append(features[i])
        np.save(args.dump_features, feat_arr)
        logger.info(f"saved {len(features)} emissions to {args.dump_features}")
    else:
        if lengths_t > 0:
            wer = errs_t * 100.0 / lengths_t
            logger.info(f"WER: {wer}")

        logger.info("| Processed {} sentences ({} tokens) in {:.1f}s ({:.2f}"
                    "sentences/s, {:.2f} tokens/s)".format(
                        num_sentences,
                        gen_timer.n,
                        gen_timer.sum,
                        num_sentences / gen_timer.sum,
                        1.0 / gen_timer.avg,
                    ))
        logger.info("| Generate {} with beam={}".format(
            args.gen_subset, args.beam))
    return task, wer
def custom_eval(model,
                src,
                trg,
                beam=5,
                ap=math.inf,
                eps=1. / 6,
                mc=None,
                method=None):
    model.eval()
    with torch.no_grad():
        tokenized_sentences = [model.encode((sentence)) for sentence in src]
        gen_args = copy.copy(model.args)
        gen_args.beam = beam
        gen_args.mc = mc
        generator = build_generator(model.task, model.models, gen_args)
        results = []
        # model.args.max_sentences = 64
        total_loops, total_expansions = 0, 0
        if method == 'variable_stream':
            # TODO adjust other parameters; adjust batching params
            ids, translations, total_loops, total_expansions = generator.variable_beam_stream(
                model,
                tokenized_sentences,
                bos_token=model.task.target_dictionary.eos(),
                ap=ap,
                mc=mc,
                eps=eps)
            for id, hypos in zip(ids, translations):
                results.append((id, hypos))
        else:
            for batch in model._build_batches(tokenized_sentences, False):
                # print('b')
                batch = utils.apply_to_sample(lambda t: t.to(model.device),
                                              batch)
                if method is None:
                    translations, n_loops, n_expansions = generator.generate(
                        model.models,
                        batch,
                        bos_token=model.task.target_dictionary.eos(),
                        ap=ap)
                elif method == 'greedy':
                    translations, n_loops, n_expansions = generator.greedy(
                        model.models,
                        batch,
                        bos_token=model.task.target_dictionary.eos())
                elif method == 'variable_beam':
                    translations, n_loops, n_expansions = generator.variable_beam(
                        model,
                        batch,
                        bos_token=model.task.target_dictionary.eos(),
                        ap=ap,
                        mc=mc)
                total_loops += n_loops
                total_expansions += n_expansions
                for id, hypos in zip(batch["id"].tolist(), translations):
                    results.append((id, hypos))

        # sort output to match input order
        outputs = [hypos for _, hypos in sorted(results, key=lambda x: x[0])]
        predictions = [model.decode(hypos[0]['tokens']) for hypos in outputs]
        bleu = sacrebleu.corpus_bleu(predictions, [trg]).score
        # print(predictions)
        print('loops', total_loops)
        print('expansions', total_expansions)
        print(bleu)
        return bleu
Example #17
0
def move_to_device(sample, device):
    def _move_to_device(tensor):
        return tensor.to(device=device)

    return utils.apply_to_sample(_move_to_device, sample)
Example #18
0
def main(args, override_args=None):
    utils.import_user_module(args)

    assert args.max_tokens is not None or args.max_sentences is not None, \
        'Must specify batch size either with --max-tokens or --max-sentences'

    use_fp16 = args.fp16
    use_cuda = torch.cuda.is_available() and not args.cpu

    if override_args is not None:
        try:
            override_args = override_args['override_args']
        except TypeError:
            override_args = override_args
        overrides = vars(override_args)
        overrides.update(eval(getattr(override_args, 'model_overrides', '{}')))
    else:
        overrides = None

    # Load ensemble
    logger.info('loading model(s) from {}'.format(args.path))
    models, model_args, task = checkpoint_utils.load_model_ensemble_and_task(
        [args.path],
        arg_overrides=overrides,
        suffix=getattr(args, "checkpoint_suffix", ""),
    )
    model = models[0]

    # Move models to GPU
    for model in models:
        if use_fp16:
            model.half()
        if use_cuda:
            model.cuda()

    # Print args
    logger.info(model_args)

    # Build criterion
    criterion = task.build_criterion(model_args)
    if use_fp16:
        criterion.half()
    if use_cuda:
        criterion.cuda()
    criterion.eval()

    for subset in args.valid_subset.split(','):
        try:
            task.load_dataset(subset, combine=False, epoch=1)
            dataset = task.dataset(subset)
        except KeyError:
            raise Exception('Cannot find dataset: ' + subset)

        # Initialize data iterator
        itr = task.get_batch_iterator(
            dataset=dataset,
            max_tokens=args.max_tokens,
            max_sentences=args.max_sentences,
            max_positions=utils.resolve_max_positions(
                task.max_positions(),
                *[m.max_positions() for m in models],
            ),
            ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
            required_batch_size_multiple=args.required_batch_size_multiple,
            seed=args.seed,
            num_workers=args.num_workers,
            num_shards=args.distributed_world_size,
            shard_id=args.distributed_rank).next_epoch_itr(shuffle=False)

        progress = progress_bar.progress_bar(
            itr,
            log_format=args.log_format,
            log_interval=args.log_interval,
            prefix=f"valid on '{subset}' subset",
            default_log_format=('tqdm'
                                if not args.no_progress_bar else 'simple'),
        )

        log_outputs = []
        for i, sample in enumerate(progress):
            sample = utils.move_to_cuda(sample) if use_cuda else sample
            sample = utils.apply_to_sample(
                lambda t: t.half() if t.dtype is torch.float32 else t,
                sample) if use_fp16 else sample
            try:
                with torch.no_grad():  # do not save backward passes
                    max_num_rays = 900 * 900
                    if sample['uv'].shape[3] > max_num_rays:
                        sample['ray_split'] = sample['uv'].shape[
                            3] // max_num_rays
                    _loss, _sample_size, log_output = task.valid_step(
                        sample, model, criterion)

                progress.log(log_output, step=i)
                log_outputs.append(log_output)

            except TypeError:
                break

        with metrics.aggregate() as agg:
            task.reduce_metrics(log_outputs, criterion)
            log_output = agg.get_smoothed_values()

        # summarize all the gpus
        if args.distributed_world_size > 1:
            all_log_output = list(
                zip(*distributed_utils.all_gather_list([log_output])))[0]
            log_output = {
                key: np.mean([log[key] for log in all_log_output])
                for key in all_log_output[0]
            }

        progress.print(log_output, tag=subset, step=i)