예제 #1
0
def train(args):
    processor = data_utils.AscProcessor()
    label_list = processor.get_labels()
    tokenizer = ABSATokenizer.from_pretrained(
        modelconfig.MODEL_ARCHIVE_MAP[args.bert_model])
    train_examples = processor.get_train_examples(args.data_dir)
    num_train_steps = int(
        len(train_examples) / args.train_batch_size) * args.num_train_epochs

    train_features = data_utils.convert_examples_to_features(
        train_examples, label_list, args.max_seq_length, tokenizer, "asc")
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_examples))
    logger.info("  Batch size = %d", args.train_batch_size)
    logger.info("  Num steps = %d", num_train_steps)

    all_input_ids = torch.tensor([f.input_ids for f in train_features],
                                 dtype=torch.long)
    all_segment_ids = torch.tensor([f.segment_ids for f in train_features],
                                   dtype=torch.long)
    all_input_mask = torch.tensor([f.input_mask for f in train_features],
                                  dtype=torch.long)
    all_label_ids = torch.tensor([f.label_id for f in train_features],
                                 dtype=torch.long)

    train_data = TensorDataset(all_input_ids, all_segment_ids, all_input_mask,
                               all_label_ids)

    train_sampler = RandomSampler(train_data)
    train_dataloader = DataLoader(train_data,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size)

    # >>>>> validation
    if args.do_valid:
        valid_examples = processor.get_dev_examples(args.data_dir)
        valid_features = data_utils.convert_examples_to_features(
            valid_examples, label_list, args.max_seq_length, tokenizer, "asc")
        valid_all_input_ids = torch.tensor(
            [f.input_ids for f in valid_features], dtype=torch.long)
        valid_all_segment_ids = torch.tensor(
            [f.segment_ids for f in valid_features], dtype=torch.long)
        valid_all_input_mask = torch.tensor(
            [f.input_mask for f in valid_features], dtype=torch.long)
        valid_all_label_ids = torch.tensor(
            [f.label_id for f in valid_features], dtype=torch.long)
        valid_data = TensorDataset(valid_all_input_ids, valid_all_segment_ids,
                                   valid_all_input_mask, valid_all_label_ids)

        logger.info("***** Running validations *****")
        logger.info("  Num orig examples = %d", len(valid_examples))
        logger.info("  Num split examples = %d", len(valid_features))
        logger.info("  Batch size = %d", args.train_batch_size)

        valid_sampler = SequentialSampler(valid_data)
        valid_dataloader = DataLoader(valid_data,
                                      sampler=valid_sampler,
                                      batch_size=args.train_batch_size)

        best_valid_loss = float('inf')
        valid_losses = []
    # <<<<< end of validation declaration

    model = BertForSequenceClassification.from_pretrained(
        modelconfig.MODEL_ARCHIVE_MAP[args.bert_model],
        num_labels=len(label_list))
    model.cuda()
    # Prepare optimizer
    param_optimizer = [(k, v) for k, v in model.named_parameters()
                       if v.requires_grad == True]
    param_optimizer = [n for n in param_optimizer if 'pooler' not in n[0]]
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]
    t_total = num_train_steps
    optimizer = BertAdam(optimizer_grouped_parameters,
                         lr=args.learning_rate,
                         warmup=args.warmup_proportion,
                         t_total=t_total)

    global_step = 0
    model.train()
    for _ in range(args.num_train_epochs):
        for step, batch in enumerate(train_dataloader):
            batch = tuple(t.cuda() for t in batch)
            input_ids, segment_ids, input_mask, label_ids = batch
            loss = model(input_ids, segment_ids, input_mask, label_ids)
            loss.backward()

            lr_this_step = args.learning_rate * warmup_linear(
                global_step / t_total, args.warmup_proportion)
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr_this_step
            optimizer.step()
            optimizer.zero_grad()
            global_step += 1
            # >>>> perform validation at the end of each epoch .
        if args.do_valid:
            model.eval()
            with torch.no_grad():
                losses = []
                valid_size = 0
                for step, batch in enumerate(valid_dataloader):
                    batch = tuple(
                        t.cuda()
                        for t in batch)  # multi-gpu does scattering it-self
                    input_ids, segment_ids, input_mask, label_ids = batch
                    loss = model(input_ids, segment_ids, input_mask, label_ids)
                    losses.append(loss.data.item() * input_ids.size(0))
                    valid_size += input_ids.size(0)
                valid_loss = sum(losses) / valid_size
                logger.info("validation loss: %f", valid_loss)
                valid_losses.append(valid_loss)
            if valid_loss < best_valid_loss:
                torch.save(model, os.path.join(args.output_dir, "model.pt"))
                best_valid_loss = valid_loss
            model.train()
    if args.do_valid:
        with open(os.path.join(args.output_dir, "valid.json"), "w") as fw:
            json.dump({"valid_losses": valid_losses}, fw)
    else:
        torch.save(model, os.path.join(args.output_dir, "model.pt"))
