def main():
    #training_args = GlueTraingArgs(do_train=True)
    data_args_task0 = GlueDataArgs(task_name=task0)
    data_args_task1 = GlueDataArgs(task_name=task1)

    if use_gpu:
        print("Training on GPU.")

    # logging
    log_format = '[%(asctime)s] %(message)s'
    logging.basicConfig(stream=sys.stdout,
                        level=logging.INFO,
                        format=log_format,
                        datefmt='%d %I:%M:%S')
    t = time.time()
    local_time = time.localtime(t)
    if not os.path.exists('./log'):
        os.mkdir('./log')
    fh = logging.FileHandler(
        os.path.join('log/train-{}{:02}{}'.format(local_time.tm_year % 2000,
                                                  local_time.tm_mon, t)))
    fh.setFormatter(logging.Formatter(log_format))
    logging.getLogger().addHandler(fh)

    logger.info("Tasks:" + task0 + "," + task1)

    config_task0 = BertConfig.from_pretrained(
        bert_path,
        num_labels=glue_tasks_num_labels[data_args_task0.task_name],
        finetuning_task=data_args_task0.task_name,
        cache_dir=cache_dir)

    config_task1 = BertConfig.from_pretrained(
        bert_path,
        num_labels=glue_tasks_num_labels[data_args_task1.task_name],
        finetuning_task=data_args_task1.task_name,
        cache_dir=cache_dir)

    # Model Prepare, The Bert Model has loaded the pretrained model,
    # and these downstream structures are initialized randomly.
    # TODO: Adding Seed for random.  referee: Trainer.train()

    if use_gpu:
        model_Bert = BertModel.from_pretrained(bert_path,
                                               return_dict=True).cuda()
        model_task0 = SequenceClassification(config_task0).cuda()
        model_task1 = SequenceClassification(config_task1).cuda()
    else:
        model_Bert = BertModel.from_pretrained(bert_path, return_dict=True)
        model_task0 = SequenceClassification(config_task0)
        model_task1 = SequenceClassification(config_task1)

    # print(model_Bert)
    # print(model_task0)
    # print(model_task1)

    # return
    # Data prepare
    tokenizer = BertTokenizer.from_pretrained(bert_path, cache_dir=cache_dir)
    data_iterator_train_task0 = DataIterator(data_args_task0,
                                             tokenizer=tokenizer,
                                             mode="train",
                                             cache_dir=cache_dir,
                                             batch_size=batch_size)
    data_iterator_train_task1 = DataIterator(data_args_task1,
                                             tokenizer=tokenizer,
                                             mode="train",
                                             cache_dir=cache_dir,
                                             batch_size=batch_size)
    data_iterator_eval_task0 = DataIterator(data_args_task0,
                                            tokenizer=tokenizer,
                                            mode="dev",
                                            cache_dir=cache_dir,
                                            batch_size=batch_size)
    data_iterator_eval_task1 = DataIterator(data_args_task1,
                                            tokenizer=tokenizer,
                                            mode="dev",
                                            cache_dir=cache_dir,
                                            batch_size=batch_size)
    logger.info("*** DataSet Ready ***")

    # data0 = data_iterator_train_task0.next()
    # print(data0)

    # input_ids0=data0['input_ids']
    # attention_mask0=data0['attention_mask']
    # token_type_ids0=data0['token_type_ids']
    # label0=data0['labels']

    # print(input_ids0)
    # print(input_ids0.size())
    # print(input_ids0.type())
    # print(attention_mask0)
    # print(attention_mask0.size())
    # print(attention_mask0.type())
    # print(token_type_ids0)
    # print(token_type_ids0.size())
    # print(token_type_ids0.type())
    # print(label0)
    # print(label0.size())
    # print(label0.type())

    # Optimizer and lr_scheduler
    opt_bert = torch.optim.AdamW(model_Bert.parameters(), lr=learning_rate)
    opt_task0 = torch.optim.AdamW(model_task0.parameters(), lr=learning_rate)
    opt_task1 = torch.optim.AdamW(model_task1.parameters(), lr=learning_rate)

    metrics_task0 = ComputeMetrics(data_args_task0)
    metrics_task1 = ComputeMetrics(data_args_task1)

    iterations = (epochs * len(data_iterator_train_task1) // batch_size) + 1
    print(iterations)
    scheduler = torch.optim.lr_scheduler.LambdaLR(
        opt_bert, lambda step: (1.0 - step / iterations))
    all_iters = 0

    for i in range(1, iterations + 1):

        all_iters += 1
        model_Bert.train()
        model_task0.train()
        model_task1.train()
        data0 = data_iterator_train_task0.next()
        data1 = data_iterator_train_task1.next()

        if use_gpu:
            input_ids0 = data0['input_ids'].cuda()
            attention_mask0 = data0['attention_mask'].cuda()
            token_type_ids0 = data0['token_type_ids'].cuda()
            label0 = data0['labels'].cuda()
            input_ids1 = data1['input_ids'].cuda()
            attention_mask1 = data1['attention_mask'].cuda()
            token_type_ids1 = data1['token_type_ids'].cuda()
            label1 = data1['labels'].cuda()
        else:
            input_ids0 = data0['input_ids']
            attention_mask0 = data0['attention_mask']
            token_type_ids0 = data0['token_type_ids']
            label0 = data0['labels']
            input_ids1 = data1['input_ids']
            attention_mask1 = data1['attention_mask']
            token_type_ids1 = data1['token_type_ids']
            label1 = data1['labels']

        output_inter0 = model_Bert(input_ids=input_ids0,
                                   attention_mask=attention_mask0,
                                   token_type_ids=token_type_ids0,
                                   return_dict=True)
        output_inter1 = model_Bert(input_ids=input_ids1,
                                   attention_mask=attention_mask1,
                                   token_type_ids=token_type_ids1,
                                   return_dict=True)

        loss0 = model_task0(input=output_inter0, labels=label0)[0]
        loss1 = model_task1(input=output_inter1, labels=label1)[0]

        loss = loss0 + loss1

        # balance the losses of sub-tasks
        ratio = loss0 / loss1
        weight0 = (2 * ratio) / (1 + ratio)
        weight1 = 2 - weight0
        loss = loss0 * weight0 + loss1 * weight1

        printInfo = 'TOTAL/Train {}/{} - lr:{}, sl={:.6f}, l0/w0-{:.6f}/{:.6f}, l1/w1-{:.6f}/{:.6f}'.format(
            all_iters, iterations, scheduler.get_lr(), loss, loss0, weight0,
            loss1, weight1)
        logging.info(printInfo)

        # print(loss)
        # print(all_iters)

        opt_bert.zero_grad()
        opt_task0.zero_grad()
        opt_task1.zero_grad()
        # loss0.backward()
        loss.backward()

        opt_bert.step()
        opt_task0.step()
        opt_task1.step()

        scheduler.step()

        if (i % eval_interval == 0):
            evaluate(model_Bert, model_task0, data_iterator_eval_task0,
                     metrics_task0)
            evaluate(model_Bert, model_task1, data_iterator_eval_task1,
                     metrics_task1)

    evaluate(model_Bert, model_task0, data_iterator_eval_task0, metrics_task0)
    evaluate(model_Bert, model_task1, data_iterator_eval_task1, metrics_task1)

    # Saving models
    model_Bert.save_pretrained(model_save_dir + "main")
    model_task0.save_pretrained(model_save_dir + "task0")
    model_task1.save_pretrained(model_save_dir + "task1")
def main():

    ntasks = len(tasks)

    data_args = list()
    configuration = list()
    sub_models = list()
    datasets = list()
    # train_iter = list()
    # dev_iter = list()
    # test_iter = list()
    sub_optimizer = list()
    metrics = list()
    tokenizer = DistilBertTokenizer.from_pretrained(bert_path,
                                                    cache_dir=cache_dir)

    for i in range(ntasks):
        logger.info("Tasks:" + tasks[i])
        data_args.append(GlueDataArgs(task_name=tasks[i]))
        configuration.append(
            DistilBertConfig.from_pretrained(
                bert_path,
                num_labels=glue_tasks_num_labels[tasks[i].lower()],
                finetuning_task=data_args[i].task_name,
                cache_dir=cache_dir))
        if use_gpu:
            sub_models.append(SequenceClassification(configuration[i]).cuda())
        else:
            sub_models.append(SequenceClassification(configuration[i]))

        datasets.append(
            GlueDataSets(data_args[i],
                         tokenizer=tokenizer,
                         cache_dir=cache_dir))
        sub_optimizer.append(
            torch.optim.AdamW(sub_models[i].parameters(), lr=learning_rate_0))
        metrics.append(ComputeMetrics(data_args[i]))
        logger.info("*** DataSet Ready ***")

    if use_gpu:
        Bert_model = DistilBertModel.from_pretrained(bert_path,
                                                     return_dict=True).cuda()
    else:
        Bert_model = DistilBertModel.from_pretrained(bert_path,
                                                     return_dict=True)

    bert_optimizer = torch.optim.AdamW(Bert_model.parameters(),
                                       lr=learning_rate_0)

    # balaned dataset
    train_num = list()
    for i in range(ntasks):
        train_num.append(datasets[i].length("train"))
    #train_nummax =
    #train_num = [x/train_nummax for x in train_num]
    print(train_num)
    iterations = (epochs * max(train_num) // bs) + 1
    #print(iterations)

    sub_scheduler = list()
    for i in range(ntasks):
        sub_scheduler.append(
            torch.optim.lr_scheduler.LambdaLR(
                sub_optimizer[i], lambda step: (1.0 - step / iterations))
        )  #if step <= frozen else learning_rate_1)
    Bert_scheduler = torch.optim.lr_scheduler.LambdaLR(
        bert_optimizer, lambda step:
        (1.0 - step / iterations))  # if step <= frozen else learning_rate_1

    # datasets[i].dataloader("train", batch_size_train[i])
    train_iter = list()
    for i in range(ntasks):
        train_iter.append(
            GlueIterator(datasets[i].dataloader("train", batch_size_train[i])))

    for i in range(1, iterations + 1):

        if i > frozen:
            for p in Bert_model.parameters():
                p.requires_grad = True
            Bert_model.train()
        elif i == frozen:
            for p in Bert_model.parameters():
                p.requires_grad = True
            Bert_model.train()
            logging.info("#####################################")
            logging.info("Release the Traing of the Main Model.")
            logging.info("#####################################")
        else:
            for p in Bert_model.parameters():
                p.requires_grad = False
            Bert_model.eval()

        losses = list()
        loss_rates = list()

        for j in range(ntasks):
            sub_models[j].train()
            data = train_iter[j].next()

            if use_gpu:
                input_ids = data['input_ids'].cuda()
                attention_mask = data['attention_mask'].cuda()
                #token_type_ids=data['token_type_ids'].cuda()
                label = data['labels'].cuda()
            else:
                input_ids = data['input_ids']
                attention_mask = data['attention_mask']
                #token_type_ids=data['token_type_ids']
                label = data['labels']

            output_inter = Bert_model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                return_dict=True)  # token_type_ids=token_type_ids,
            losses.append(sub_models[j](input=output_inter, labels=label)[0])

        losssum = sum(losses).item()
        for j in range(ntasks):
            loss_rates.append(losses[j].item() / losssum)

        loss = 0
        printInfo = 'TOTAL/Train {}/{}, lr:{}'.format(i, iterations,
                                                      Bert_scheduler.get_lr())
        for j in range(ntasks):
            loss += losses[j] * batch_size_train[j]  # * loss_rates[j]
            printInfo += ', loss{}-{:.6f}'.format(j, losses[j])
            sub_optimizer[j].zero_grad()

        logging.info(printInfo)

        if i > frozen:
            bert_optimizer.zero_grad()
        loss.backward()

        if i > frozen:
            bert_optimizer.step()

        for j in range(ntasks):
            sub_optimizer[j].step()
            # sub_scheduler[j].step()

        # Bert_scheduler.step()

        if (i % eval_interval == 0):
            evaluate(Bert_model, sub_models, datasets, batch_size_val, metrics,
                     ntasks)
            save_models(Bert_model, sub_models, ntasks, i)

    evaluate(Bert_model, sub_models, datasets, batch_size_val, metrics, ntasks)
    save_models(Bert_model, sub_models, ntasks, iterations)
def main():
    
    
    
    ntasks = len(tasks)
    
    data_args = list()
    configuration = list()
    sub_models = list()
    train_iter = list()
    dev_iter = list()
    test_iter = list()
    sub_optimizer = list()
    metrics = list()
    tokenizer = DistilBertTokenizer.from_pretrained(bert_path, cache_dir=cache_dir)
    
    for i in range(ntasks):    
        logger.info("Tasks:" + tasks[i])
        data_args.append(GlueDataArgs(task_name=tasks[i]))
        configuration.append(DistilBertConfig.from_pretrained(bert_path, num_labels=glue_tasks_num_labels[data_args[i].task_name], 
                                finetuning_task=data_args[i].task_name, cache_dir = cache_dir))
        if use_gpu:
            sub_models.append(SequenceClassification(configuration[i]).cuda())
        else: 
            sub_models.append(SequenceClassification(configuration[i]))
            
        train_iter.append(DataIterator(data_args[i], tokenizer=tokenizer, mode="train", cache_dir=cache_dir, batch_size=batch_size[i]))
        dev_iter.append(DataIterator(data_args[i], tokenizer=tokenizer, mode="dev", cache_dir=cache_dir, batch_size=batch_size_val[i]))
        
        sub_optimizer.append(torch.optim.AdamW(sub_models[i].parameters(), lr=learning_rate))
        
        metrics.append(ComputeMetrics(data_args[i]))
        
        logger.info("*** DataSet Ready ***")
    
    if use_gpu:
        Bert_model = DistilBertModel.from_pretrained(bert_path, return_dict=True).cuda()
    else:
        Bert_model = DistilBertModel.from_pretrained(bert_path, return_dict=True)
    
    bert_optimizer = torch.optim.AdamW(Bert_model.parameters(), lr=learning_rate)
    
    
    # balaned dataset
    train_num = list()    
    for i in range(ntasks):
        train_num.append(len(train_iter[i]))
    #train_nummax = 
    #train_num = [x/train_nummax for x in train_num]
    #print(train_num)
    iterations = (epochs * max(train_num) // bs) + 1
    #print(iterations)
    
    sub_scheduler = list()
    for i in range(ntasks):
        sub_scheduler.append(torch.optim.lr_scheduler.LambdaLR(sub_optimizer[i], lambda step: (1.0-step/iterations)))    
    Bert_scheduler = torch.optim.lr_scheduler.LambdaLR(bert_optimizer, lambda step: (1.0-step/iterations))
    
    
    for i in range(1, iterations+1):
        
        
        if iterations > frozen:
            for p in Bert_model.parameters():
                p.requires_grad = True
            Bert_model.train()
            
        else:
            for p in Bert_model.parameters():
                p.requires_grad = False
            Bert_model.eval()
        
        losses=list()
        for j in range(ntasks):
            sub_models[j].train()
            data = train_iter[j].next()
            
            if use_gpu:
                input_ids=data['input_ids'].cuda()
                attention_mask=data['attention_mask'].cuda()
                #token_type_ids=data['token_type_ids'].cuda()
                label=data['labels'].cuda()
            else:
                input_ids=data['input_ids']
                attention_mask=data['attention_mask']
                #token_type_ids=data['token_type_ids']
                label=data['labels']
                
            output_inter = Bert_model(input_ids=input_ids, attention_mask=attention_mask, return_dict=True) # token_type_ids=token_type_ids,
            losses.append(sub_models[j](input=output_inter, labels=label)[0])
   
        
        loss = 0
        printInfo = 'TOTAL/Train {}/{}, lr:{}'.format(i, iterations, Bert_scheduler.get_lr())
        for j in range(ntasks):
            loss += losses[j] * batch_size[j]
            printInfo += ', loss{}-{:.6f}'.format(j,losses[j])
            sub_optimizer[j].zero_grad()
            
        logging.info(printInfo) 
        
        if iterations > frozen:
            bert_optimizer.zero_grad()
        loss.backward()
        
        if iterations > frozen:
            bert_optimizer.step()
            
        for j in range(ntasks):
            sub_optimizer[j].step()
            sub_scheduler[j].step()
        
        if iterations > frozen:
            Bert_scheduler.step()
        
        if (i % eval_interval == 0):
            for j in range(ntasks):
                evaluate(Bert_model, sub_models[j], dev_iter[j], batch_size_val[j], metrics[j])
                sub_models[j].save_pretrained(os.path.join(model_save_dir, "{}-checkpoint-{:06}.pth.tar".format(tasks[j], i)))
            Bert_model.save_pretrained(os.path.join(model_save_dir, "{}-checkpoint-{:06}.pth.tar".format("main", i)))
    
    
    for i in range(ntasks):
        evaluate(Bert_model, sub_models[i], dev_iter[i], batch_size_val[i], metrics[i])
        sub_models[i].save_pretrained(os.path.join(model_save_dir, "{}-checkpoint-{:06}.pth.tar".format(tasks[j], iterations)))
            
    Bert_model.save_pretrained(os.path.join(model_save_dir, "{}-checkpoint-{:06}.pth.tar".format("main", iterations)))