Exemplo n.º 1
0
def main(opts):
    device = torch.device("cuda", hvd.local_rank())
    torch.cuda.set_device(hvd.local_rank())
    rank = hvd.rank()
    opts.rank = rank
    opts.size = hvd.size()
    LOGGER.info("device: {} n_gpu: {}, rank: {}, "
                "16-bits training: {}".format(
        device, n_gpu, hvd.rank(), opts.fp16))

    if opts.gradient_accumulation_steps < 1:
        raise ValueError("Invalid gradient_accumulation_steps parameter: {}, "
                         "should be >= 1".format(
            opts.gradient_accumulation_steps))

    set_random_seed(opts.seed)
    if opts.rank == 0:
        TB_LOGGER.create(join(opts.output_dir, 'log'))
        add_log_to_file(join(opts.output_dir, 'log', f'{opts.mode}.log'))

    # data loaders
    DatasetCls = DATA_REGISTRY[opts.dataset_cls]
    EvalDatasetCls = DATA_REGISTRY[opts.eval_dataset_cls]
    splits, dataloaders = create_dataloaders(DatasetCls, EvalDatasetCls, opts)

    # Prepare model
    model = build_model(opts)
    model.to(device)

    if opts.mode == 'train':
        best_ckpt = train(model, dataloaders, opts)
    elif opts.mode == 'eval':
        best_ckpt = None
        if opts.rank == 0:
            os.makedirs(join(opts.output_dir, 'results'), exist_ok=True)  # store val predictions
    else:
        best_ckpt = get_best_ckpt(dataloaders['val'].dataset.db_dir, opts)

    sum(all_gather_list(opts.rank))

    if best_ckpt is not None:
        best_pt = f'{opts.output_dir}/ckpt/model_step_{best_ckpt}.pt'
        model.load_state_dict(torch.load(best_pt), strict=False)

    sum(all_gather_list(opts.rank))

    log = evaluation(model, dict(filter(lambda x: x[0] != 'train', dataloaders.items())), opts, best_ckpt)
    splits = ['val', 'test', 'ran', 'sim', 'out']
    LOGGER.info('\t'.join(splits))
    LOGGER.info('\t'.join(chain(
        [format(log[f'{split}/acc'], "0.6f") for split in splits],
        [format(log[f'{split}/mrr'], "0.6f") for split in splits]
    )))