예제 #2
0
def test(
    args
):  # Load a trained model that you have fine-tuned (we assume evaluate on cpu)
    processor = data_utils.AeProcessor()
    label_list = processor.get_labels()
    tokenizer = ABSATokenizer.from_pretrained(
        modelconfig.MODEL_ARCHIVE_MAP[args.bert_model])
    eval_examples = processor.get_test_examples(args.data_dir)
    eval_features = data_utils.convert_examples_to_features(
        eval_examples, label_list, args.max_seq_length, tokenizer, "ae")

    logger.info("***** Running evaluation *****")
    logger.info("  Num examples = %d", len(eval_examples))
    logger.info("  Batch size = %d", args.eval_batch_size)
    all_input_ids = torch.tensor([f.input_ids for f in eval_features],
                                 dtype=torch.long)
    all_segment_ids = torch.tensor([f.segment_ids for f in eval_features],
                                   dtype=torch.long)
    all_input_mask = torch.tensor([f.input_mask for f in eval_features],
                                  dtype=torch.long)
    all_label_ids = torch.tensor([f.label_id for f in eval_features],
                                 dtype=torch.long)
    eval_data = TensorDataset(all_input_ids, all_segment_ids, all_input_mask,
                              all_label_ids)
    # Run prediction for full data
    eval_sampler = SequentialSampler(eval_data)
    eval_dataloader = DataLoader(eval_data,
                                 sampler=eval_sampler,
                                 batch_size=args.eval_batch_size)

    model = torch.load(os.path.join(args.output_dir, "model.pt"))
    model.cuda()
    model.eval()

    full_logits = []
    full_label_ids = []
    for step, batch in enumerate(eval_dataloader):
        batch = tuple(t.cuda() for t in batch)
        input_ids, segment_ids, input_mask, label_ids = batch

        with torch.no_grad():
            logits = model(input_ids, segment_ids, input_mask)

        logits = logits.detach().cpu().numpy()
        label_ids = label_ids.cpu().numpy()

        full_logits.extend(logits.tolist())
        full_label_ids.extend(label_ids.tolist())

    output_eval_json = os.path.join(args.output_dir, "predictions.json")
    with open(output_eval_json, "w") as fw:
        assert len(full_logits) == len(eval_examples)
        # sort by original order for evaluation
        recs = {}
        for qx, ex in enumerate(eval_examples):
            recs[int(ex.guid.split("-")[1])] = {
                "sentence": ex.text_a,
                "idx_map": ex.idx_map,
                "logit": full_logits[qx][1:]
            }  # skip the [CLS] tag.
        full_logits = [recs[qx]["logit"] for qx in range(len(full_logits))]
        raw_X = [recs[qx]["sentence"] for qx in range(len(eval_examples))]
        idx_map = [recs[qx]["idx_map"] for qx in range(len(eval_examples))]
        json.dump({
            "logits": full_logits,
            "raw_X": raw_X,
            "idx_map": idx_map
        }, fw)
