def save_text():
    args = parse_args()
    args.data_dir = 'data/multi-lingual/'
    args.few_shot = -1
    args.device = torch.device('cpu')
    # all_langs='en zh ca de fr il jp sp sw'.split()
    all_langs = 'zh'.split()
    all_tasks = 'val test train'.split()
    for lang in all_langs:
        for task in all_tasks:
            args.lang = lang
            tokenizer = BertTokenizer.from_pretrained(
                'bert-base-multilingual-uncased',
                do_lower_case=True,
                tokenize_chinese_chars=False,
                split_puntc=False)
            dataset = supervised_load_datasets(args, tokenizer, task=task)
            eval_dataloader = DataLoader(dataset,
                                         batch_size=1,
                                         collate_fn=collate_fn)
            all_sents = []
            with open(f'data/raw_text/{args.lang}.{task}.txt', 'w') as f:
                for batch in eval_dataloader:
                    tokens = [[
                        tokenizer.convert_ids_to_tokens(id.item())
                        for id in ids[1:len(torch.nonzero(masks)) - 1]
                    ] for ids, masks in zip(batch[0], batch[1])]
                    strings = [
                        tokenizer.convert_tokens_to_string(tks) + '\n'
                        for tks in tokens
                    ]
                    all_sents.extend(strings)
                f.writelines(all_sents)
def test(args, model, tokenizer,prefix=''):
    # test parsing performance
    test_dataset = supervised_load_datasets(args, tokenizer, task=args.task_name)
    args.per_gpu_eval_batch_size=64

    args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
    # Note that DistributedSampler samples randomly
    eval_sampler = SequentialSampler(test_dataset)
    eval_dataloader = DataLoader(test_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size, collate_fn=collate_fn)

    # Eval!
    logger.info("***** Running testing {} *****".format(prefix))
    logger.info("  Num examples = %d", len(test_dataset))
    logger.info("  Batch size = %d", args.eval_batch_size)
    nb_eval_steps = 0
    model.eval()
    all_pred_trees=[]
    all_std_trees=[]
    for batch in tqdm(eval_dataloader, desc="Evaluate parsing"):
        input_ids,attention_mask,bpe_ids,_,trees,nltk_trees=batch
        inputs = input_ids.to(args.device)
        all_std_trees.extend(trees)
        
        with torch.no_grad():
            sents = [[tokenizer.convert_ids_to_tokens(i.item()) for i in ids[1:len(torch.nonzero(masks))-1]]
                for ids, masks in zip(inputs,attention_mask)] # detokenize
            sents=[tokenizer.convert_tokens_to_string(s).split()
                     for s in sents] # remove bpe
            pred_trees, all_attens, all_keys, all_querys = model.parse(
                inputs, attention_mask,bpe_ids, sents, args.rm_bpe_from_a,args.decoding, inner_only=args.inner_only)
            all_pred_trees.extend(pred_trees)
            # for i,(a,s) in enumerate(zip(all_attens, sents)):
            #     if ' '.join(s).startswith('under an agreement signed'):
            #         visual_attention([np.exp(a),],[s,],'attention-frozen.svg')
            #         pass
            # visual_hiddens(query, key, sents)
            # visual_hiddens(all_querys, all_keys, sents)
            # visual_attention(all_attens, sents)
            # print(trees)
        # eval step
        f1_list=[]
        for pred_tree, std_tree in zip(pred_trees, trees):
            prec, reca, f1 = comp_tree(pred_tree, std_tree)
            f1_list.append(f1)
        print(sum(f1_list)/len(f1_list))
        nb_eval_steps += 1
    eval_res=evalb(all_pred_trees,all_std_trees) # eval all

    print(eval_res)
    checkpoint_dir='/'.join(re.split('/*',args.checkpoint_path)[:-1])
    output_eval_file = os.path.join(checkpoint_dir, "parse_results.txt")
    with open(output_eval_file, "a") as writer:
        writer.write("***** parse results {} *****\n".format(prefix))
        for key in sorted(eval_res.keys()):
            writer.write("%s = %s\n" % (key, str(eval_res[key])))

    return eval_res
