Exemplo n.º 1
0
 def _get_dataset(reddit_dict, parlai_dict):
     if args.task == "dailydialog":
         return DDDataset(
             split,
             parlai_dict,
             data_folder=args.dailydialog_folder,
             history_len=history_len,
         )
     elif args.task == "empchat":
         return EmpDataset(
             split,
             parlai_dict,
             data_folder=args.empchat_folder,
             history_len=history_len,
             reactonly=args.reactonly,
             fasttext=args.fasttext,
             fasttext_type=args.fasttext_type,
             fasttext_path=args.fasttext_path,
         )
     elif args.task == "reddit":
         return RedditDataset(
             data_folder=args.reddit_folder,
             chunk_id=999,
             dict_=reddit_dict,
             max_hist_len=history_len,
             rm_blank_sentences=True,
         )
     else:
         raise ValueError("Task unrecognized!")
 def build_reddit_dataset(self, chunk_id):
     return RedditDataset(
         self.opt.reddit_folder,
         chunk_id,
         self.dict,
         max_len=self.opt.max_sent_len,
         rm_long_sent=self.opt.rm_long_sent,
         max_hist_len=self.opt.max_hist_len,
         rm_long_contexts=self.opt.rm_long_contexts,
     )
Exemplo n.º 3
0
def build_candidates(max_cand_length,
                     n_cands=int(1e7),
                     rm_duplicates=True,
                     rm_starting_gt=True):
    global actual_ct
    global args
    tensor = torch.LongTensor(n_cands, max_cand_length).fill_(NET_PAD_IDX)
    i = 0
    chunk = 422
    if "bert_tokenizer" in net_dictionary:
        gt_tokens = torch.LongTensor(
            net_dictionary["bert_tokenizer"].convert_tokens_to_ids(
                ["&", "g", "##t"]))
    else:
        gt_index = net_dictionary["words"]["&gt"]
        lt_index = net_dictionary["words"]["&lt"]
    unk_index = net_dictionary["words"]["<UNK>"]
    n_duplicates = n_start_gt = 0
    if rm_duplicates:
        all_sent = set()

    def _has_lts(sentence_) -> bool:
        if "bert_tokenizer" in net_dictionary:
            tokens = net_dictionary["bert_tokenizer"].convert_ids_to_tokens(
                sentence_.tolist())
            return "& l ##t" in " ".join(tokens)
        else:
            return torch.sum(sentence_ == lt_index).gt(0)

    def _starts_with_gt(sentence_) -> bool:
        if "bert_tokenizer" in net_dictionary:
            if sentence_.size(0) < 3:
                return False
            else:
                return torch.eq(sentence_[:3], gt_tokens).all()
        else:
            return sentence_[0].item == gt_index

    parlai_dict = ParlAIDictionary.create_from_reddit_style(net_dictionary)
    if args.empchat_cands:
        dataset = EmpDataset(
            "train",
            parlai_dict,
            data_folder=args.empchat_folder,
            reactonly=False,
            fasttext=args.fasttext,
            fasttext_type=args.fasttext_type,
            fasttext_path=args.fasttext_path,
        )
        sample_index = range(len(dataset))
        for data_idx in sample_index:
            _context, sentence, _ = dataset[data_idx]
            sent_length = sentence.size(0)
            if torch.sum(sentence == unk_index).gt(0):
                continue
            if _has_lts(sentence):
                continue
            if sent_length <= max_cand_length:
                if _starts_with_gt(sentence) and rm_starting_gt:
                    n_start_gt += 1
                    continue
                if rm_duplicates:
                    tuple_sent = tuple(sentence.numpy())
                    if tuple_sent in all_sent:
                        n_duplicates += 1
                        continue
                    all_sent.add(tuple_sent)
                tensor[i, :sentence.size(0)] = sentence
                i += 1
                if i >= n_cands:
                    break
    breakpoint_ = i
    actual_ct[1] = i
    if args.dailydialog_cands:
        dataset = DDDataset("train",
                            parlai_dict,
                            data_folder=args.dailydialog_folder)
        sample_index = range(len(dataset))
        for data_idx in sample_index:
            _context, sentence = dataset[data_idx]
            sent_length = sentence.size(0)
            if torch.sum(sentence == unk_index).gt(0):
                continue
            if _has_lts(sentence):
                continue
            if sent_length <= max_cand_length:
                if _starts_with_gt(sentence) and rm_starting_gt:
                    n_start_gt += 1
                    continue
                if rm_duplicates:
                    tuple_sent = tuple(sentence.numpy())
                    if tuple_sent in all_sent:
                        n_duplicates += 1
                        continue
                    all_sent.add(tuple_sent)
                tensor[i, :sentence.size(0)] = sentence
                i += 1
                if i >= n_cands:
                    break
    bp2 = i
    actual_ct[2] = i - breakpoint_
    if args.reddit_cands:
        while i < n_cands:
            chunk += 1
            logging.info(f"Loaded {i} / {n_cands} candidates")
            dataset = RedditDataset(args.reddit_folder, chunk, net_dictionary)
            sample_index = range(len(dataset))
            for data_idx in sample_index:
                _context, sentence = dataset[data_idx]
                sent_length = sentence.size(0)
                if sent_length == 0:
                    print(f"Reddit sentence {data_idx} is of length 0.")
                    continue
                if torch.sum(sentence == unk_index).gt(0):
                    continue
                if _has_lts(sentence):
                    continue
                if sent_length <= max_cand_length:
                    if _starts_with_gt(sentence) and rm_starting_gt:
                        n_start_gt += 1
                        continue
                    if rm_duplicates:
                        tuple_sent = tuple(sentence.numpy())
                        if tuple_sent in all_sent:
                            n_duplicates += 1
                            continue
                        all_sent.add(tuple_sent)
                    tensor[i, :sentence.size(0)] = sentence
                    i += 1
                    if i >= n_cands:
                        break
    actual_ct[0] = i - bp2
    logging.info(
        f"Loaded {i} candidates, {n_start_gt} start with >, {n_duplicates} duplicates"
    )
    args.n_candidates = i
    return tensor[:i, :], breakpoint_, bp2