Exemplo n.º 2
0
def train(model, dataloaders, opts):
    # make sure every process has same model parameters in the beginning
    broadcast_tensors([p.data for p in model.parameters()], 0)
    set_dropout(model, opts.dropout)

    # Prepare optimizer
    optimizer = build_optimizer(model, opts)
    scaler = GradScaler()

    global_step = 0
    if opts.rank == 0:
        save_training_meta(opts)
        TB_LOGGER.create(join(opts.output_dir, 'log'))
        pbar = tqdm(total=opts.num_train_steps, desc=opts.model)
        model_saver = ModelSaver(join(opts.output_dir, 'ckpt'))
        os.makedirs(join(opts.output_dir, 'results'),
                    exist_ok=True)  # store val predictions
        add_log_to_file(join(opts.output_dir, 'log', 'log.txt'))
    else:
        LOGGER.disabled = True
        pbar = NoOp()
        model_saver = NoOp()

    LOGGER.info(f"***** Running training with {opts.n_gpu} GPUs *****")
    LOGGER.info("  Num examples = %d", len(dataloaders['train'].dataset))
    LOGGER.info("  Batch size = %d", opts.train_batch_size)
    LOGGER.info("  Accumulate steps = %d", opts.gradient_accumulation_steps)
    LOGGER.info("  Num steps = %d", opts.num_train_steps)

    running_loss = RunningMeter('loss')
    model.train()
    n_examples = 0
    n_epoch = 0
    best_ckpt = 0
    best_eval = 0
    start = time()
    # quick hack for amp delay_unscale bug
    optimizer.zero_grad()
    optimizer.step()
    while True:
        for step, batch in enumerate(dataloaders['train']):
            targets = batch['targets']
            del batch['gather_index']
            n_examples += targets.size(0)

            with autocast():
                loss = model(**batch, compute_loss=True)
                loss = loss.mean()

            delay_unscale = (step + 1) % opts.gradient_accumulation_steps != 0
            scaler.scale(loss).backward()
            if not delay_unscale:
                # gather gradients from every processes
                # do this before unscaling to make sure every process uses
                # the same gradient scale
                grads = [
                    p.grad.data for p in model.parameters()
                    if p.requires_grad and p.grad is not None
                ]
                all_reduce_and_rescale_tensors(grads, float(1))

            running_loss(loss.item())

            if (step + 1) % opts.gradient_accumulation_steps == 0:
                global_step += 1

                # learning rate scheduling
                lr_this_step = get_lr_sched(global_step, opts)
                for param_group in optimizer.param_groups:
                    param_group['lr'] = lr_this_step
                TB_LOGGER.add_scalar('lr', lr_this_step, global_step)

                # log loss
                losses = all_gather_list(running_loss)
                running_loss = RunningMeter(
                    'loss',
                    sum(l.val for l in losses) / len(losses))
                TB_LOGGER.add_scalar('loss', running_loss.val, global_step)
                TB_LOGGER.step()

                # update model params
                if opts.grad_norm != -1:
                    # Unscales the gradients of optimizer's assigned params in-place
                    scaler.unscale_(optimizer)
                    grad_norm = clip_grad_norm_(model.parameters(),
                                                opts.grad_norm)
                    TB_LOGGER.add_scalar('grad_norm', grad_norm, global_step)

                # scaler.step() first unscales gradients of the optimizer's params.
                # If gradients don't contain infs/NaNs, optimizer.step() is then called,
                # otherwise, optimizer.step() is skipped.
                scaler.step(optimizer)

                # Updates the scale for next iteration.
                scaler.update()
                optimizer.zero_grad()
                pbar.update(1)

                if global_step % 100 == 0:
                    # monitor training throughput
                    tot_ex = sum(all_gather_list(n_examples))
                    ex_per_sec = int(tot_ex / (time() - start))
                    LOGGER.info(f'{opts.model}: {n_epoch}-{global_step}: '
                                f'{tot_ex} examples trained at '
                                f'{ex_per_sec} ex/s '
                                f'best_acc-{best_eval * 100:.2f}')
                    TB_LOGGER.add_scalar('perf/ex_per_s', ex_per_sec,
                                         global_step)

                if global_step % opts.valid_steps == 0:
                    log = evaluation(
                        model,
                        dict(
                            filter(lambda x: x[0].startswith('val'),
                                   dataloaders.items())), opts, global_step)
                    log_eval = log['val/acc']
                    if log_eval > best_eval:
                        best_ckpt = global_step
                        best_eval = log_eval
                        pbar.set_description(
                            f'{opts.model}: {n_epoch}-{best_ckpt} best_acc-{best_eval * 100:.2f}'
                        )
                    model_saver.save(model, global_step)
            if global_step >= opts.num_train_steps:
                break
        if global_step >= opts.num_train_steps:
            break
        n_epoch += 1
        LOGGER.info(f"Step {global_step}: finished {n_epoch} epochs")
        # if n_epoch >= opts.num_train_epochs:
        #     break
    return best_ckpt