def eval(args, model, tokenizer,prefix=""):
    # Loop to handle MNLI double evaluation (matched, mis-matched)
    eval_output_dir = args.output_dir
    eval_dataset = supervised_load_datasets(args, tokenizer, 'val')

    if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]:
        os.makedirs(eval_output_dir)
    args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
    eval_sampler = SequentialSampler(eval_dataset)
    eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size, collate_fn=collate_fn)

    # multi-gpu evaluate
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Eval!
    logger.info("***** Running evaluation {} *****".format(prefix))
    logger.info("  Num examples = %d", len(eval_dataset))
    logger.info("  Batch size = %d", args.eval_batch_size)
    avg_acc,avg_f1,eval_loss = 0.,0.,0.
    nb_eval_steps = 0
    model.eval()
    for batch in tqdm(eval_dataloader, desc="Evaluating"):
        input_ids,attention_mask,bpe_ids,labels,_,_=batch
        inputs = input_ids.to(args.device)
        with torch.no_grad():
            if args.is_supervised:
                loss,acc,f1 = model(inputs, attention_mask,bpe_ids, labels, inner_only=args.inner_only)
            else:
                loss, acc = model(inputs, attention_mask)
            eval_loss += loss.mean().item()
            avg_acc+=acc
            avg_f1+=f1
        nb_eval_steps += 1

    eval_loss = eval_loss / nb_eval_steps
    avg_acc=avg_acc/nb_eval_steps
    avg_f1=avg_f1/nb_eval_steps

    result = {
        "loss": eval_loss,
        "acc": avg_acc,
        'f1': avg_f1
    }

    output_eval_file = os.path.join(eval_output_dir, "eval_results.txt")
    with open(output_eval_file, "a") as writer:
        logger.info("***** Eval results {} *****".format(prefix))
        writer.write("***** Eval results {} *****".format(prefix))
        for key in sorted(result.keys()):
            logger.info("  %s = %s", key, str(result[key]))
            writer.write("%s = %s\n" % (key, str(result[key])))

    return result