예제 #3
0
def train(args):
    start = time.time()
    torch.cuda.empty_cache()
    epsilon = 2
    wdec = 1e-2

    processor = data_utils.E2EProcessor()
    label_list = processor.get_labels()
    # tokenizer = ABSATokenizer.from_pretrained(modelconfig.MODEL_ARCHIVE_MAP[args.albert_model])
    if args.albert_model == 'voidful/albert_chinese_base':
        tokenizer = ABSATokenizerB.from_pretrained(args.albert_model)
    else:
        tokenizer = ABSATokenizer.from_pretrained(args.albert_model)

    train_examples = processor.get_train_examples(args.data_dir)
    num_train_steps = int(
        len(train_examples) / args.train_batch_size) * args.num_train_epochs

    train_features = data_utils.convert_examples_to_features(
        train_examples, label_list, args.max_seq_length, tokenizer, "e2e")
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_examples))
    logger.info("  Batch size = %d", args.train_batch_size)
    logger.info("  Num steps = %d", num_train_steps)

    all_input_ids = torch.tensor([f.input_ids for f in train_features],
                                 dtype=torch.long)
    all_segment_ids = torch.tensor([f.segment_ids for f in train_features],
                                   dtype=torch.long)
    all_input_mask = torch.tensor([f.input_mask for f in train_features],
                                  dtype=torch.long)
    all_label_ids = torch.tensor([f.label_id for f in train_features],
                                 dtype=torch.long)

    train_data = TensorDataset(all_input_ids, all_segment_ids, all_input_mask,
                               all_label_ids)

    train_sampler = RandomSampler(train_data)
    train_dataloader = DataLoader(train_data,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size)

    #>>>>> validation
    if args.do_valid:
        valid_examples = processor.get_dev_examples(args.data_dir)
        valid_features = data_utils.convert_examples_to_features(
            valid_examples, label_list, args.max_seq_length, tokenizer, "e2e")
        valid_all_input_ids = torch.tensor(
            [f.input_ids for f in valid_features], dtype=torch.long)
        valid_all_segment_ids = torch.tensor(
            [f.segment_ids for f in valid_features], dtype=torch.long)
        valid_all_input_mask = torch.tensor(
            [f.input_mask for f in valid_features], dtype=torch.long)
        valid_all_label_ids = torch.tensor(
            [f.label_id for f in valid_features], dtype=torch.long)
        valid_data = TensorDataset(valid_all_input_ids, valid_all_segment_ids,
                                   valid_all_input_mask, valid_all_label_ids)

        logger.info("***** Running validations *****")
        logger.info("  Num orig examples = %d", len(valid_examples))
        logger.info("  Num split examples = %d", len(valid_features))
        logger.info("  Batch size = %d", args.train_batch_size)

        valid_sampler = SequentialSampler(valid_data)
        valid_dataloader = DataLoader(valid_data,
                                      sampler=valid_sampler,
                                      batch_size=args.train_batch_size)

        best_valid_loss = float('inf')
        valid_losses = []
    #<<<<< end of validation declaration

    # model = AlbertForABSA.from_pretrained(modelconfig.MODEL_ARCHIVE_MAP[args.albert_model], num_labels = len(label_list), epsilon=epsilon)
    model = AlbertForABSA.from_pretrained(args.albert_model,
                                          num_labels=len(label_list),
                                          epsilon=epsilon)

    params_total = sum(p.numel() for p in model.parameters())
    params_trainable = sum(p.numel() for p in model.parameters()
                           if p.requires_grad)
    logger.info("***** Model Properties *****")
    logger.info("  Parameters (Total): {:.2e}".format(params_total))
    logger.info("  Parameters (Trainable): {:.2e}".format(params_trainable))

    model.to(device)

    # Prepare optimizer
    param_optimizer = [(k, v) for k, v in model.named_parameters()
                       if v.requires_grad == True]
    param_optimizer = [n for n in param_optimizer if 'pooler' not in n[0]]
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        wdec
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]
    t_total = num_train_steps
    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)  #
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=int(args.warmup_proportion * t_total),
        num_training_steps=t_total)

    global_step = 0
    model.train()
    for _ in range(args.num_train_epochs):
        for step, batch in enumerate(train_dataloader):
            batch = tuple(t.to(device) for t in batch)
            input_ids, segment_ids, input_mask, label_ids = batch

            _loss, adv_loss = model(input_ids, segment_ids, input_mask,
                                    label_ids)
            loss = _loss + adv_loss
            loss.backward()

            lr_this_step = args.learning_rate * warmup_linear(
                global_step / t_total, args.warmup_proportion)
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr_this_step
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
            global_step += 1
            #>>>> perform validation at the end of each epoch .
        if args.do_valid:
            model.eval()
            with torch.no_grad():
                losses = []
                valid_size = 0
                for step, batch in enumerate(valid_dataloader):
                    batch = tuple(
                        t.to(device)
                        for t in batch)  # multi-gpu does scattering it-self
                    input_ids, segment_ids, input_mask, label_ids = batch
                    loss = model(input_ids, segment_ids, input_mask, label_ids)
                    losses.append(loss.data.item() * input_ids.size(0))
                    valid_size += input_ids.size(0)
                valid_loss = sum(losses) / valid_size
                logger.info("validation loss: %f", valid_loss)
                valid_losses.append(valid_loss)
            if valid_loss < best_valid_loss:
                torch.save(model, os.path.join(args.output_dir, "model.pt"))
                best_valid_loss = valid_loss
            model.train()
    if args.do_valid:
        with open(os.path.join(args.output_dir, "valid.json"), "w") as fw:
            json.dump({"valid_losses": valid_losses}, fw)
    else:
        torch.save(model, os.path.join(args.output_dir, "model.pt"))
    mstats = torch.cuda.memory_stats()
    duration = time.time() - start
    logger.info("Training completed in {} minutes, {} seconds".format(
        duration // 60, ceil(duration % 60)))
    logger.info("***** GPU Memory Statistics *****")
    logger.info("  Allocated bytes (Peak):      {} MiB".format(
        mstats['allocated_bytes.all.peak'] / 1048576))
    logger.info("  Allocated bytes (Allocated): {} MiB".format(
        mstats['allocated_bytes.all.allocated'] / 1048576))
예제 #4
0
def test(
    args
):  # Load a trained model that you have fine-tuned (we assume evaluate on cpu)
    start = time.time()
    torch.cuda.empty_cache()
    processor = data_utils.E2EProcessor()
    label_list = processor.get_labels()
    # tokenizer = ABSATokenizer.from_pretrained(modelconfig.MODEL_ARCHIVE_MAP[args.albert_model])
    if args.albert_model == 'voidful/albert_chinese_base':
        tokenizer = ABSATokenizerB.from_pretrained(args.albert_model)
    else:
        tokenizer = ABSATokenizer.from_pretrained(args.albert_model)

    eval_examples = processor.get_test_examples(args.data_dir)
    eval_features = data_utils.convert_examples_to_features(
        eval_examples, label_list, args.max_seq_length, tokenizer, "e2e")

    logger.info("***** Running evaluation *****")
    logger.info("  Num examples = %d", len(eval_examples))
    logger.info("  Batch size = %d", args.eval_batch_size)
    all_input_ids = torch.tensor([f.input_ids for f in eval_features],
                                 dtype=torch.long)
    all_segment_ids = torch.tensor([f.segment_ids for f in eval_features],
                                   dtype=torch.long)
    all_input_mask = torch.tensor([f.input_mask for f in eval_features],
                                  dtype=torch.long)
    all_label_ids = torch.tensor([f.label_id for f in eval_features],
                                 dtype=torch.long)
    eval_data = TensorDataset(all_input_ids, all_segment_ids, all_input_mask,
                              all_label_ids)
    # Run prediction for full data
    eval_sampler = SequentialSampler(eval_data)
    eval_dataloader = DataLoader(eval_data,
                                 sampler=eval_sampler,
                                 batch_size=args.eval_batch_size)

    model = torch.load(os.path.join(args.output_dir, "model.pt"))
    params_total = sum(p.numel() for p in model.parameters())
    params_trainable = sum(p.numel() for p in model.parameters()
                           if p.requires_grad)
    logger.info("***** Model Properties *****")
    logger.info("  Parameters (Total): {:.2e}".format(params_total))
    logger.info("  Parameters (Trainable): {:.2e}".format(params_trainable))
    model.to(device)
    model.eval()

    full_logits = []
    full_label_ids = []
    for step, batch in enumerate(eval_dataloader):
        batch = tuple(t.to(device) for t in batch)
        input_ids, segment_ids, input_mask, label_ids = batch

        with torch.no_grad():
            logits = model(input_ids, segment_ids, input_mask)

        logits = logits.detach().cpu().numpy()
        label_ids = label_ids.cpu().numpy()

        full_logits.extend(logits.tolist())
        full_label_ids.extend(label_ids.tolist())

    output_eval_json = os.path.join(args.output_dir, "predictions.json")
    with open(output_eval_json, "w") as fw:
        json.dump({
            "logits": full_logits,
            "label_ids": full_label_ids
        },
                  fw,
                  indent=4)
        # assert len(full_logits)==len(eval_examples)
        # #sort by original order for evaluation
        # recs={}
        # for qx, ex in enumerate(eval_examples):
        #     recs[int(ex.guid.split("-")[1]) ]={"sentence": ex.text_a, "idx_map": ex.idx_map, "logit": full_logits[qx][1:]} #skip the [CLS] tag.
        # full_logits=[recs[qx]["logit"] for qx in range(len(full_logits))]
        # raw_X=[recs[qx]["sentence"] for qx in range(len(eval_examples) ) ]
        # idx_map=[recs[qx]["idx_map"] for qx in range(len(eval_examples)) ]
        # json.dump({"logits": full_logits, "raw_X": raw_X, "idx_map": idx_map}, fw, indent=4)
    mstats = torch.cuda.memory_stats()
    duration = time.time() - start
    logger.info("Testing completed in {} minutes, {} seconds".format(
        duration // 60, ceil(duration % 60)))
    logger.info("***** GPU Memory Statistics *****")
    logger.info("  Allocated bytes (Peak):      {} MiB".format(
        mstats['allocated_bytes.all.peak'] / 1048576))
    logger.info("  Allocated bytes (Allocated): {} MiB".format(
        mstats['allocated_bytes.all.allocated'] / 1048576))
예제 #5
0
def get(logger=None, args=None):
    data = {}
    taskcla = []

    # Others
    f_name = 'asc_random'

    with open(f_name, 'r') as f_random_seq:
        random_sep = f_random_seq.readlines()[args.idrandom].split()

    print('random_sep: ', random_sep)
    print('domains: ', domains)

    print('random_sep: ', len(random_sep))
    print('domains: ', len(domains))

    for t in range(args.ntasks):
        dataset = datasets[domains.index(random_sep[t])]

        data[t] = {}
        if 'Bing' in dataset:
            data[t]['name'] = dataset
            data[t]['ncla'] = 2
        elif 'XuSemEval' in dataset:
            data[t]['name'] = dataset
            data[t]['ncla'] = 3

        processor = data_utils.AscProcessor()
        label_list = processor.get_labels()
        tokenizer = ABSATokenizer.from_pretrained(args.bert_model)
        train_examples = processor.get_train_examples(dataset)
        num_train_steps = int(
            math.ceil(len(train_examples) /
                      args.train_batch_size)) * args.num_train_epochs
        # num_train_steps = int(len(train_examples) / args.train_batch_size) * args.num_train_epochs

        train_features = data_utils.convert_examples_to_features(
            train_examples, label_list, args.max_seq_length, tokenizer, "asc")
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_examples))
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", num_train_steps)

        all_input_ids = torch.tensor([f.input_ids for f in train_features],
                                     dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in train_features],
                                       dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in train_features],
                                      dtype=torch.long)
        all_label_ids = torch.tensor([f.label_id for f in train_features],
                                     dtype=torch.long)
        all_tasks = torch.tensor([t for f in train_features], dtype=torch.long)

        train_data = TensorDataset(all_input_ids, all_segment_ids,
                                   all_input_mask, all_label_ids, all_tasks)

        data[t]['train'] = train_data
        data[t]['num_train_steps'] = num_train_steps

        valid_examples = processor.get_dev_examples(dataset)
        valid_features = data_utils.convert_examples_to_features(
            valid_examples, label_list, args.max_seq_length, tokenizer, "asc")
        valid_all_input_ids = torch.tensor(
            [f.input_ids for f in valid_features], dtype=torch.long)
        valid_all_segment_ids = torch.tensor(
            [f.segment_ids for f in valid_features], dtype=torch.long)
        valid_all_input_mask = torch.tensor(
            [f.input_mask for f in valid_features], dtype=torch.long)
        valid_all_label_ids = torch.tensor(
            [f.label_id for f in valid_features], dtype=torch.long)
        valid_all_tasks = torch.tensor([t for f in valid_features],
                                       dtype=torch.long)

        valid_data = TensorDataset(valid_all_input_ids, valid_all_segment_ids,
                                   valid_all_input_mask, valid_all_label_ids,
                                   valid_all_tasks)

        logger.info("***** Running validations *****")
        logger.info("  Num orig examples = %d", len(valid_examples))
        logger.info("  Num split examples = %d", len(valid_features))
        logger.info("  Batch size = %d", args.train_batch_size)

        data[t]['valid'] = valid_data

        processor = data_utils.AscProcessor()
        label_list = processor.get_labels()
        tokenizer = BertTokenizer.from_pretrained(args.bert_model)
        eval_examples = processor.get_test_examples(dataset)
        eval_features = data_utils.convert_examples_to_features(
            eval_examples, label_list, args.max_seq_length, tokenizer, "asc")

        logger.info("***** Running evaluation *****")
        logger.info("  Num examples = %d", len(eval_examples))
        logger.info("  Batch size = %d", args.eval_batch_size)
        all_input_ids = torch.tensor([f.input_ids for f in eval_features],
                                     dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in eval_features],
                                       dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in eval_features],
                                      dtype=torch.long)
        all_label_ids = torch.tensor([f.label_id for f in eval_features],
                                     dtype=torch.long)
        all_tasks = torch.tensor([t for f in eval_features], dtype=torch.long)

        eval_data = TensorDataset(all_input_ids, all_segment_ids,
                                  all_input_mask, all_label_ids, all_tasks)
        # Run prediction for full data

        data[t]['test'] = eval_data

        taskcla.append((t, int(data[t]['ncla'])))

    # Others
    n = 0
    for t in data.keys():
        n += data[t]['ncla']
    data['ncla'] = n

    return data, taskcla