Exemplo n.º 3
0
def validate(opts, model, val_loader, split, global_step):
    val_loss = 0
    tot_score = 0
    n_ex = 0
    val_mrr = 0
    st = time()
    results = []
    with tqdm(range(len(val_loader.dataset) // opts.size),
              desc=f'{split}-{opts.rank}') as tq:
        for i, batch in enumerate(val_loader):
            qids = batch['qids']
            targets = batch['targets']
            del batch['targets']
            del batch['qids']
            del batch['gather_index']

            scores, over_logits = model(**batch,
                                        targets=None,
                                        compute_loss=False)
            loss = F.cross_entropy(scores, targets, reduction='sum')
            val_loss += loss.item()
            tot_score += (scores.max(
                dim=-1, keepdim=False)[1] == targets).sum().item()
            max_prob, max_idx = scores.max(dim=-1, keepdim=False)
            answers = max_idx.cpu().tolist()

            targets = torch.gather(batch['option_ids'],
                                   dim=1,
                                   index=targets.unsqueeze(1)).cpu().numpy()
            for j, (qid, target) in enumerate(zip(qids, targets)):
                g = over_logits[j].cpu().numpy()
                top_k = np.argsort(-g)
                val_mrr += 1 / (1 + np.argwhere(top_k == target).item())

            results.extend(zip(qids, answers))
            n_ex += len(qids)
            tq.update(len(qids))

    out_file = f'{opts.output_dir}/results/{split}_results_{global_step}_rank{opts.rank}.csv'
    with open(out_file, 'w') as f:
        for id_, ans in results:
            f.write(f'{id_},{ans}\n')

    val_loss = sum(all_gather_list(val_loss))
    val_mrr = sum(all_gather_list(val_mrr))
    # tot_score = sum(all_gather_list(tot_score))
    n_ex = sum(all_gather_list(n_ex))
    tot_time = time() - st

    val_loss /= n_ex
    val_mrr = val_mrr / n_ex

    out_file = f'{opts.output_dir}/results/{split}_results_{global_step}.csv'
    if not os.path.isfile(out_file):
        with open(out_file, 'wb') as g:
            for f in glob.glob(
                    f'{opts.output_dir}/results/{split}_results_{global_step}_rank*.csv'
            ):
                shutil.copyfileobj(open(f, 'rb'), g)

    sum(all_gather_list(opts.rank))

    txt_db = os.path.join('/txt',
                          intermediate_dir(opts.pretrained_model_name_or_path),
                          getattr(opts, f'{split}_txt_db'))
    val_acc = judge(out_file, f'{txt_db}/answer.csv')
    val_log = {
        f'{split}/loss': val_loss,
        f'{split}/acc': val_acc,
        f'{split}/mrr': val_mrr,
        f'{split}/ex_per_s': n_ex / tot_time
    }
    LOGGER.info(f"validation finished in {int(tot_time)} seconds, "
                f"score: {val_acc * 100:.2f}, "
                f"mrr: {val_mrr:.3f}")
    return val_log
Exemplo n.º 4
0
def validate(opts, model, val_loader, split, global_step):
    val_loss = 0
    tot_score = 0
    n_ex = 0
    val_mrr = 0
    st = time()
    results = []
    with tqdm(range(len(val_loader.dataset)),
              desc=f'{split}-{opts.rank}') as tq:
        for i, batch in enumerate(val_loader):
            qids = batch['qids']
            targets = batch['targets']
            del batch['targets']
            del batch['qids']

            scores, over_logits, composition = model(**batch,
                                                     targets=None,
                                                     compute_loss=False)
            loss = F.cross_entropy(scores, targets, reduction='sum')
            val_loss += loss.item()
            tot_score += (scores.max(
                dim=-1, keepdim=False)[1] == targets).sum().item()
            max_prob, max_idx = scores.max(dim=-1, keepdim=False)

            select_masks, atts, composition_gates = composition

            input_ids = torch.gather(batch['input_ids'],
                                     dim=1,
                                     index=batch['gather_index'])
            targets = torch.gather(batch['option_ids'],
                                   dim=1,
                                   index=targets.unsqueeze(1)).cpu().numpy()
            for j, (qid, target, inp, option_ids, position,
                    answer) in enumerate(
                        zip(qids, targets, input_ids, batch['option_ids'],
                            batch['positions'], max_idx)):
                g = over_logits[j].cpu().numpy()
                top_k = np.argsort(-g)
                val_mrr += 1 / (1 + np.argwhere(top_k == target).item())
                if i % 1000 == 0:
                    options = [
                        val_loader.dataset.id2idiom[o.item()]
                        for o in option_ids
                    ]
                    idiom = options[answer.item()]
                    print(qid, val_loader.dataset.id2idiom[target.item()],
                          idiom, options)
                    print(len(select_masks), atts.size())
                    s_masks = [
                        select_mask[j].long().cpu().numpy().tolist()
                        for select_mask in select_masks
                    ]
                    s_att = atts[j].cpu().numpy().tolist()

                    # tokens = val_loader.dataset.tokenizer.convert_ids_to_tokens(inp)
                    # start = tokens.index(val_loader.dataset.tokenizer.mask_token)
                    # tokens[position:position + len(idiom)] = list(idiom)
                    tokens = list(idiom)
                    print(tokens, s_masks, s_att, composition_gates[j].sum())
                    try:
                        tree = Tree(' '.join(tokens),
                                    idiom2tree(tokens, s_masks))
                        print(TreePrettyPrinter(tree).text(unicodelines=True))
                    except:
                        pass

            answers = max_idx.cpu().tolist()
            results.extend(zip(qids, answers))
            n_ex += len(qids)
            tq.update(len(qids))

    out_file = f'{opts.output_dir}/results/{split}_results_{global_step}_rank{opts.rank}.csv'
    with open(out_file, 'w') as f:
        for id_, ans in results:
            f.write(f'{id_},{ans}\n')

    val_loss = sum(all_gather_list(val_loss))
    val_mrr = sum(all_gather_list(val_mrr))
    # tot_score = sum(all_gather_list(tot_score))
    n_ex = sum(all_gather_list(n_ex))
    tot_time = time() - st

    val_loss /= n_ex
    val_mrr = val_mrr / n_ex

    out_file = f'{opts.output_dir}/results/{split}_results_{global_step}.csv'
    if not os.path.isfile(out_file):
        with open(out_file, 'wb') as g:
            for f in glob.glob(
                    f'{opts.output_dir}/results/{split}_results_{global_step}_rank*.csv'
            ):
                shutil.copyfileobj(open(f, 'rb'), g)

    sum(all_gather_list(opts.rank))

    txt_db = os.path.join('/txt',
                          intermediate_dir(opts.pretrained_model_name_or_path),
                          getattr(opts, f'{split}_txt_db'))
    val_acc = judge(out_file, f'{txt_db}/answer.csv')
    val_log = {
        f'{split}/loss': val_loss,
        f'{split}/acc': val_acc,
        f'{split}/mrr': val_mrr,
        f'{split}/ex_per_s': n_ex / tot_time
    }
    LOGGER.info(f"validation finished in {int(tot_time)} seconds, "
                f"score: {val_acc * 100:.2f}, "
                f"mrr: {val_mrr:.3f}")
    return val_log
Exemplo n.º 5
0
def main(opts):
    device = torch.device("cuda", hvd.local_rank())
    torch.cuda.set_device(hvd.local_rank())
    rank = hvd.rank()
    opts.rank = rank
    opts.size = hvd.size()
    LOGGER.info("device: {} n_gpu: {}, rank: {}, "
                "16-bits training: {}".format(device, n_gpu, hvd.rank(),
                                              opts.fp16))

    if opts.gradient_accumulation_steps < 1:
        raise ValueError("Invalid gradient_accumulation_steps parameter: {}, "
                         "should be >= 1".format(
                             opts.gradient_accumulation_steps))

    set_random_seed(opts.seed)

    # data loaders
    DatasetCls = DATA_REGISTRY[opts.dataset_cls]
    EvalDatasetCls = DATA_REGISTRY[opts.eval_dataset_cls]
    splits, dataloaders = create_dataloaders(DatasetCls, EvalDatasetCls, opts)

    # Prepare model
    model = build_model(opts)
    model.to(device)
    # make sure every process has same model parameters in the beginning
    broadcast_tensors([p.data for p in model.parameters()], 0)
    set_dropout(model, opts.dropout)

    # Prepare optimizer
    optimizer = build_optimizer(model, opts)
    scaler = GradScaler()

    global_step = 0
    if rank == 0:
        save_training_meta(opts)
        TB_LOGGER.create(join(opts.output_dir, 'log'))
        pbar = tqdm(total=opts.num_train_steps, desc=opts.model)
        model_saver = ModelSaver(join(opts.output_dir, 'ckpt'))
        os.makedirs(join(opts.output_dir, 'results'),
                    exist_ok=True)  # store val predictions
        add_log_to_file(join(opts.output_dir, 'log', 'log.txt'))
    else:
        LOGGER.disabled = True
        pbar = NoOp()
        model_saver = NoOp()

    LOGGER.info(f"***** Running training with {n_gpu} GPUs *****")
    LOGGER.info("  Num examples = %d", len(dataloaders['train'].dataset))
    LOGGER.info("  Batch size = %d", opts.train_batch_size)
    LOGGER.info("  Accumulate steps = %d", opts.gradient_accumulation_steps)
    LOGGER.info("  Num steps = %d", opts.num_train_steps)

    running_loss = RunningMeter('loss')
    model.train()
    n_examples = 0
    n_epoch = 0
    best_ckpt = 0
    best_eval = 0
    start = time()
    # quick hack for amp delay_unscale bug
    optimizer.zero_grad()
    optimizer.step()
    while True:
        for step, batch in enumerate(dataloaders['train']):
            targets = batch['targets']
            del batch['gather_index']
            n_examples += targets.size(0)

            with autocast():
                original_loss, enlarged_loss = model(**batch,
                                                     compute_loss=True)
                if opts.candidates == 'original':
                    loss = original_loss
                elif opts.candidates == 'enlarged':
                    loss = enlarged_loss
                elif opts.candidates == 'combined':
                    loss = original_loss + enlarged_loss
                else:
                    raise AssertionError("No such loss!")

                loss = loss.mean()

            delay_unscale = (step + 1) % opts.gradient_accumulation_steps != 0
            scaler.scale(loss).backward()
            if not delay_unscale:
                # gather gradients from every processes
                # do this before unscaling to make sure every process uses
                # the same gradient scale
                grads = [
                    p.grad.data for p in model.parameters()
                    if p.requires_grad and p.grad is not None
                ]
                all_reduce_and_rescale_tensors(grads, float(1))

            running_loss(loss.item())

            if (step + 1) % opts.gradient_accumulation_steps == 0:
                global_step += 1

                # learning rate scheduling
                lr_this_step = get_lr_sched(global_step, opts)
                for param_group in optimizer.param_groups:
                    param_group['lr'] = lr_this_step
                TB_LOGGER.add_scalar('lr', lr_this_step, global_step)

                # log loss
                losses = all_gather_list(running_loss)
                running_loss = RunningMeter(
                    'loss',
                    sum(l.val for l in losses) / len(losses))
                TB_LOGGER.add_scalar('loss', running_loss.val, global_step)
                TB_LOGGER.step()

                # update model params
                if opts.grad_norm != -1:
                    # Unscales the gradients of optimizer's assigned params in-place
                    scaler.unscale_(optimizer)
                    grad_norm = clip_grad_norm_(model.parameters(),
                                                opts.grad_norm)
                    TB_LOGGER.add_scalar('grad_norm', grad_norm, global_step)

                # scaler.step() first unscales gradients of the optimizer's params.
                # If gradients don't contain infs/NaNs, optimizer.step() is then called,
                # otherwise, optimizer.step() is skipped.
                scaler.step(optimizer)

                # Updates the scale for next iteration.
                scaler.update()
                optimizer.zero_grad()
                pbar.update(1)

                if global_step % 100 == 0:
                    # monitor training throughput
                    tot_ex = sum(all_gather_list(n_examples))
                    ex_per_sec = int(tot_ex / (time() - start))
                    LOGGER.info(f'{opts.model}: {n_epoch}-{global_step}: '
                                f'{tot_ex} examples trained at '
                                f'{ex_per_sec} ex/s '
                                f'best_acc-{best_eval * 100:.2f}')
                    TB_LOGGER.add_scalar('perf/ex_per_s', ex_per_sec,
                                         global_step)

                if global_step % opts.valid_steps == 0:
                    log = evaluation(
                        model,
                        dict(
                            filter(lambda x: x[0].startswith('val'),
                                   dataloaders.items())), opts, global_step)
                    if log['val/acc'] > best_eval:
                        best_ckpt = global_step
                        best_eval = log['val/acc']
                        pbar.set_description(
                            f'{opts.model}: {n_epoch}-{best_ckpt} best_acc-{best_eval * 100:.2f}'
                        )
                    model_saver.save(model, global_step)
            if global_step >= opts.num_train_steps:
                break
        if global_step >= opts.num_train_steps:
            break
        n_epoch += 1
        LOGGER.info(f"Step {global_step}: finished {n_epoch} epochs")

    sum(all_gather_list(opts.rank))

    best_pt = f'{opts.output_dir}/ckpt/model_step_{best_ckpt}.pt'
    model.load_state_dict(torch.load(best_pt), strict=False)
    evaluation(model,
               dict(filter(lambda x: x[0] != 'train', dataloaders.items())),
               opts, best_ckpt)
Exemplo n.º 6
0
def validate(opts, model, val_loader, split, global_step):
    val_loss = 0
    tot_score = 0
    n_ex = 0
    val_mrr = 0
    st = time()
    example_logits = {}
    with open(f'{val_loader.dataset.db_dir}/id2eid.json', 'r') as f:
        id2eid = json.load(f)

    with tqdm(range(len(val_loader.dataset)), desc=split) as tq:
        for i, batch in enumerate(val_loader):
            qids = batch['qids']
            targets = batch['targets']
            del batch['targets']
            del batch['gather_index']
            del batch['qids']

            logits, over_logits, cond_logits = model(**batch,
                                                     targets=None,
                                                     compute_loss=False)
            loss = F.cross_entropy(logits, targets, reduction='sum')
            val_loss += loss.item()

            if opts.candidates == 'original':
                logits = logits
            elif opts.candidates == 'enlarged':
                logits = cond_logits
            elif opts.candidates == 'combined':
                logits = logits + cond_logits
            else:
                raise AssertionError("No such loss!")

            # scores, over_logits = model(**batch, targets=None, compute_loss=False)
            # loss = F.cross_entropy(scores, targets, reduction='sum')
            # val_loss += loss.item()
            max_prob, max_idx = logits.max(dim=-1, keepdim=False)
            tot_score += torch.eq(max_idx, targets).sum().item()
            # tot_score += (scores.max(dim=-1, keepdim=False)[1] == targets).sum().item()

            targets = torch.gather(batch['option_ids'],
                                   dim=1,
                                   index=targets.unsqueeze(1)).cpu().numpy()
            for j, (qid, target, score, over_logit) in enumerate(
                    zip(qids, targets, logits, over_logits)):
                g = over_logit.cpu().numpy()
                top_k = np.argsort(-g)
                val_mrr += 1 / (1 + np.argwhere(top_k == target).item())

                eid = id2eid[qid]
                example_logits.setdefault(eid, {})
                example_logits[eid][qid] = score.cpu().numpy()

            n_ex += len(qids)
            tq.update(len(qids))

    out_file = f'{opts.output_dir}/results/{split}_results_{global_step}_rank{opts.rank}.csv'
    with open(out_file, 'w') as f:
        for id_, ans in optimize_answer(example_logits):
            f.write(f'{id_},{ans}\n')

    val_loss = sum(all_gather_list(val_loss))
    n_ex = sum(all_gather_list(n_ex))
    tot_time = time() - st
    val_loss /= n_ex
    val_mrr = val_mrr / n_ex

    out_file = f'{opts.output_dir}/results/{split}_results_{global_step}.csv'
    if not os.path.isfile(out_file):
        with open(out_file, 'wb') as g:
            for f in glob.glob(
                    f'{opts.output_dir}/results/{split}_results_{global_step}_rank*.csv'
            ):
                shutil.copyfileobj(open(f, 'rb'), g)

    sum(all_gather_list(opts.rank))

    txt_db = os.path.join('/txt',
                          intermediate_dir(opts.pretrained_model_name_or_path),
                          getattr(opts, f'{split}_txt_db'))
    val_acc = judge(out_file, f'{txt_db}/answer.csv')

    val_log = {
        f'{split}/loss': val_loss,
        f'{split}/acc': val_acc,
        f'{split}/mrr': val_mrr,
        f'{split}/ex_per_s': n_ex / tot_time
    }
    LOGGER.info(f"validation finished in {int(tot_time)} seconds, "
                f"score: {val_acc * 100:.2f}, "
                f"mrr: {val_mrr:.3f}")
    return val_log
Exemplo n.º 7
0
def validate_slide(opts, model, val_loader, split, global_step):
    val_loss = 0
    sentiment_score = 0
    n_ex = 0
    val_mrr = 0
    st = time()
    results = []

    def get_header(key):
        d = idioms_inverse_mapping[key]
        return [f'{key}_{d[v]}_{v}' if isinstance(d[v], str) else f'{key}_{d[v][-1]}_{v}' for v in range(len(d))]

    affection_results = []
    with tqdm(range(len(val_loader.dataset) // opts.size), desc=f'{split}-{opts.rank}') as tq:
        for i, batch in enumerate(val_loader):
            qids = batch['qids']
            targets = batch['targets']
            del batch['targets']
            del batch['qids']

            # select_masks, atts, composition_gates = composition
            if batch['input_ids'].dim() == 3:
                input_ids = torch.gather(batch['input_ids'][1], dim=1, index=batch['gather_index'][0])
            else:
                input_ids = torch.gather(batch['input_ids'], dim=1, index=batch['gather_index'])

            _, over_logits, select_masks, sentiment_logits = model(
                **batch, targets=None, compute_loss=False)

            idiom_targets = targets[:, 0]
            sentiment_targets = targets[:, 1]

            sentiment_score += (
                    sentiment_logits.max(dim=-1, keepdim=False)[1] == sentiment_targets).sum().item()

            if over_logits is not None:
                loss = F.cross_entropy(over_logits, idiom_targets, reduction='sum')
                val_loss += loss.item()
                # tot_score += (scores.max(dim=-1, keepdim=False)[1] == idiom_targets).sum().item()
                max_prob, max_idx = over_logits.max(dim=-1, keepdim=False)

                options = [val_loader.dataset.id2idiom[o] for o in val_loader.dataset.enlarged_candidates]
                for j, (qid, inp, position, answer) in enumerate(zip(qids,
                                                                     # idiom_targets,
                                                                     input_ids,
                                                                     # batch['option_ids'],
                                                                     batch['positions'],
                                                                     max_idx)):
                    # g = over_logits[j].cpu().numpy()
                    # top_k = np.argsort(-g)
                    # val_mrr += 1 / (1 + np.argwhere(top_k == target.item()).item())

                    example = val_loader.dataset.db[qid]
                    idiom = val_loader.dataset.id2idiom[example['idiom']]
                    # idiom = options[target.item()]
                    affection_results.append(
                        [idiom] + sentiment_logits[j].cpu().numpy().tolist()
                    )
                    if i % 1000 == 0:
                        g = over_logits[j].cpu().numpy()
                        top_k = np.argsort(-g)[:5]
                        print(qid,
                              [options[k] for k in top_k],
                              idiom)
                        # print(len(select_masks), atts.size())
                        if select_masks is not None:
                            s_masks = [select_mask[j].long().cpu().numpy().tolist() for select_mask in select_masks]
                            # s_att = atts[j].cpu().numpy().tolist()

                            # tokens = val_loader.dataset.tokenizer.convert_ids_to_tokens(inp)
                            # start = tokens.index(val_loader.dataset.tokenizer.mask_token)
                            # tokens[position:position + len(idiom)] = list(idiom)
                            tokens = val_loader.dataset.tokenizer.convert_ids_to_tokens(
                                val_loader.dataset.idiom_input_ids[qid])
                            # print(tokens, s_masks, s_att, composition_gates[j].sum())
                            print(tokens, s_masks)
                            try:
                                tree = Tree(' '.join(tokens), idiom2tree(tokens, s_masks))
                                print(TreePrettyPrinter(tree).text(unicodelines=True))
                            except:
                                pass

                        predictions = {
                            # "coarse emotion": {
                            #     "target": calo_inverse_mapping['coarse_emotion'].get(coarse_emotion_targets[j].item(),
                            #                                                          '无'),
                            #     "predictions": {calo_inverse_mapping['coarse_emotion'][k]: v for k, v in
                            #                     enumerate(coarse_emotion_logits[j].cpu().numpy().tolist())}
                            # },
                            "sentiment": {
                                "target": idioms_inverse_mapping['sentiment'].get(sentiment_targets[j].item(), '无'),
                                "predictions": {idioms_inverse_mapping['sentiment'][k]: v for k, v in
                                                enumerate(sentiment_logits[j].cpu().numpy().tolist())}
                            }
                        }
                        pprint(predictions)

                answers = max_idx.cpu().tolist()
                results.extend(zip(qids, answers))
            else:
                for j, (qid, inp, position) in enumerate(zip(qids, input_ids,
                                                             # batch['option_ids'],
                                                             batch['positions'],
                                                             )):
                    # options = [val_loader.dataset.id2idiom[o.item()] for o in option_ids]
                    example = val_loader.dataset.db[qid]
                    idiom = val_loader.dataset.id2idiom[example['idiom']]
                    affection_results.append(
                        [idiom] + sentiment_logits[j].cpu().numpy().tolist()
                    )
                    if i % 1000 == 0:
                        print(qid,
                              idiom)
                        if select_masks is not None:
                            s_masks = [select_mask[j].long().cpu().numpy().tolist() for select_mask in select_masks]
                            tokens = val_loader.dataset.tokenizer.convert_ids_to_tokens(
                                val_loader.dataset.idiom_input_ids[qid])
                            # print(tokens, s_masks, s_att, composition_gates[j].sum())
                            print(tokens, s_masks)
                            try:
                                tree = Tree(' '.join(tokens), idiom2tree(tokens, s_masks))
                                print(TreePrettyPrinter(tree).text(unicodelines=True))
                            except:
                                pass

                        predictions = {
                            "sentiment": {
                                "target": idioms_inverse_mapping['sentiment'].get(sentiment_targets[j].item(), '无'),
                                "predictions": {idioms_inverse_mapping['sentiment'][k]: v for k, v in
                                                enumerate(sentiment_logits[j].cpu().numpy().tolist())}
                            }
                        }
                        pprint(predictions)

            n_ex += len(qids)
            tq.update(len(qids))

    if results:
        out_file = f'{opts.output_dir}/results/{split}_results_{global_step}_rank{opts.rank}.csv'
        with open(out_file, 'w') as f:
            for id_, ans in results:
                f.write(f'{id_},{ans}\n')

    header = ['idiom'] + get_header('sentiment')
    if affection_results:
        out_file = f'{opts.output_dir}/results/{split}_affection_results_{global_step}_rank{opts.rank}.csv'
        pd.DataFrame(affection_results, columns=header).to_csv(out_file)

    val_loss = sum(all_gather_list(val_loss))
    val_mrr = sum(all_gather_list(val_mrr))

    val_sentiment_score = sum(all_gather_list(sentiment_score))

    n_ex = sum(all_gather_list(n_ex))
    tot_time = time() - st

    val_loss /= n_ex
    val_mrr = val_mrr / n_ex
    val_sentiment_score = val_sentiment_score / n_ex

    if results:
        out_file = f'{opts.output_dir}/results/{split}_results_{global_step}.csv'
        if not os.path.isfile(out_file):
            with open(out_file, 'wb') as g:
                for f in glob.glob(f'{opts.output_dir}/results/{split}_results_{global_step}_rank*.csv'):
                    shutil.copyfileobj(open(f, 'rb'), g)

        sum(all_gather_list(opts.rank))

        txt_db = os.path.join('/txt',
                              intermediate_dir(opts.pretrained_model_name_or_path),
                              getattr(opts, f'{split}_txt_db'))
        val_acc = judge(out_file, f'{txt_db}/answer.csv')

    if opts.rank == 0:
        results_files = glob.glob(f'{opts.output_dir}/results/{split}_affection_results_{global_step}_rank*.csv')
        new_affection_results_df = pd.concat(map(pd.read_csv, results_files))
        idiom_num = new_affection_results_df['idiom'].unique().size
        idiom_wise_accs = {}
        for item in new_affection_results_df.groupby('idiom').mean().reset_index().to_dict(orient='records'):
            idiom = item['idiom']
            idiom_id = val_loader.dataset.vocab[idiom]
            for sub_type in ['sentiment']:
                d = {k: v for k, v in item.items() if k.startswith(sub_type)}
                key = max(d, key=d.get)
                _, pred = key.rsplit('_', 1)
                target = val_loader.dataset.sentiments[idiom_id]
                idiom_wise_accs.setdefault(sub_type, 0)
                idiom_wise_accs[sub_type] += (int(pred) == target) / idiom_num * 100

        val_acc = val_sentiment_score

        val_log = {f'{split}/loss': val_loss,
                   f'{split}/acc': val_acc,
                   f'{split}/sentiment': val_sentiment_score * 100,
                   f'{split}/mrr': val_mrr,
                   f'{split}/ex_per_s': n_ex / tot_time}

        for k, v in idiom_wise_accs.items():
            val_log[f'{split}/{k}'] = v

        LOGGER.info(f"validation finished in {int(tot_time)} seconds, \n"
                    # f"coarse emotion score: {val_coarse_emotion_score * 100:.2f}, \n"
                    f"sentiment score: {val_sentiment_score * 100:.2f}, \n"
                    f"score: {val_acc * 100:.2f}, \n"
                    f"idiom-wise score: {idiom_wise_accs}, "
                    f"mrr: {val_mrr:.3f}")
        return val_log