def unsupervised_parsing(args, model, tokenizer, prefix=""):
    """ 
    return attention of every batch
        -get input
        -get output and ground truth
        -compute evaluate score
    """
    # load data
    eval_outputs_dirs = args.output_dir

    eval_dataset = supervised_load_datasets(args, tokenizer, args.task_name)
    args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
    eval_dataloader = DataLoader(eval_dataset,
                                 batch_size=args.eval_batch_size,
                                 collate_fn=collate_fn)
    # positional_mask = get_pos_mask(
    #     args.max_seq_length, scale=0.1).to(args.device)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Eval!
    pred_tree_list, targ_tree_list, prec_list, reca_list, f1_list = [], [], [], [], []
    corpus_sys, corpus_ref = {}, {}
    sample_count = 0
    mean = torch.zeros(12, 12)
    for i, batch in enumerate(tqdm(eval_dataloader, desc="parsing")):
        model.eval()
        with torch.no_grad():
            input_ids, attention_mask, bpe_ids, _, tgt_trees, nltk_trees = batch
            tokens = [[
                tokenizer.convert_ids_to_tokens(id.item())
                for id in ids[1:len(torch.nonzero(masks)) - 1]
            ] for ids, masks in zip(batch[0], batch[1])]
            if args.remove_bpe:
                strings = [
                    tokenizer.convert_tokens_to_string(tks).split()
                    for tks in tokens
                ]
            else:
                strings = tokens
            inputs = {
                'input_ids': input_ids,
                'attention_mask': attention_mask,
            }
            outputs = model(**inputs)
            _, _, hiddens, attentions = outputs  # (layer_nb,bsz,head,m,m)
            hiddens = hiddens[args.layer_nb]
            # attentions=attentions[args.layer_nb][:,args.head_nb]
            # random_layer,random_head=random.randint(0,11),random.randint(0,11)
            # attentions=attentions[random_layer][:,random_head]
            # attentions=(attentions[5][:,10]+attentions[10][:,6]+attentions[11][:,0])/3
            # attentions2 = attentions[7][:,10]
            # scores=[(s1+s2)/2 for s1,s2 in zip(scores1,scores2)]

            pred_trees = []
            bsz = len(hiddens)
            for j in range(bsz):
                heads = ((5, 10), (10, 6), (11, 0))
                avg_scores = []
                for layer, head in heads:
                    a = attentions[layer][:, head][j]
                    s = strings[j]
                    if args.remove_bpe:
                        a = remove_bpe_from_attention(bpe_ids[j], a)
                        seq_len = len(attention_mask[j].nonzero()) - 2 - len(
                            bpe_ids[j])
                    else:
                        seq_len = len(attention_mask[j].nonzero()) - 2
                    bpe_mask = torch.tensor(['##' in w
                                             for w in s]).to(args.device)
                    a = a[1:1 + seq_len, 1:1 + seq_len]
                    scores = split_score(None, a, bpe_mask,
                                         args.relevance_type, args.norm,
                                         args.inner_only)
                    avg_scores.append(scores)
                scores = sum(avg_scores) / len(avg_scores)
                if args.decoding == 'cky':
                    tree = parse_cyk(scores, s)
                else:
                    tree = parse_greedy(scores, s)
                pred_trees.append(tree)
                # visual_attention([a.cpu().numpy(),],[strings[i],],'attention-unsupervised.png')

            # evaluate
            for j, (pred_tree, tgt_tree, nltk_tree) in enumerate(
                    zip(pred_trees, tgt_trees, nltk_trees)):
                prec, reca, f1 = comp_tree(pred_tree, tgt_tree)
                prec_list.append(prec)
                reca_list.append(reca)
                f1_list.append(f1)
                # if f1<0.45 and len(str(tgt_tree))<100:
                #     logger.info(f'f1:{f1}\nstd tree:{tgt_tree}\npred tree:{pred_tree}')
                corpus_sys[sample_count] = MRG(pred_tree)
                corpus_ref[sample_count] = MRG_labeled(nltk_tree)
                sample_count += 1
            pred_tree_list += pred_trees
            targ_tree_list += tgt_trees
            # print(f'f1 score:{sum(f1_list)/len(f1_list)}')

    logger.info('-' * 80)
    np.set_printoptions(precision=4)
    print(
        '-' * 20, 'model:{}---layer nb:{}---head nb:{}'.format(
            args.model_name_or_path, args.layer_nb, args.head_nb) + '-' * 20)
    print('Mean Prec:',
          sum(prec_list) / len(prec_list), ', Mean Reca:',
          sum(reca_list) / len(reca_list), ', Mean F1:',
          sum(f1_list) / len(f1_list))
    print('Number of sentence: %i' % sample_count)
    # correct, total = corpus_stats_labeled(corpus_sys, corpus_ref)
    # print(correct)
    # print(total)
    # print('ADJP:', correct['ADJP'], total['ADJP'])
    # print('NP:', correct['NP'], total['NP'])
    # print('PP:', correct['PP'], total['PP'])
    # print('INTJ:', correct['INTJ'], total['INTJ'])
    # print(corpus_average_depth(corpus_sys))

    result = evalb(pred_tree_list, targ_tree_list)
    print(
        f'task:{args.task_name} model:{args.model_name_or_path}, layer nb:{args.layer_nb}, \
head nb:{args.head_nb}, seed: {args.seed}, f1:{result["f1"]}, f1<10:{result["f1_10"]}',
        file=file_to_print,
        flush=True)