예제 #6
0
def get(logger=None, args=None):

    #TODO: 另外生成多一个mask for generation

    data = {}
    taskcla = []

    # Others
    f_name = 'asc_random'

    with open(f_name, 'r') as f_random_seq:
        random_sep = f_random_seq.readlines()[args.idrandom].split()

    print('random_sep: ', random_sep)
    print('domains: ', domains)

    print('random_sep: ', len(random_sep))
    print('domains: ', len(domains))

    for t in range(args.ntasks):
        asc_dataset = asc_datasets[domains.index(random_sep[t])]
        ae_dataset = ae_datasets[domains.index(random_sep[t])]

        data[t] = {}
        if 'Bing' in asc_dataset:
            data[t]['name'] = asc_dataset
            data[t]['ncla'] = 2
        elif 'XuSemEval' in asc_dataset:
            data[t]['name'] = asc_dataset
            data[t]['ncla'] = 3

        print('ae_dataset: ', ae_dataset)

        logger.info("***** Running training *****")

        #ASC for encoder ====================
        processor = data_utils.AscProcessor()
        label_list = processor.get_labels()
        tokenizer = ABSATokenizer.from_pretrained(args.bert_model)
        train_examples = processor.get_train_examples(asc_dataset)
        train_features = data_utils.convert_examples_to_features_gen(
            train_examples, label_list, args.max_seq_length, tokenizer, "asc")

        all_asc_input_ids = torch.tensor([f.input_ids for f in train_features],
                                         dtype=torch.long)
        all_asc_segment_ids = torch.tensor(
            [f.segment_ids for f in train_features], dtype=torch.long)
        all_asc_input_mask = torch.tensor(
            [f.input_mask for f in train_features], dtype=torch.long)
        all_asc_label_ids = torch.tensor([f.label_id for f in train_features],
                                         dtype=torch.long)
        all_tasks = torch.tensor([t for f in train_features], dtype=torch.long)

        #AE for decoder ====================
        processor = data_utils.AeProcessor()
        label_list = processor.get_labels()
        tokenizer = ABSATokenizer.from_pretrained(args.bert_model)
        train_examples = processor.get_train_examples(ae_dataset)
        train_features = data_utils.convert_examples_to_features_gen(
            train_examples, label_list, args.max_seq_length, tokenizer, "ae")

        all_ae_input_ids = torch.tensor([f.input_ids for f in train_features],
                                        dtype=torch.long)
        all_ae_segment_ids = torch.tensor(
            [f.segment_ids for f in train_features], dtype=torch.long)
        all_ae_input_mask = torch.tensor(
            [f.input_mask for f in train_features], dtype=torch.long)
        all_ae_label_ids = torch.tensor([f.label_id for f in train_features],
                                        dtype=torch.long)

        #SG (sentence generation) for decoder ====================
        processor = data_utils.SgProcessor()
        label_list = None
        tokenizer = ABSATokenizer.from_pretrained(args.bert_model)
        train_examples = processor.get_train_examples(asc_dataset)

        mask_source_words = args.mask_source_words
        max_pred = args.max_pred
        mask_prob = args.mask_prob
        skipgram_prb = args.skipgram_prb
        skipgram_size = args.skipgram_size
        mask_whole_word = args.mask_whole_word
        vocab_words = list(tokenizer.vocab.keys())
        indexer = tokenizer.convert_tokens_to_ids

        train_features = data_utils.convert_examples_to_features_gen(
            train_examples,
            label_list,
            args.max_seq_length * 2,
            tokenizer,
            "sg",
            mask_source_words=mask_source_words,
            max_pred=max_pred,
            mask_prob=mask_prob,
            skipgram_prb=skipgram_prb,
            skipgram_size=skipgram_size,
            mask_whole_word=mask_whole_word,
            vocab_words=vocab_words,
            indexer=indexer)  #seq2seq task

        all_sg_input_ids = torch.tensor([f.input_ids for f in train_features],
                                        dtype=torch.long)
        all_sg_segment_ids = torch.tensor(
            [f.segment_ids for f in train_features], dtype=torch.long)
        all_sg_input_mask = torch.tensor(
            [f.input_mask for f in train_features], dtype=torch.long)
        all_sg_masked_lm_labels = torch.tensor(
            [f.masked_lm_labels for f in train_features],
            dtype=torch.long).squeeze(1)
        all_sg_masked_pos = torch.tensor(
            [f.masked_pos for f in train_features],
            dtype=torch.long).squeeze(1)
        all_sg_masked_weights = torch.tensor(
            [f.masked_weights for f in train_features], dtype=torch.long)

        ae_length = all_ae_input_ids.size(0)
        while all_ae_input_ids.size(0) < all_sg_input_ids.size(0):
            rand_id = torch.randint(low=0, high=ae_length, size=(1, ))
            all_ae_input_ids = torch.cat(
                [all_ae_input_ids, all_ae_input_ids[rand_id]], 0)
            all_ae_segment_ids = torch.cat(
                [all_ae_segment_ids, all_ae_segment_ids[rand_id]], 0)
            all_ae_input_mask = torch.cat(
                [all_ae_input_mask, all_ae_input_mask[rand_id]], 0)
            all_ae_label_ids = torch.cat(
                [all_ae_label_ids, all_ae_label_ids[rand_id]], 0)

        #some have sentiment conflict, the ae can be larger than asc
        asc_length = all_asc_input_ids.size(0)
        while all_asc_input_ids.size(0) < all_ae_input_ids.size(0):
            rand_id = torch.randint(low=0, high=asc_length, size=(1, ))
            all_asc_input_ids = torch.cat(
                [all_asc_input_ids, all_asc_input_ids[rand_id]], 0)
            all_asc_segment_ids = torch.cat(
                [all_asc_segment_ids, all_asc_segment_ids[rand_id]], 0)
            all_asc_input_mask = torch.cat(
                [all_asc_input_mask, all_asc_input_mask[rand_id]], 0)
            all_asc_label_ids = torch.cat(
                [all_asc_label_ids, all_asc_label_ids[rand_id]], 0)
            all_sg_input_ids = torch.cat(
                [all_sg_input_ids, all_sg_input_ids[rand_id]], 0)
            all_sg_segment_ids = torch.cat(
                [all_sg_segment_ids, all_sg_segment_ids[rand_id]], 0)
            all_sg_input_mask = torch.cat(
                [all_sg_input_mask, all_sg_input_mask[rand_id]], 0)
            all_sg_masked_lm_labels = torch.cat(
                [all_sg_masked_lm_labels, all_sg_masked_lm_labels[rand_id]], 0)
            all_sg_masked_pos = torch.cat(
                [all_sg_masked_pos, all_sg_masked_pos[rand_id]], 0)
            all_sg_masked_weights = torch.cat(
                [all_sg_masked_weights, all_sg_masked_weights[rand_id]], 0)
            all_tasks = torch.cat([all_tasks, all_tasks[rand_id]], 0)

            # ae is smaller in size than others. beacuase a sentence can have multiple terms

        num_train_steps = int(
            math.ceil(all_asc_input_ids.size(0) /
                      args.train_batch_size)) * args.num_train_epochs
        # num_train_steps = int(len(train_examples) / args.train_batch_size) * args.num_train_epochs

        logger.info("  Num asc examples = %d", all_asc_input_ids.size(0))
        logger.info("  Num sg examples = %d", all_sg_input_ids.size(0))
        logger.info("  Num ae examples = %d", all_ae_input_ids.size(0))

        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", num_train_steps)



        train_data = \
            TensorDataset(all_asc_input_ids,all_asc_segment_ids, all_asc_input_mask,\
            all_sg_input_ids, all_sg_segment_ids, all_sg_input_mask,\
            all_sg_masked_lm_labels,all_sg_masked_pos,all_sg_masked_weights,\
            all_ae_input_ids, all_ae_segment_ids, all_ae_input_mask,all_ae_label_ids,all_asc_label_ids,all_tasks)

        data[t]['train'] = train_data
        data[t]['num_train_steps'] = num_train_steps

        logger.info("***** Running validations *****")

        processor = data_utils.AscProcessor()
        label_list = processor.get_labels()
        tokenizer = BertTokenizer.from_pretrained(args.bert_model)
        dev_examples = processor.get_dev_examples(asc_dataset)
        dev_features = data_utils.convert_examples_to_features_gen(
            dev_examples, label_list, args.max_seq_length, tokenizer, "asc")

        all_asc_input_ids = torch.tensor([f.input_ids for f in dev_features],
                                         dtype=torch.long)
        all_asc_segment_ids = torch.tensor(
            [f.segment_ids for f in dev_features], dtype=torch.long)
        all_asc_input_mask = torch.tensor([f.input_mask for f in dev_features],
                                          dtype=torch.long)
        all_asc_label_ids = torch.tensor([f.label_id for f in dev_features],
                                         dtype=torch.long)
        all_tasks = torch.tensor([t for f in dev_features], dtype=torch.long)

        #AE for decoder ====================
        processor = data_utils.AeProcessor()
        label_list = processor.get_labels()
        tokenizer = ABSATokenizer.from_pretrained(args.bert_model)
        dev_examples = processor.get_dev_examples(ae_dataset)
        dev_features = data_utils.convert_examples_to_features_gen(
            dev_examples, label_list, args.max_seq_length, tokenizer, "ae")

        all_ae_input_ids = torch.tensor([f.input_ids for f in dev_features],
                                        dtype=torch.long)
        all_ae_segment_ids = torch.tensor(
            [f.segment_ids for f in dev_features], dtype=torch.long)
        all_ae_input_mask = torch.tensor([f.input_mask for f in dev_features],
                                         dtype=torch.long)
        all_ae_label_ids = torch.tensor([f.label_id for f in dev_features],
                                        dtype=torch.long)

        #SG (sentence generation) for decoder ====================
        processor = data_utils.SgProcessor()
        label_list = None
        tokenizer = ABSATokenizer.from_pretrained(args.bert_model)
        dev_examples = processor.get_dev_examples(asc_dataset)
        mask_source_words = args.mask_source_words
        max_pred = args.max_pred
        mask_prob = args.mask_prob
        skipgram_prb = args.skipgram_prb
        skipgram_size = args.skipgram_size
        mask_whole_word = args.mask_whole_word
        vocab_words = list(tokenizer.vocab.keys())
        indexer = tokenizer.convert_tokens_to_ids

        dev_features = data_utils.convert_examples_to_features_gen(
            dev_examples,
            label_list,
            args.max_seq_length * 2,
            tokenizer,
            "sg",
            mask_source_words=mask_source_words,
            max_pred=max_pred,
            mask_prob=mask_prob,
            skipgram_prb=skipgram_prb,
            skipgram_size=skipgram_size,
            mask_whole_word=mask_whole_word,
            vocab_words=vocab_words,
            indexer=indexer)  #seq2seq task

        all_sg_input_ids = torch.tensor([f.input_ids for f in dev_features],
                                        dtype=torch.long)
        all_sg_segment_ids = torch.tensor(
            [f.segment_ids for f in dev_features], dtype=torch.long)
        all_sg_input_mask = torch.tensor([f.input_mask for f in dev_features],
                                         dtype=torch.long)
        all_sg_masked_lm_labels = torch.tensor(
            [f.masked_lm_labels for f in dev_features],
            dtype=torch.long).squeeze(1)
        all_sg_masked_pos = torch.tensor([f.masked_pos for f in dev_features],
                                         dtype=torch.long).squeeze(1)
        all_sg_masked_weights = torch.tensor(
            [f.masked_weights for f in dev_features], dtype=torch.long)

        ae_length = all_ae_input_ids.size(0)
        while all_ae_input_ids.size(0) < all_sg_input_ids.size(0):
            rand_id = torch.randint(low=0, high=ae_length, size=(1, ))
            all_ae_input_ids = torch.cat(
                [all_ae_input_ids, all_ae_input_ids[rand_id]], 0)
            all_ae_segment_ids = torch.cat(
                [all_ae_segment_ids, all_ae_segment_ids[rand_id]], 0)
            all_ae_input_mask = torch.cat(
                [all_ae_input_mask, all_ae_input_mask[rand_id]], 0)
            all_ae_label_ids = torch.cat(
                [all_ae_label_ids, all_ae_label_ids[rand_id]], 0)

        #some have sentiment conflict, the ae can be larger than asc
        asc_length = all_asc_input_ids.size(0)
        while all_asc_input_ids.size(0) < all_ae_input_ids.size(0):
            rand_id = torch.randint(low=0, high=asc_length, size=(1, ))
            all_asc_input_ids = torch.cat(
                [all_asc_input_ids, all_asc_input_ids[rand_id]], 0)
            all_asc_segment_ids = torch.cat(
                [all_asc_segment_ids, all_asc_segment_ids[rand_id]], 0)
            all_asc_input_mask = torch.cat(
                [all_asc_input_mask, all_asc_input_mask[rand_id]], 0)
            all_asc_label_ids = torch.cat(
                [all_asc_label_ids, all_asc_label_ids[rand_id]], 0)
            all_sg_input_ids = torch.cat(
                [all_sg_input_ids, all_sg_input_ids[rand_id]], 0)
            all_sg_segment_ids = torch.cat(
                [all_sg_segment_ids, all_sg_segment_ids[rand_id]], 0)
            all_sg_input_mask = torch.cat(
                [all_sg_input_mask, all_sg_input_mask[rand_id]], 0)
            all_sg_masked_lm_labels = torch.cat(
                [all_sg_masked_lm_labels, all_sg_masked_lm_labels[rand_id]], 0)
            all_sg_masked_pos = torch.cat(
                [all_sg_masked_pos, all_sg_masked_pos[rand_id]], 0)
            all_sg_masked_weights = torch.cat(
                [all_sg_masked_weights, all_sg_masked_weights[rand_id]], 0)
            all_tasks = torch.cat([all_tasks, all_tasks[rand_id]], 0)

        logger.info("  Num asc examples = %d", all_asc_input_ids.size(0))
        logger.info("  Num sg examples = %d", all_sg_input_ids.size(0))
        logger.info("  Num ae examples = %d", all_ae_input_ids.size(0))


        valid_data = \
            TensorDataset(all_asc_input_ids,all_asc_segment_ids, all_asc_input_mask,\
            all_sg_input_ids, all_sg_segment_ids, all_sg_input_mask,\
            all_sg_masked_lm_labels,all_sg_masked_pos,all_sg_masked_weights,\
            all_ae_input_ids, all_ae_segment_ids, all_ae_input_mask,all_ae_label_ids,all_asc_label_ids,all_tasks)

        data[t]['valid'] = valid_data

        logger.info("***** Running evaluation *****")

        processor = data_utils.AscProcessor()
        label_list = processor.get_labels()
        tokenizer = BertTokenizer.from_pretrained(args.bert_model)
        eval_examples = processor.get_test_examples(asc_dataset)
        eval_features = data_utils.convert_examples_to_features_gen(
            eval_examples, label_list, args.max_seq_length, tokenizer, "asc")

        all_asc_input_ids = torch.tensor([f.input_ids for f in eval_features],
                                         dtype=torch.long)
        all_asc_segment_ids = torch.tensor(
            [f.segment_ids for f in eval_features], dtype=torch.long)
        all_asc_input_mask = torch.tensor(
            [f.input_mask for f in eval_features], dtype=torch.long)
        all_asc_label_ids = torch.tensor([f.label_id for f in eval_features],
                                         dtype=torch.long)
        all_tasks = torch.tensor([t for f in eval_features], dtype=torch.long)

        #AE for decoder ====================
        processor = data_utils.AeProcessor()
        label_list = processor.get_labels()
        tokenizer = ABSATokenizer.from_pretrained(args.bert_model)
        eval_examples = processor.get_test_examples(ae_dataset)

        eval_features = data_utils.convert_examples_to_features_gen(
            eval_examples, label_list, args.max_seq_length, tokenizer, "ae")

        all_ae_input_ids = torch.tensor([f.input_ids for f in eval_features],
                                        dtype=torch.long)
        all_ae_segment_ids = torch.tensor(
            [f.segment_ids for f in eval_features], dtype=torch.long)
        all_ae_input_mask = torch.tensor([f.input_mask for f in eval_features],
                                         dtype=torch.long)
        all_ae_label_ids = torch.tensor([f.label_id for f in eval_features],
                                        dtype=torch.long)

        #SG (sentence generation) for decoder ====================
        processor = data_utils.SgProcessor()
        label_list = None
        tokenizer = ABSATokenizer.from_pretrained(args.bert_model)
        eval_examples = processor.get_test_examples(asc_dataset)

        mask_source_words = args.mask_source_words
        max_pred = args.max_pred
        mask_prob = args.mask_prob
        skipgram_prb = args.skipgram_prb
        skipgram_size = args.skipgram_size
        mask_whole_word = args.mask_whole_word
        vocab_words = list(tokenizer.vocab.keys())
        indexer = tokenizer.convert_tokens_to_ids

        eval_features = data_utils.convert_examples_to_features_gen(
            eval_examples,
            label_list,
            args.max_seq_length * 2,
            tokenizer,
            "sg",
            mask_source_words=mask_source_words,
            max_pred=max_pred,
            mask_prob=mask_prob,
            skipgram_prb=skipgram_prb,
            skipgram_size=skipgram_size,
            mask_whole_word=mask_whole_word,
            vocab_words=vocab_words,
            indexer=indexer)  #seq2seq task

        all_sg_input_ids = torch.tensor([f.input_ids for f in eval_features],
                                        dtype=torch.long)
        all_sg_segment_ids = torch.tensor(
            [f.segment_ids for f in eval_features], dtype=torch.long)
        all_sg_input_mask = torch.tensor([f.input_mask for f in eval_features],
                                         dtype=torch.long)
        all_sg_masked_lm_labels = torch.tensor(
            [f.masked_lm_labels for f in eval_features],
            dtype=torch.long).squeeze(1)
        all_sg_masked_pos = torch.tensor([f.masked_pos for f in eval_features],
                                         dtype=torch.long).squeeze(1)
        all_sg_masked_weights = torch.tensor(
            [f.masked_weights for f in eval_features], dtype=torch.long)

        ae_length = all_ae_input_ids.size(0)
        while all_ae_input_ids.size(0) < all_sg_input_ids.size(0):
            rand_id = torch.randint(low=0, high=ae_length, size=(1, ))
            all_ae_input_ids = torch.cat(
                [all_ae_input_ids, all_ae_input_ids[rand_id]], 0)
            all_ae_segment_ids = torch.cat(
                [all_ae_segment_ids, all_ae_segment_ids[rand_id]], 0)
            all_ae_input_mask = torch.cat(
                [all_ae_input_mask, all_ae_input_mask[rand_id]], 0)
            all_ae_label_ids = torch.cat(
                [all_ae_label_ids, all_ae_label_ids[rand_id]], 0)

        #some have sentiment conflict, the ae can be larger than asc
        asc_length = all_asc_input_ids.size(0)
        while all_asc_input_ids.size(0) < all_ae_input_ids.size(0):
            rand_id = torch.randint(low=0, high=asc_length, size=(1, ))
            all_asc_input_ids = torch.cat(
                [all_asc_input_ids, all_asc_input_ids[rand_id]], 0)
            all_asc_segment_ids = torch.cat(
                [all_asc_segment_ids, all_asc_segment_ids[rand_id]], 0)
            all_asc_input_mask = torch.cat(
                [all_asc_input_mask, all_asc_input_mask[rand_id]], 0)
            all_asc_label_ids = torch.cat(
                [all_asc_label_ids, all_asc_label_ids[rand_id]], 0)
            all_sg_input_ids = torch.cat(
                [all_sg_input_ids, all_sg_input_ids[rand_id]], 0)
            all_sg_segment_ids = torch.cat(
                [all_sg_segment_ids, all_sg_segment_ids[rand_id]], 0)
            all_sg_input_mask = torch.cat(
                [all_sg_input_mask, all_sg_input_mask[rand_id]], 0)
            all_sg_masked_lm_labels = torch.cat(
                [all_sg_masked_lm_labels, all_sg_masked_lm_labels[rand_id]], 0)
            all_sg_masked_pos = torch.cat(
                [all_sg_masked_pos, all_sg_masked_pos[rand_id]], 0)
            all_sg_masked_weights = torch.cat(
                [all_sg_masked_weights, all_sg_masked_weights[rand_id]], 0)
            all_tasks = torch.cat([all_tasks, all_tasks[rand_id]], 0)

        logger.info("  Num asc examples = %d", all_asc_input_ids.size(0))
        logger.info("  Num sg examples = %d", all_sg_input_ids.size(0))
        logger.info("  Num ae examples = %d", all_ae_input_ids.size(0))



        eval_data = \
            TensorDataset(all_asc_input_ids,all_asc_segment_ids, all_asc_input_mask,\
            all_sg_input_ids, all_sg_segment_ids, all_sg_input_mask,\
            all_sg_masked_lm_labels,all_sg_masked_pos,all_sg_masked_weights,\
            all_ae_input_ids, all_ae_segment_ids, all_ae_input_mask,all_ae_label_ids,all_asc_label_ids,all_tasks)

        # Run prediction for full data

        data[t]['test'] = eval_data

        taskcla.append((t, int(data[t]['ncla'])))

    # Others
    n = 0
    for t in data.keys():
        n += data[t]['ncla']
    data['ncla'] = n

    return data, taskcla