Exemplo n.º 1
0
 def save_checkpoint(self, val_loss, model, model_dir):
     '''Saves model when validation loss decrease.'''
     if self.verbose:
         self.logger.info(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
     model.save_pretrained(model_dir)
     TOKENIZER.save_pretrained(model_dir)
     self.val_loss_min = val_loss
Exemplo n.º 2
0
def test_one_to_many(task_load):
    score_dicts = []
    for ep in range(args.n_train_epochs[task_load]):
        model_dir = get_model_dir([task_load])
        model_path = os.path.join(model_dir, 'model-{}'.format(ep+1))
        config_path = os.path.join(model_dir,CONFIG_NAME)

        gen_token = get_gen_token(task_load)
        TOKENIZER.add_tokens([gen_token])
        SPECIAL_TOKENS[task_load] = gen_token
        SPECIAL_TOKEN_IDS[task_load] = TOKENIZER.convert_tokens_to_ids(gen_token)
        model_config = CONFIG_CLASS.from_json_file(config_path) 
        model = MODEL_CLASS(model_config).cuda().eval()
        state_dict = torch.load(model_path, map_location='cuda:0')
        model.load_state_dict(state_dict)

        if not args.fp32:
            model = FP16_Module(model)

        model.ep = ep
        model.model_dir = model_dir
        logger.info("task: {}, epoch: {}".format(task_load, ep+1))
        score_dict = {k:None for k in args.tasks}
        with torch.no_grad():
            for task_eval in args.tasks:
                test_one_to_one(task_load, task_eval, model, score_dict)
        logger.info("score: {}".format(score_dict))
        score_dicts.append(score_dict)

    with open(os.path.join(model_dir, "metrics.json"),"w") as f:
        json.dump(score_dicts, f)
Exemplo n.º 3
0
def read_extra_data(gen_path, train_extra_data):
    with open(gen_path,"r") as lm_file:
        reader = csv.reader(lm_file,delimiter=',')
        next(reader)
        for row in reader: 
            row = TOKENIZER.encode(row[0].strip()) 
            train_extra_data.append(row)
Exemplo n.º 4
0
def get_real_data(task, train_extra_data, accum=True, encode=True):
    task_idx = args.tasks.index(task)
    gen_size = DATA_ATTRS[task]["train"]["data_size"]
    if accum:
        prev_tasks = args.tasks[:task_idx]
        gen_size = int(np.ceil(gen_size * args.gen_lm_sample_percentage))//len(prev_tasks)
    else:
        prev_tasks = [args.tasks[task_idx-1]]
        gen_size = int(gen_size * args.gen_lm_sample_percentage)

    datum = []
    for prev_task in prev_tasks:
        with open(TASK_DICT[prev_task]["train"],"r") as f:
            data = data_expand(json.load(f)["data"])
        indices = np.random.choice(range(len(data)), gen_size)
        for i in indices:
            d = parse_single_real_data(data[i],prev_task)
            datum.append(d)
            if encode:
                train_extra_data.append(TOKENIZER.encode(d))
        
    model_dir = get_model_dir([prev_task])
    dump_path = os.path.join(model_dir,"real.csv")
    write_extra_data(dump_path, datum)
    return dump_path
Exemplo n.º 5
0
    def parallel_tokenization(self, d):
        examples = []
        context = TOKENIZER.encode(d["context"])
        max_a_len = 0
        for i3, qa in enumerate(d["qas"]):
            question = TOKENIZER.encode(qa["question"])

            raw_answers = qa["answers"]
            if len(raw_answers) == 0:
                assert qa["is_impossible"]
                raw_answers.append({"text": ""})

            answer = []
            for i, raw_answer in enumerate(raw_answers):
                answer.extend(TOKENIZER.encode(raw_answer["text"]))
                if i != len(raw_answers) - 1:
                    answer.append(self.pad_token)
            max_a_len = max(max_a_len, len(answer))

            examples.append(self.parse_example(self.gen_token, context, question, answer, qa.get("id", 0 if not args.test_training_set else d["pid"]+"_%d"%i3)))
        return examples, max_a_len
Exemplo n.º 6
0
def create_extra_data(task, prev_task, model, train_extra_data):
    if args.real_sample:
        logger.info(f"using real data as extra data")
        return get_real_data(task, train_extra_data)
    task_cnt = args.tasks.index(task)
    model_dir = get_model_dir([prev_task])
    gen_path = os.path.join(model_dir, "lm.csv")
    if os.path.exists(gen_path):
        logger.info(f"extra data exists in {gen_path}, read it!")
        return read_extra_data(gen_path, train_extra_data)
    gen_size = DATA_ATTRS[task]["train"]["data_size"]
    gen_size = int(np.ceil(gen_size * args.gen_lm_sample_percentage))
    gen_size -= (gen_size % task_cnt)

    if args.debug:
        gen_size = task_cnt

    model.eval()

    need_process = OrderedDict()
    qa_results = []
    for task_name in args.tasks[:task_cnt]:
        qa_results.extend([
            torch.tensor([SPECIAL_TOKEN_IDS[task_name]])
            for _ in range(gen_size // task_cnt)
        ])
    all_pasts = [[
        torch.empty(2,
                    MODEL_CONFIG.n_head,
                    0,
                    MODEL_CONFIG.n_embd // MODEL_CONFIG.n_head,
                    dtype=torch.float if args.fp32 else torch.half).cuda()
        for _ in range(gen_size)
    ] for __ in range(MODEL_CONFIG.n_layer)]
    max_tot_lens = [args.max_len for _ in range(gen_size)]

    for i in range(gen_size):
        need_process.update([[i, None]])
        if len(need_process) > int(args.memory_sizes[0] * 0.12):
            sample_sequence(model, need_process, qa_results, all_pasts,
                            max_tot_lens)
    sample_sequence(model, need_process, qa_results, all_pasts, max_tot_lens)

    model.train()

    qa_results = [res.tolist() for res in qa_results]
    train_extra_data.extend(qa_results)
    qa_results = [TOKENIZER.decode(res) for res in qa_results]

    write_extra_data(gen_path, qa_results)
Exemplo n.º 7
0
def test_one_to_one(task_load, task_eval, model, score_dict):

    logger.info("start to test { task: %s (load) %s (eval), seq train type: %s }" % (task_load, task_eval, args.seq_train_type))

    test_qadata = QADataset(TASK_DICT[task_eval]["test"] , "test", SPECIAL_TOKEN_IDS[task_load]).sort()
    max_a_len = test_qadata.max_a_len
    test_dataloader = create_dataloader(test_qadata, "test")
    n_examples = len(test_qadata)
    logger.info("len of test dataset: {}".format(n_examples))

    need_process = OrderedDict()
    qa_results = [0 for _ in range(n_examples)]
    all_pasts = [[0 for _ in range(n_examples)] for __ in range(MODEL_CONFIG.n_layer)]
    max_tot_lens = [0 for _ in range(n_examples)]

    cnt = 0
    for n_steps, (cqs, len_cqs, _, _, _, _, _) in enumerate(test_dataloader):
        # assume n_gpus == 1
        cqs = cqs[0]
        len_cqs = len_cqs[0]
        n_inputs = cqs.shape[0]
        all_outputs = model(input_ids=cqs.cuda())
        outputs = all_outputs[0]
        if args.model_name == "gpt2":
            pasts = all_outputs[1]
        next_logits = outputs[range(n_inputs), len_cqs-1, :] / args.temperature_qa
        next_tokens = logits_to_tokens(next_logits).cpu()

        for i in range(n_inputs):
            max_tot_lens[cnt] = max_a_len + test_qadata[cnt][1]
            qa_results[cnt] = cqs[i][:len_cqs[i]]
            if next_tokens[i] != SPECIAL_TOKEN_IDS["eos_token"]:
                qa_results[cnt] = torch.cat((cqs[i][:len_cqs[i]], next_tokens[i]))
                if len(qa_results[cnt]) not in [max_tot_lens[cnt], args.max_len]:
                    need_process.update([[cnt, None]])
                    if args.model_name == "gpt2":
                        for layer_id in range(MODEL_CONFIG.n_layer):
                            all_pasts[layer_id][cnt] = pasts[layer_id][:, i, ..., :len_cqs[i], :].type(torch.float32 if args.fp32 else torch.half)
            cnt += 1

        if len(need_process) > int(12 * args.memory_sizes[0] / cqs.shape[1]):  # dynamic threshold to avoid out of memory
            sample_sequence(model, need_process, qa_results, all_pasts, max_tot_lens)
    sample_sequence(model, need_process, qa_results, all_pasts, max_tot_lens)

    if task_eval in ['wikisql','woz.en','multinli.in.out']:
        ids = test_qadata.get_indices()
        test_qadata.sort_by_index()
        qa_results = [x[1] for x in sorted([(i, g) for i, g in zip(ids, qa_results)])]
    for i in range(len(test_qadata)):
        _, len_cq, _, _, Y, _, _, _ = test_qadata[i]
        if task_eval in ['wikisql','woz.en']:
            Y = test_qadata.answers[i]
        else:
            Y = list(filter(lambda x: x != -1, Y))[:-1]  # remove eos
            Y = ' '.join([str(y) for y in Y]).split(str(SPECIAL_TOKEN_IDS["pad_token"]))
            Y = [TOKENIZER.decode(list(map(int, y.split()))) for y in Y]
        qa_results[i] = [TOKENIZER.decode(qa_results[i].tolist()[len_cq:]), Y]
    get_test_score(task_eval, qa_results, score_dict)

    model_dir = model.model_dir
    ep = model.ep
    results_path = os.path.join(model_dir,"qa_{}_{}.csv".format(task_eval,ep+1))
    if not args.debug:
        with open(results_path, "w",encoding="utf-8") as f:
            qa_writer = csv.writer(f,delimiter=',')
            qa_writer.writerow(["y","pred"])
            for pred, y in qa_results:
                if task_eval == 'wikisql': 
                    y = y["answer"]
                elif task_eval == 'woz.en': 
                    y = y[1]
                qa_writer.writerow([y,pred])

    return model, score_dict
Exemplo n.º 8
0
def train(task_ids, model):
    tasks = [args.tasks[task_id] for task_id in task_ids]

    logger.info("start to train { task: %s, seq train type: %s }" %
                (tasks, args.seq_train_type))
    model_dir = get_model_dir(tasks)
    make_dir(model_dir)

    #train_dataset = [(TASK_DICT[t]["train"] if not args.seq_distil else TASK_DICT[t]["train"].replace("train", "distil")) for t in tasks]
    train_dataset = [
        swap_name(TASK_DICT[t]["train"], args.seq_distil, args.ref1)
        for t in tasks
    ]
    train_extra_data = []
    if "lll" in args.seq_train_type and task_ids[0] > 0 and not args.skip_tasks:
        prev_task = args.tasks[task_ids[0] - 1]
        with torch.no_grad():
            create_extra_data(tasks[0], prev_task, model, train_extra_data)
    elif "gem" in args.seq_train_type and task_ids[0] > 0:
        get_real_data(tasks[0], train_extra_data, accum=False, encode=True)
        args.memory_data.append(train_extra_data)
        train_extra_data = []
    logger.info('extra training data size: {}'.format(len(train_extra_data)))

    if not model:
        # which_model_to_load = model_dir if os.path.isfile(os.path.join(model_dir, FINAL_SAVE_NAME)) else args.model_name
        model = MODEL_CLASS.from_pretrained(args.model_name).cuda()
        model.resize_token_embeddings(len(TOKENIZER))
        if not args.fp32:
            model = FP16_Module(model)

    gen_token = get_gen_token(tasks[0])
    TOKENIZER.add_tokens([gen_token])
    TOKENIZER.save_pretrained(model_dir)
    SPECIAL_TOKENS[tasks[0]] = gen_token
    SPECIAL_TOKEN_IDS[tasks[0]] = TOKENIZER.convert_tokens_to_ids(gen_token)
    logger.info('gen token = {} , gen token id = {}'.format(
        gen_token, SPECIAL_TOKEN_IDS[tasks[0]]))
    MODEL_CONFIG.vocab_size = len(TOKENIZER)
    MODEL_CONFIG.to_json_file(os.path.join(model_dir, CONFIG_NAME))
    global TOKENS_WEIGHT
    if len(TOKENIZER) != TOKENS_WEIGHT.shape[0]:
        TOKENS_WEIGHT = torch.cat((TOKENS_WEIGHT, torch.ones([1]).cuda()))

    if args.skip_tasks and len(tasks) == 1:
        logger.info("*********** skip task: {} ***********".format(tasks[0]))
        if tasks[0] in args.skip_tasks:
            if len(args.skip_tasks) == 1:
                model_dir = get_model_dir(tasks)
                model_path = os.path.join(model_dir, FINAL_SAVE_NAME)
                config_path = os.path.join(model_dir, CONFIG_NAME)
                model_config = CONFIG_CLASS.from_json_file(config_path)
                model = MODEL_CLASS(model_config).cuda()
                state_dict = torch.load(model_path)
                model.load_state_dict(state_dict)
                if not args.fp32:
                    model = FP16_Module(model)
                if args.seq_train_type in REG_TYPE_KEYS:
                    logger.info("calulating reg_params ...")
                    train_qadata = QADataset(train_dataset, "train",
                                             SPECIAL_TOKEN_IDS[tasks[0]],
                                             train_extra_data)
                    max_train_batch_size = max(
                        len(train_qadata) // args.min_n_steps,
                        args.min_batch_size)
                    train_dataloader = create_dataloader(
                        train_qadata, "train", max_train_batch_size)
                    parallel_model = DataParallelModel(WrapModel(model),
                                                       args.device_ids)
                    regularizer = REG_TYPES[args.seq_train_type](
                        model, parallel_model, [train_dataloader], tasks[0])
                    regularizer.task_start_do()
                    regularizer.task_end_do()
                    torch.save(model.state_dict(),
                               os.path.join(model_dir, FINAL_SAVE_NAME))
                    logger.info("done reg_params!")
            args.skip_tasks.remove(tasks[0])
            return model

    model.resize_token_embeddings(
        len(TOKENIZER) if not args.multitask_specific else len(TOKENIZER) + 4)
    if args.multitask_specific:
        for i in range(4):
            TOKENS_WEIGHT = torch.cat((TOKENS_WEIGHT, torch.ones([1]).cuda()))
    if args.distil:
        teacher_model = MODEL_CLASS.from_pretrained(args.model_name).cuda()
        teacher_vocab_size = json.load(
            open("models/gpt2/lll/{task}_0.2/{task}/config.json".format(
                task=tasks[0])))['vocab_size']
        teacher_model.resize_token_embeddings(teacher_vocab_size)
        print("load teacher model from {}".format(
            "models/gpt2/lll/{task}_0.2/{task}/model-finish".format(
                task=tasks[0])))
        teacher_model.load_state_dict(
            torch.load("models/gpt2/lll/{task}_0.2/{task}/model-finish".format(
                task=tasks[0])))
        if not args.fp32:
            teacher_model = FP16_Module(teacher_model)
        teacher_model.eval()
        teacher_model = DataParallelModel(WrapModel(teacher_model),
                                          args.device_ids)

    if not args.fp32:  # again because resize_token_embeddings makes embedding layer fp32
        model = FP16_Module(model)

    parallel_model = DataParallelModel(WrapModel(model), args.device_ids)

    train_qadata = QADataset(train_dataset, "train",
                             SPECIAL_TOKEN_IDS[tasks[0]], train_extra_data)
    max_train_batch_size = max(
        len(train_qadata) // args.min_n_steps, args.min_batch_size)
    train_dataloader = create_dataloader(train_qadata, "train",
                                         max_train_batch_size)
    if not args.unbound and args.seq_train_type not in [
            "multitask", "multilm"
    ]:
        #n_train_epochs = TASK_DICT[tasks[0]]["n_train_epochs"]
        n_train_epochs = args.n_train_epochs[tasks[0]]
    else:
        n_train_epochs = args.n_train_epochs['_'.join(tasks)]
    n_train_optimization_steps = len(train_qadata) * n_train_epochs
    logger.info(
        'len of train dataset: {} , max train batch size {} , num of opt steps: {}'
        .format(len(train_qadata), max_train_batch_size,
                n_train_optimization_steps))

    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        args.weight_decay
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]

    if "gem" in args.seq_train_type:
        model.task_id = task_ids[0]
        if not hasattr(model, "grad_dims"):
            model.grad_dims = []
            for param in model.parameters():
                model.grad_dims.append(param.data.numel())
        if not hasattr(model, "grads"):
            model.grads = torch.zeros(sum(model.grad_dims), len(args.tasks))
            model.grads = model.grads.cuda()

    if args.seq_train_type in REG_TYPE_KEYS:
        optimizer = Weight_Regularized_AdamW(optimizer_grouped_parameters,
                                             lr=args.learning_rate,
                                             eps=args.adam_epsilon)
    else:
        optimizer = AdamW(optimizer_grouped_parameters,
                          lr=args.learning_rate,
                          eps=args.adam_epsilon)
    if not args.fp32:
        optimizer = FP16_Optimizer(optimizer,
                                   static_loss_scale=None,
                                   dynamic_loss_scale=True,
                                   dynamic_loss_args={
                                       'scale_window': 100,
                                       'min_scale': 1,
                                       'delayed_shift': 2
                                   })

    scheduler = AnnealingLR(optimizer,
                            start_lr=args.learning_rate,
                            warmup_iter=int(args.n_warmup_ratio *
                                            len(train_qadata)),
                            num_iters=int(n_train_optimization_steps),
                            decay_style=args.decay_style)
    train_loss_fct = DataParallelCriterion(
        CrossEntropyLoss(ignore_index=FILL_VAL, weight=TOKENS_WEIGHT),
        args.device_ids)
    if args.distil:
        kd_loss_fct = DataParallelCriterion(
            nn.KLDivLoss(reduction="batchmean"), args.device_ids)

    if args.seq_train_type in REG_TYPE_KEYS:
        copy_train_dataloader = create_dataloader(train_qadata, "train",
                                                  max_train_batch_size)
        prev_task = args.tasks[task_ids[0] - 1]
        regularizer = REG_TYPES[args.seq_train_type](model, parallel_model,
                                                     [copy_train_dataloader],
                                                     tasks[0], prev_task)
        regularizer.task_start_do()

    tot_n_steps = 0
    train_once = TrainStep(model, optimizer, scheduler)
    if "gem" in args.seq_train_type and task_ids[0] != 0:
        gem_step = GEMStep(model, parallel_model, train_loss_fct, optimizer)
    model.train()
    for ep in range(n_train_epochs):
        cum_loss, cum_qa_loss, cum_lm_loss, cur_n_inputs = 0, 0, 0, 0
        for n_steps, (_, _, cqa, _, Y, gen_X, gen_Y,
                      is_extra) in enumerate(train_dataloader):

            n_inputs = sum(_cqa.shape[0] for _cqa in cqa)
            if args.multitask_specific:
                for i in range(len(is_extra)):
                    gen_X[i][:, 0] += is_extra[i]
                    is_extra[i] = is_extra[i] * 0

            for i in range(len(cqa)):
                cqa[i] = (cqa[i].to(args.device_ids[i]), )
                Y[i] = Y[i].to(args.device_ids[i])
                gen_X[i] = (gen_X[i].to(args.device_ids[i]), )
                gen_Y[i] = gen_Y[i].to(args.device_ids[i])
                is_extra[i] = is_extra[i].to(args.device_ids[i])

            if args.distil:
                losses = get_distil_losses(teacher_model,
                                           parallel_model,
                                           cqa,
                                           Y,
                                           gen_X,
                                           gen_Y,
                                           is_extra,
                                           kd_loss_fct,
                                           train_loss_fct,
                                           args.temperature_kd,
                                           pad_idx=FILL_VAL)
            else:
                losses = get_losses(parallel_model, cqa, Y, gen_X, gen_Y,
                                    train_loss_fct)
            loss = sum(losses)
            if "gem" in args.seq_train_type and task_ids[0] != 0:
                gem_step(task_ids[0])
            train_once(loss, n_inputs)

            qa_loss = losses[0].item() * n_inputs
            lm_loss = losses[1].item() * n_inputs
            cum_loss += (qa_loss + lm_loss)
            cum_qa_loss += qa_loss
            cum_lm_loss += lm_loss
            cur_n_inputs += n_inputs

            if (n_steps + 1) % args.logging_steps == 0:
                logger.info(
                    'progress {:.3f} , lr {:.1E} , loss {:.3f} , qa loss {:.3f} , lm loss {:.3f} , avg batch size {:.1f}'
                    .format(ep + cur_n_inputs / len(train_qadata),
                            scheduler.get_lr(), cum_loss / cur_n_inputs,
                            cum_qa_loss / cur_n_inputs,
                            cum_lm_loss / cur_n_inputs,
                            cur_n_inputs / (n_steps + 1)))

        torch.save(model.state_dict(),
                   os.path.join(model_dir, SAVE_NAME + str(ep + 1)))
        tot_n_steps += (n_steps + 1)
        logger.info(
            'epoch {}/{} done , tot steps {} , lr {:.1E} , loss {:.2f} , qa loss {:.2f} , lm loss {:.2f} , avg batch size {:.1f}'
            .format(ep + 1, n_train_epochs, tot_n_steps, scheduler.get_lr(),
                    cum_loss / cur_n_inputs, cum_qa_loss / cur_n_inputs,
                    cum_lm_loss / cur_n_inputs, cur_n_inputs / (n_steps + 1)))

    # task end do for reg
    if args.seq_train_type in REG_TYPE_KEYS:
        regularizer.task_end_do()
    torch.save(model.state_dict(), os.path.join(model_dir, FINAL_SAVE_NAME))

    return model