def train(args, model, tokenizer, checkpoint=None):
    if args.local_rank in [-1, 0]:
        tb_writer = SummaryWriter(args.tensorboard_dir)
    dataset=supervised_load_datasets(args, tokenizer)
    args.batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    train_sampler = RandomSampler(dataset)
    dataloader = DataLoader(
        dataset, batch_size=args.batch_size, collate_fn=collate_fn, sampler=train_sampler)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // max(len(dataloader) // args.gradient_accumulation_steps, 1) + 1
    else:
        t_total = len(dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
        {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]
    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
    if checkpoint is not None:
        if isinstance(checkpoint['optimizer'],dict): optimizer.load_state_dict(checkpoint['optimizer'])
        else: optimizer=checkpoint['optimizer']
    scheduler = get_linear_schedule_with_warmup(optimizer, warmup_steps=args.warmup_steps, t_total=t_total)
    if checkpoint is not None: scheduler.load_state_dict(checkpoint['schedule'])
    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
        model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
    logger.info("  Total train batch size (w. parallel, distributed & accumulation) = %d",
                   args.batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
    logger.info("  Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)
    global_step = checkpoint['step'] if checkpoint is not None else 0
    tr_loss, logging_loss, tr_acc, tr_f1 = 0.0, 0.0, 0.0, 0.0
    model.zero_grad()
    train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
    set_seed(args)  # Added here for reproducibility (even between python 2 and 3)

    for _ in train_iterator:
        epoch_iterator = tqdm(dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
        iter_samples=0
        for step, batch in enumerate(epoch_iterator):
            model.train()
            if args.is_supervised:
                input_ids,attention_mask,bpe_ids,labels=batch
                iter_samples+=len(input_ids)
                if args.few_shot>0 and iter_samples>args.few_shot: break # few shot training
                inputs = input_ids.to(args.device)
                loss, acc, f1 = model(inputs, attention_mask, bpe_ids, labels, inner_only=args.inner_only)
            else:
                inputs = input_ids.to(args.device)
                loss, acc = model(inputs, attention_mask)
            
            if args.n_gpu > 1:
                loss = loss.mean()  # mean() to average on multi-gpu parallel training
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps
                acc = acc / args.gradient_accumulation_steps
                f1 = f1 / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            tr_loss += loss.item()
            tr_acc += acc
            tr_f1 += f1
            if (step + 1) % args.gradient_accumulation_steps == 0:
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1

                if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    # Log metrics TODO eval
                    if args.local_rank == -1 and args.evaluate_during_training:  # Only evaluate when single GPU otherwise metrics may not average well
                        results = eval(args, model, tokenizer, prefix=str(global_step))
                        for key, value in results.items():
                            tb_writer.add_scalar('eval_{}'.format(key), value, global_step)
                    tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step)
                    tb_writer.add_scalar('loss', (tr_loss - logging_loss)/args.logging_steps, global_step)
                    logger.info('step: {} ,loss: {}, acc: {}, f1: {}'.format(
                        global_step, (tr_loss - logging_loss)/args.logging_steps, tr_acc/args.logging_steps, tr_f1/args.logging_steps
                    ))
                    logging_loss = tr_loss
                    tr_acc=0.
                    tr_f1=0.
                if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
                    checkpoint_prefix = 'checkpoint'
                    # Save model checkpoint
                    output_path = os.path.join(args.output_dir, '{}-{}'.format(checkpoint_prefix, global_step))
                    checkpoint={'key':model.key_proj.state_dict(),
                                'query':model.query_proj.state_dict(), 
                                'schedule': scheduler.state_dict(),
                                'optimizer':optimizer.state_dict(), 
                                'step': global_step, 
                                'args': args}
                    if (not args.frozen_bert) or (not args.is_supervised): checkpoint['bert']=model.bert.state_dict()
                    if hasattr(model, 'label_predictor'): checkpoint['label_predictor']=model.label_predictor.state_dict()
                    torch.save(checkpoint, output_path)
                    logger.info("Saving model checkpoint to %s", output_path)

                if args.max_steps > 0 and global_step > args.max_steps:
                    epoch_iterator.close()
                    break


        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

    if args.local_rank in [-1, 0]:
        tb_writer.close()

    return global_step, tr_loss / global_step