def get_best_ckpt(val_data_dir, opts):
    pat = re.compile(r'val_results_(?P<step>\d+)_rank0.csv')
    prediction_files = glob.glob('{}/results/val_results_*_rank0.csv'.format(opts.output_dir))

    top_files = Counter()
    for f in prediction_files:
        acc = judge(f, os.path.join(val_data_dir, 'answer.csv'))
        top_files.update({f: acc})

    print(top_files)

    for f, acc in top_files.most_common(1):
        m = pat.match(os.path.basename(f))
        best_epoch = int(m.group('step'))
        return best_epoch
def validate(opts, model, val_loader, split, global_step):
    val_loss = 0
    tot_score = 0
    n_ex = 0
    val_mrr = 0
    st = time()
    results = []
    with tqdm(range(len(val_loader.dataset) // opts.size),
              desc=f'{split}-{opts.rank}') as tq:
        for i, batch in enumerate(val_loader):
            qids = batch['qids']
            targets = batch['targets']
            del batch['targets']
            del batch['qids']
            del batch['gather_index']

            scores, over_logits = model(**batch,
                                        targets=None,
                                        compute_loss=False)
            loss = F.cross_entropy(scores, targets, reduction='sum')
            val_loss += loss.item()
            tot_score += (scores.max(
                dim=-1, keepdim=False)[1] == targets).sum().item()
            max_prob, max_idx = scores.max(dim=-1, keepdim=False)
            answers = max_idx.cpu().tolist()

            targets = torch.gather(batch['option_ids'],
                                   dim=1,
                                   index=targets.unsqueeze(1)).cpu().numpy()
            for j, (qid, target) in enumerate(zip(qids, targets)):
                g = over_logits[j].cpu().numpy()
                top_k = np.argsort(-g)
                val_mrr += 1 / (1 + np.argwhere(top_k == target).item())

            results.extend(zip(qids, answers))
            n_ex += len(qids)
            tq.update(len(qids))

    out_file = f'{opts.output_dir}/results/{split}_results_{global_step}_rank{opts.rank}.csv'
    with open(out_file, 'w') as f:
        for id_, ans in results:
            f.write(f'{id_},{ans}\n')

    val_loss = sum(all_gather_list(val_loss))
    val_mrr = sum(all_gather_list(val_mrr))
    # tot_score = sum(all_gather_list(tot_score))
    n_ex = sum(all_gather_list(n_ex))
    tot_time = time() - st

    val_loss /= n_ex
    val_mrr = val_mrr / n_ex

    out_file = f'{opts.output_dir}/results/{split}_results_{global_step}.csv'
    if not os.path.isfile(out_file):
        with open(out_file, 'wb') as g:
            for f in glob.glob(
                    f'{opts.output_dir}/results/{split}_results_{global_step}_rank*.csv'
            ):
                shutil.copyfileobj(open(f, 'rb'), g)

    sum(all_gather_list(opts.rank))

    txt_db = os.path.join('/txt',
                          intermediate_dir(opts.pretrained_model_name_or_path),
                          getattr(opts, f'{split}_txt_db'))
    val_acc = judge(out_file, f'{txt_db}/answer.csv')
    val_log = {
        f'{split}/loss': val_loss,
        f'{split}/acc': val_acc,
        f'{split}/mrr': val_mrr,
        f'{split}/ex_per_s': n_ex / tot_time
    }
    LOGGER.info(f"validation finished in {int(tot_time)} seconds, "
                f"score: {val_acc * 100:.2f}, "
                f"mrr: {val_mrr:.3f}")
    return val_log
def validate(opts, model, val_loader, split, global_step):
    val_loss = 0
    tot_score = 0
    n_ex = 0
    val_mrr = 0
    st = time()
    results = []
    with tqdm(range(len(val_loader.dataset)),
              desc=f'{split}-{opts.rank}') as tq:
        for i, batch in enumerate(val_loader):
            qids = batch['qids']
            targets = batch['targets']
            del batch['targets']
            del batch['qids']

            scores, over_logits, composition = model(**batch,
                                                     targets=None,
                                                     compute_loss=False)
            loss = F.cross_entropy(scores, targets, reduction='sum')
            val_loss += loss.item()
            tot_score += (scores.max(
                dim=-1, keepdim=False)[1] == targets).sum().item()
            max_prob, max_idx = scores.max(dim=-1, keepdim=False)

            select_masks, atts, composition_gates = composition

            input_ids = torch.gather(batch['input_ids'],
                                     dim=1,
                                     index=batch['gather_index'])
            targets = torch.gather(batch['option_ids'],
                                   dim=1,
                                   index=targets.unsqueeze(1)).cpu().numpy()
            for j, (qid, target, inp, option_ids, position,
                    answer) in enumerate(
                        zip(qids, targets, input_ids, batch['option_ids'],
                            batch['positions'], max_idx)):
                g = over_logits[j].cpu().numpy()
                top_k = np.argsort(-g)
                val_mrr += 1 / (1 + np.argwhere(top_k == target).item())
                if i % 1000 == 0:
                    options = [
                        val_loader.dataset.id2idiom[o.item()]
                        for o in option_ids
                    ]
                    idiom = options[answer.item()]
                    print(qid, val_loader.dataset.id2idiom[target.item()],
                          idiom, options)
                    print(len(select_masks), atts.size())
                    s_masks = [
                        select_mask[j].long().cpu().numpy().tolist()
                        for select_mask in select_masks
                    ]
                    s_att = atts[j].cpu().numpy().tolist()

                    # tokens = val_loader.dataset.tokenizer.convert_ids_to_tokens(inp)
                    # start = tokens.index(val_loader.dataset.tokenizer.mask_token)
                    # tokens[position:position + len(idiom)] = list(idiom)
                    tokens = list(idiom)
                    print(tokens, s_masks, s_att, composition_gates[j].sum())
                    try:
                        tree = Tree(' '.join(tokens),
                                    idiom2tree(tokens, s_masks))
                        print(TreePrettyPrinter(tree).text(unicodelines=True))
                    except:
                        pass

            answers = max_idx.cpu().tolist()
            results.extend(zip(qids, answers))
            n_ex += len(qids)
            tq.update(len(qids))

    out_file = f'{opts.output_dir}/results/{split}_results_{global_step}_rank{opts.rank}.csv'
    with open(out_file, 'w') as f:
        for id_, ans in results:
            f.write(f'{id_},{ans}\n')

    val_loss = sum(all_gather_list(val_loss))
    val_mrr = sum(all_gather_list(val_mrr))
    # tot_score = sum(all_gather_list(tot_score))
    n_ex = sum(all_gather_list(n_ex))
    tot_time = time() - st

    val_loss /= n_ex
    val_mrr = val_mrr / n_ex

    out_file = f'{opts.output_dir}/results/{split}_results_{global_step}.csv'
    if not os.path.isfile(out_file):
        with open(out_file, 'wb') as g:
            for f in glob.glob(
                    f'{opts.output_dir}/results/{split}_results_{global_step}_rank*.csv'
            ):
                shutil.copyfileobj(open(f, 'rb'), g)

    sum(all_gather_list(opts.rank))

    txt_db = os.path.join('/txt',
                          intermediate_dir(opts.pretrained_model_name_or_path),
                          getattr(opts, f'{split}_txt_db'))
    val_acc = judge(out_file, f'{txt_db}/answer.csv')
    val_log = {
        f'{split}/loss': val_loss,
        f'{split}/acc': val_acc,
        f'{split}/mrr': val_mrr,
        f'{split}/ex_per_s': n_ex / tot_time
    }
    LOGGER.info(f"validation finished in {int(tot_time)} seconds, "
                f"score: {val_acc * 100:.2f}, "
                f"mrr: {val_mrr:.3f}")
    return val_log
def validate(opts, model, val_loader, split, global_step):
    val_loss = 0
    tot_score = 0
    n_ex = 0
    val_mrr = 0
    st = time()
    example_logits = {}
    with open(f'{val_loader.dataset.db_dir}/id2eid.json', 'r') as f:
        id2eid = json.load(f)

    with tqdm(range(len(val_loader.dataset)), desc=split) as tq:
        for i, batch in enumerate(val_loader):
            qids = batch['qids']
            targets = batch['targets']
            del batch['targets']
            del batch['gather_index']
            del batch['qids']

            logits, over_logits, cond_logits = model(**batch,
                                                     targets=None,
                                                     compute_loss=False)
            loss = F.cross_entropy(logits, targets, reduction='sum')
            val_loss += loss.item()

            if opts.candidates == 'original':
                logits = logits
            elif opts.candidates == 'enlarged':
                logits = cond_logits
            elif opts.candidates == 'combined':
                logits = logits + cond_logits
            else:
                raise AssertionError("No such loss!")

            # scores, over_logits = model(**batch, targets=None, compute_loss=False)
            # loss = F.cross_entropy(scores, targets, reduction='sum')
            # val_loss += loss.item()
            max_prob, max_idx = logits.max(dim=-1, keepdim=False)
            tot_score += torch.eq(max_idx, targets).sum().item()
            # tot_score += (scores.max(dim=-1, keepdim=False)[1] == targets).sum().item()

            targets = torch.gather(batch['option_ids'],
                                   dim=1,
                                   index=targets.unsqueeze(1)).cpu().numpy()
            for j, (qid, target, score, over_logit) in enumerate(
                    zip(qids, targets, logits, over_logits)):
                g = over_logit.cpu().numpy()
                top_k = np.argsort(-g)
                val_mrr += 1 / (1 + np.argwhere(top_k == target).item())

                eid = id2eid[qid]
                example_logits.setdefault(eid, {})
                example_logits[eid][qid] = score.cpu().numpy()

            n_ex += len(qids)
            tq.update(len(qids))

    out_file = f'{opts.output_dir}/results/{split}_results_{global_step}_rank{opts.rank}.csv'
    with open(out_file, 'w') as f:
        for id_, ans in optimize_answer(example_logits):
            f.write(f'{id_},{ans}\n')

    val_loss = sum(all_gather_list(val_loss))
    n_ex = sum(all_gather_list(n_ex))
    tot_time = time() - st
    val_loss /= n_ex
    val_mrr = val_mrr / n_ex

    out_file = f'{opts.output_dir}/results/{split}_results_{global_step}.csv'
    if not os.path.isfile(out_file):
        with open(out_file, 'wb') as g:
            for f in glob.glob(
                    f'{opts.output_dir}/results/{split}_results_{global_step}_rank*.csv'
            ):
                shutil.copyfileobj(open(f, 'rb'), g)

    sum(all_gather_list(opts.rank))

    txt_db = os.path.join('/txt',
                          intermediate_dir(opts.pretrained_model_name_or_path),
                          getattr(opts, f'{split}_txt_db'))
    val_acc = judge(out_file, f'{txt_db}/answer.csv')

    val_log = {
        f'{split}/loss': val_loss,
        f'{split}/acc': val_acc,
        f'{split}/mrr': val_mrr,
        f'{split}/ex_per_s': n_ex / tot_time
    }
    LOGGER.info(f"validation finished in {int(tot_time)} seconds, "
                f"score: {val_acc * 100:.2f}, "
                f"mrr: {val_mrr:.3f}")
    return val_log
def validate_slide(opts, model, val_loader, split, global_step):
    val_loss = 0
    sentiment_score = 0
    n_ex = 0
    val_mrr = 0
    st = time()
    results = []

    def get_header(key):
        d = idioms_inverse_mapping[key]
        return [f'{key}_{d[v]}_{v}' if isinstance(d[v], str) else f'{key}_{d[v][-1]}_{v}' for v in range(len(d))]

    affection_results = []
    with tqdm(range(len(val_loader.dataset) // opts.size), desc=f'{split}-{opts.rank}') as tq:
        for i, batch in enumerate(val_loader):
            qids = batch['qids']
            targets = batch['targets']
            del batch['targets']
            del batch['qids']

            # select_masks, atts, composition_gates = composition
            if batch['input_ids'].dim() == 3:
                input_ids = torch.gather(batch['input_ids'][1], dim=1, index=batch['gather_index'][0])
            else:
                input_ids = torch.gather(batch['input_ids'], dim=1, index=batch['gather_index'])

            _, over_logits, select_masks, sentiment_logits = model(
                **batch, targets=None, compute_loss=False)

            idiom_targets = targets[:, 0]
            sentiment_targets = targets[:, 1]

            sentiment_score += (
                    sentiment_logits.max(dim=-1, keepdim=False)[1] == sentiment_targets).sum().item()

            if over_logits is not None:
                loss = F.cross_entropy(over_logits, idiom_targets, reduction='sum')
                val_loss += loss.item()
                # tot_score += (scores.max(dim=-1, keepdim=False)[1] == idiom_targets).sum().item()
                max_prob, max_idx = over_logits.max(dim=-1, keepdim=False)

                options = [val_loader.dataset.id2idiom[o] for o in val_loader.dataset.enlarged_candidates]
                for j, (qid, inp, position, answer) in enumerate(zip(qids,
                                                                     # idiom_targets,
                                                                     input_ids,
                                                                     # batch['option_ids'],
                                                                     batch['positions'],
                                                                     max_idx)):
                    # g = over_logits[j].cpu().numpy()
                    # top_k = np.argsort(-g)
                    # val_mrr += 1 / (1 + np.argwhere(top_k == target.item()).item())

                    example = val_loader.dataset.db[qid]
                    idiom = val_loader.dataset.id2idiom[example['idiom']]
                    # idiom = options[target.item()]
                    affection_results.append(
                        [idiom] + sentiment_logits[j].cpu().numpy().tolist()
                    )
                    if i % 1000 == 0:
                        g = over_logits[j].cpu().numpy()
                        top_k = np.argsort(-g)[:5]
                        print(qid,
                              [options[k] for k in top_k],
                              idiom)
                        # print(len(select_masks), atts.size())
                        if select_masks is not None:
                            s_masks = [select_mask[j].long().cpu().numpy().tolist() for select_mask in select_masks]
                            # s_att = atts[j].cpu().numpy().tolist()

                            # tokens = val_loader.dataset.tokenizer.convert_ids_to_tokens(inp)
                            # start = tokens.index(val_loader.dataset.tokenizer.mask_token)
                            # tokens[position:position + len(idiom)] = list(idiom)
                            tokens = val_loader.dataset.tokenizer.convert_ids_to_tokens(
                                val_loader.dataset.idiom_input_ids[qid])
                            # print(tokens, s_masks, s_att, composition_gates[j].sum())
                            print(tokens, s_masks)
                            try:
                                tree = Tree(' '.join(tokens), idiom2tree(tokens, s_masks))
                                print(TreePrettyPrinter(tree).text(unicodelines=True))
                            except:
                                pass

                        predictions = {
                            # "coarse emotion": {
                            #     "target": calo_inverse_mapping['coarse_emotion'].get(coarse_emotion_targets[j].item(),
                            #                                                          '无'),
                            #     "predictions": {calo_inverse_mapping['coarse_emotion'][k]: v for k, v in
                            #                     enumerate(coarse_emotion_logits[j].cpu().numpy().tolist())}
                            # },
                            "sentiment": {
                                "target": idioms_inverse_mapping['sentiment'].get(sentiment_targets[j].item(), '无'),
                                "predictions": {idioms_inverse_mapping['sentiment'][k]: v for k, v in
                                                enumerate(sentiment_logits[j].cpu().numpy().tolist())}
                            }
                        }
                        pprint(predictions)

                answers = max_idx.cpu().tolist()
                results.extend(zip(qids, answers))
            else:
                for j, (qid, inp, position) in enumerate(zip(qids, input_ids,
                                                             # batch['option_ids'],
                                                             batch['positions'],
                                                             )):
                    # options = [val_loader.dataset.id2idiom[o.item()] for o in option_ids]
                    example = val_loader.dataset.db[qid]
                    idiom = val_loader.dataset.id2idiom[example['idiom']]
                    affection_results.append(
                        [idiom] + sentiment_logits[j].cpu().numpy().tolist()
                    )
                    if i % 1000 == 0:
                        print(qid,
                              idiom)
                        if select_masks is not None:
                            s_masks = [select_mask[j].long().cpu().numpy().tolist() for select_mask in select_masks]
                            tokens = val_loader.dataset.tokenizer.convert_ids_to_tokens(
                                val_loader.dataset.idiom_input_ids[qid])
                            # print(tokens, s_masks, s_att, composition_gates[j].sum())
                            print(tokens, s_masks)
                            try:
                                tree = Tree(' '.join(tokens), idiom2tree(tokens, s_masks))
                                print(TreePrettyPrinter(tree).text(unicodelines=True))
                            except:
                                pass

                        predictions = {
                            "sentiment": {
                                "target": idioms_inverse_mapping['sentiment'].get(sentiment_targets[j].item(), '无'),
                                "predictions": {idioms_inverse_mapping['sentiment'][k]: v for k, v in
                                                enumerate(sentiment_logits[j].cpu().numpy().tolist())}
                            }
                        }
                        pprint(predictions)

            n_ex += len(qids)
            tq.update(len(qids))

    if results:
        out_file = f'{opts.output_dir}/results/{split}_results_{global_step}_rank{opts.rank}.csv'
        with open(out_file, 'w') as f:
            for id_, ans in results:
                f.write(f'{id_},{ans}\n')

    header = ['idiom'] + get_header('sentiment')
    if affection_results:
        out_file = f'{opts.output_dir}/results/{split}_affection_results_{global_step}_rank{opts.rank}.csv'
        pd.DataFrame(affection_results, columns=header).to_csv(out_file)

    val_loss = sum(all_gather_list(val_loss))
    val_mrr = sum(all_gather_list(val_mrr))

    val_sentiment_score = sum(all_gather_list(sentiment_score))

    n_ex = sum(all_gather_list(n_ex))
    tot_time = time() - st

    val_loss /= n_ex
    val_mrr = val_mrr / n_ex
    val_sentiment_score = val_sentiment_score / n_ex

    if results:
        out_file = f'{opts.output_dir}/results/{split}_results_{global_step}.csv'
        if not os.path.isfile(out_file):
            with open(out_file, 'wb') as g:
                for f in glob.glob(f'{opts.output_dir}/results/{split}_results_{global_step}_rank*.csv'):
                    shutil.copyfileobj(open(f, 'rb'), g)

        sum(all_gather_list(opts.rank))

        txt_db = os.path.join('/txt',
                              intermediate_dir(opts.pretrained_model_name_or_path),
                              getattr(opts, f'{split}_txt_db'))
        val_acc = judge(out_file, f'{txt_db}/answer.csv')

    if opts.rank == 0:
        results_files = glob.glob(f'{opts.output_dir}/results/{split}_affection_results_{global_step}_rank*.csv')
        new_affection_results_df = pd.concat(map(pd.read_csv, results_files))
        idiom_num = new_affection_results_df['idiom'].unique().size
        idiom_wise_accs = {}
        for item in new_affection_results_df.groupby('idiom').mean().reset_index().to_dict(orient='records'):
            idiom = item['idiom']
            idiom_id = val_loader.dataset.vocab[idiom]
            for sub_type in ['sentiment']:
                d = {k: v for k, v in item.items() if k.startswith(sub_type)}
                key = max(d, key=d.get)
                _, pred = key.rsplit('_', 1)
                target = val_loader.dataset.sentiments[idiom_id]
                idiom_wise_accs.setdefault(sub_type, 0)
                idiom_wise_accs[sub_type] += (int(pred) == target) / idiom_num * 100

        val_acc = val_sentiment_score

        val_log = {f'{split}/loss': val_loss,
                   f'{split}/acc': val_acc,
                   f'{split}/sentiment': val_sentiment_score * 100,
                   f'{split}/mrr': val_mrr,
                   f'{split}/ex_per_s': n_ex / tot_time}

        for k, v in idiom_wise_accs.items():
            val_log[f'{split}/{k}'] = v

        LOGGER.info(f"validation finished in {int(tot_time)} seconds, \n"
                    # f"coarse emotion score: {val_coarse_emotion_score * 100:.2f}, \n"
                    f"sentiment score: {val_sentiment_score * 100:.2f}, \n"
                    f"score: {val_acc * 100:.2f}, \n"
                    f"idiom-wise score: {idiom_wise_accs}, "
                    f"mrr: {val_mrr:.3f}")
        return val_log