Exemplo n.º 1
0
def main():
    logger.info('Launching the MT-DNN training')
    opt = vars(args)
    # update data dir
    opt['data_dir'] = data_dir
    batch_size = args.batch_size
    train_data_list = []
    tasks = {}
    tasks_class = {}
    nclass_list = []
    decoder_opts = []
    dropout_list = []

    for dataset in args.train_datasets:
        prefix = dataset.split('_')[0]
        if prefix in tasks: continue
        assert prefix in DATA_META
        assert prefix in DATA_TYPE
        data_type = DATA_TYPE[prefix]
        nclass = DATA_META[prefix]
        task_id = len(tasks)
        if args.mtl_opt > 0:
            task_id = tasks_class[nclass] if nclass in tasks_class else len(
                tasks_class)

        task_type = TASK_TYPE[prefix]
        pw_task = False
        if prefix in opt['pw_tasks']:
            pw_task = True

        dopt = generate_decoder_opt(prefix, opt['answer_opt'])
        if task_id < len(decoder_opts):
            decoder_opts[task_id] = min(decoder_opts[task_id], dopt)
        else:
            decoder_opts.append(dopt)

        if prefix not in tasks:
            tasks[prefix] = len(tasks)
            if args.mtl_opt < 1: nclass_list.append(nclass)

        if (nclass not in tasks_class):
            tasks_class[nclass] = len(tasks_class)
            if args.mtl_opt > 0: nclass_list.append(nclass)

        dropout_p = args.dropout_p
        if tasks_config and prefix in tasks_config:
            dropout_p = tasks_config[prefix]
        dropout_list.append(dropout_p)

        train_path = os.path.join(data_dir, '{}_train.json'.format(dataset))
        logger.info('Loading {} as task {}'.format(train_path, task_id))
        train_data = BatchGen(BatchGen.load(train_path,
                                            True,
                                            pairwise=pw_task,
                                            maxlen=args.max_seq_len),
                              batch_size=batch_size,
                              dropout_w=args.dropout_w,
                              gpu=args.cuda,
                              task_id=task_id,
                              maxlen=args.max_seq_len,
                              pairwise=pw_task,
                              data_type=data_type,
                              task_type=task_type)
        train_data_list.append(train_data)

    opt['answer_opt'] = decoder_opts
    opt['tasks_dropout_p'] = dropout_list

    args.label_size = ','.join([str(l) for l in nclass_list])
    logger.info(args.label_size)
    dev_data_list = []
    test_data_list = []
    for dataset in args.test_datasets:
        prefix = dataset.split('_')[0]
        if args.mtl_opt > 0:
            task_id = tasks_class[DATA_META[prefix]]
        else:
            task_id = tasks[prefix]
        task_type = TASK_TYPE[prefix]

        pw_task = False
        if prefix in opt['pw_tasks']:
            pw_task = True

        assert prefix in DATA_TYPE
        data_type = DATA_TYPE[prefix]

        dev_path = os.path.join(data_dir, '{}_dev.json'.format(dataset))
        dev_data = None
        if os.path.exists(dev_path):
            dev_data = BatchGen(BatchGen.load(dev_path,
                                              False,
                                              pairwise=pw_task,
                                              maxlen=args.max_seq_len),
                                batch_size=args.batch_size_eval,
                                gpu=args.cuda,
                                is_train=False,
                                task_id=task_id,
                                maxlen=args.max_seq_len,
                                pairwise=pw_task,
                                data_type=data_type,
                                task_type=task_type)
        dev_data_list.append(dev_data)

        test_path = os.path.join(data_dir, '{}_test.json'.format(dataset))
        test_data = None
        if os.path.exists(test_path):
            test_data = BatchGen(BatchGen.load(test_path,
                                               False,
                                               pairwise=pw_task,
                                               maxlen=args.max_seq_len),
                                 batch_size=args.batch_size_eval,
                                 gpu=args.cuda,
                                 is_train=False,
                                 task_id=task_id,
                                 maxlen=args.max_seq_len,
                                 pairwise=pw_task,
                                 data_type=data_type,
                                 task_type=task_type)
        test_data_list.append(test_data)

    logger.info('#' * 20)
    logger.info(opt)
    logger.info('#' * 20)

    all_iters = [iter(item) for item in train_data_list]
    all_lens = [len(bg) for bg in train_data_list]
    num_all_batches = args.epochs * sum(all_lens)

    if len(train_data_list) > 1 and args.ratio > 0:
        num_all_batches = int(args.epochs * (len(train_data_list[0]) *
                                             (1 + args.ratio)))

    model_path = args.init_checkpoint
    state_dict = None

    if os.path.exists(model_path):
        state_dict = torch.load(model_path)
        config = state_dict['config']
        config['attention_probs_dropout_prob'] = args.bert_dropout_p
        config['hidden_dropout_prob'] = args.bert_dropout_p
        opt.update(config)
    else:
        logger.error('#' * 20)
        logger.error(
            'Could not find the init model!\n The parameters will be initialized randomly!'
        )
        logger.error('#' * 20)
        config = BertConfig(vocab_size_or_config_json_file=30522).to_dict()
        opt.update(config)

    model = MTDNNModel(opt,
                       state_dict=state_dict,
                       num_train_step=num_all_batches)
    ####model meta str
    headline = '############# Model Arch of MT-DNN #############'
    ###print network
    logger.info('\n{}\n{}\n'.format(headline, model.network))

    # dump config
    config_file = os.path.join(output_dir, 'config.json')
    with open(config_file, 'w', encoding='utf-8') as writer:
        writer.write('{}\n'.format(json.dumps(opt)))
        writer.write('\n{}\n{}\n'.format(headline, model.network))

    logger.info("Total number of params: {}".format(model.total_param))

    if args.freeze_layers > 0:
        model.network.freeze_layers(args.freeze_layers)

    if args.cuda:
        model.cuda()
    for epoch in range(0, args.epochs):
        logger.warning('At epoch {}'.format(epoch))
        for train_data in train_data_list:
            train_data.reset()
        start = datetime.now()
        all_indices = []
        if len(train_data_list) > 1 and args.ratio > 0:
            main_indices = [0] * len(train_data_list[0])
            extra_indices = []
            for i in range(1, len(train_data_list)):
                extra_indices += [i] * len(train_data_list[i])
            random_picks = int(
                min(len(train_data_list[0]) * args.ratio, len(extra_indices)))
            extra_indices = np.random.choice(extra_indices,
                                             random_picks,
                                             replace=False)
            if args.mix_opt > 0:
                extra_indices = extra_indices.tolist()
                random.shuffle(extra_indices)
                all_indices = extra_indices + main_indices
            else:
                all_indices = main_indices + extra_indices.tolist()

        else:
            for i in range(1, len(train_data_list)):
                all_indices += [i] * len(train_data_list[i])
            if args.mix_opt > 0:
                random.shuffle(all_indices)
            all_indices += [0] * len(train_data_list[0])
        if args.mix_opt < 1:
            random.shuffle(all_indices)

        for i in range(len(all_indices)):
            task_id = all_indices[i]
            batch_meta, batch_data = next(all_iters[task_id])
            model.update(batch_meta, batch_data)
            if (model.updates
                ) % args.log_per_updates == 0 or model.updates == 1:
                logger.info(
                    'Task [{0:2}] updates[{1:6}] train loss[{2:.5f}] remaining[{3}]'
                    .format(
                        task_id, model.updates, model.train_loss.avg,
                        str((datetime.now() - start) / (i + 1) *
                            (len(all_indices) - i - 1)).split('.')[0]))

        for idx, dataset in enumerate(args.test_datasets):
            prefix = dataset.split('_')[0]
            label_dict = GLOBAL_MAP.get(prefix, None)
            dev_data = dev_data_list[idx]
            if dev_data is not None:
                dev_metrics, dev_predictions, scores, golds, dev_ids = eval_model(
                    model, dev_data, dataset=prefix, use_cuda=args.cuda)
                for key, val in dev_metrics.items():
                    logger.warning(
                        "Task {0} -- epoch {1} -- Dev {2}: {3:.3f}".format(
                            dataset, epoch, key, val))
                score_file = os.path.join(
                    output_dir, '{}_dev_scores_{}.json'.format(dataset, epoch))
                results = {
                    'metrics': dev_metrics,
                    'predictions': dev_predictions,
                    'uids': dev_ids,
                    'scores': scores
                }
                dump(score_file, results)
                official_score_file = os.path.join(
                    output_dir, '{}_dev_scores_{}.tsv'.format(dataset, epoch))
                submit(official_score_file, results, label_dict)

            # test eval
            test_data = test_data_list[idx]
            if test_data is not None:
                # For eval_model function, with_label = True specifies that evaluation metrics will be reported for test data -
                # this was presumably disabled by authors as it is bad practice in hyperparameter tuning, however it is the most convenient
                # way to get test scores. To avoid bias, hyperparameter decisions are made based on dev evaluation metrics, and test metrics
                # are only recorded for the final versions of models.
                test_metrics, test_predictions, scores, golds, test_ids = eval_model(
                    model,
                    test_data,
                    dataset=prefix,
                    use_cuda=args.cuda,
                    with_label=True)
                score_file = os.path.join(
                    output_dir,
                    '{}_test_scores_{}.json'.format(dataset, epoch))
                results = {
                    'metrics': test_metrics,
                    'predictions': test_predictions,
                    'uids': test_ids,
                    'scores': scores
                }
                dump(score_file, results)
                official_score_file = os.path.join(
                    output_dir, '{}_test_scores_{}.tsv'.format(dataset, epoch))
                submit(official_score_file, results, label_dict)
                logger.info('[new test scores saved.]')

        model_file = os.path.join(output_dir, 'model_{}.pt'.format(epoch))
        model.save(model_file)
Exemplo n.º 2
0
def main():
    logger.info('Launching the MT-DNN training')
    opt = vars(args)
    # update data dir
    opt['data_dir'] = data_dir
    batch_size = args.batch_size
    train_data_list = []
    tasks = {}
    tasks_class = {}
    nclass_list = []
    decoder_opts = []
    dropout_list = []

    for dataset in args.train_datasets:
        prefix = dataset.split('_')[0]
        if prefix in tasks: continue
        assert prefix in DATA_META
        assert prefix in DATA_TYPE
        data_type = DATA_TYPE[prefix]
        nclass = DATA_META[prefix]
        task_id = len(tasks)
        if args.mtl_opt > 0:
            task_id = tasks_class[nclass] if nclass in tasks_class else len(
                tasks_class)

        task_type = TASK_TYPE[prefix]
        pw_task = False
        if prefix in opt['pw_tasks']:
            pw_task = True

        dopt = generate_decoder_opt(prefix, opt['answer_opt'])
        if task_id < len(decoder_opts):
            decoder_opts[task_id] = min(decoder_opts[task_id], dopt)
        else:
            decoder_opts.append(dopt)

        if prefix not in tasks:
            tasks[prefix] = len(tasks)
            if args.mtl_opt < 1: nclass_list.append(nclass)

        if (nclass not in tasks_class):
            tasks_class[nclass] = len(tasks_class)
            if args.mtl_opt > 0: nclass_list.append(nclass)

        dropout_p = args.dropout_p
        if tasks_config and prefix in tasks_config:
            dropout_p = tasks_config[prefix]
        dropout_list.append(dropout_p)

        train_data_ratio_string = str(
            args.train_data_ratio) + "p" if args.train_data_ratio < 100 else ""

        train_path = os.path.join(
            data_dir, '{0}_train{1}.json'.format(dataset,
                                                 train_data_ratio_string))
        logger.info('Loading {} as task {}'.format(train_path, task_id))
        train_data = BatchGen(BatchGen.load(train_path,
                                            True,
                                            pairwise=pw_task,
                                            maxlen=args.max_seq_len),
                              batch_size=batch_size,
                              dropout_w=args.dropout_w,
                              gpu=args.cuda,
                              task_id=task_id,
                              maxlen=args.max_seq_len,
                              pairwise=pw_task,
                              data_type=data_type,
                              task_type=task_type)
        train_data_list.append(train_data)

    opt['answer_opt'] = decoder_opts
    opt['tasks_dropout_p'] = dropout_list

    args.label_size = ','.join([str(l) for l in nclass_list])
    logger.info(args.label_size)
    dev_data_list = []
    test_data_list = []
    for dataset in args.test_datasets:
        prefix = dataset.split('_')[0]
        task_id = tasks_class[
            DATA_META[prefix]] if args.mtl_opt > 0 else tasks[prefix]
        task_type = TASK_TYPE[prefix]

        pw_task = False
        if prefix in opt['pw_tasks']:
            pw_task = True

        assert prefix in DATA_TYPE
        data_type = DATA_TYPE[prefix]

        dev_path = os.path.join(data_dir, '{}_dev.json'.format(dataset))
        dev_data = None
        if os.path.exists(dev_path):
            dev_data = BatchGen(BatchGen.load(dev_path,
                                              False,
                                              pairwise=pw_task,
                                              maxlen=args.max_seq_len),
                                batch_size=args.batch_size_eval,
                                gpu=args.cuda,
                                is_train=False,
                                task_id=task_id,
                                maxlen=args.max_seq_len,
                                pairwise=pw_task,
                                data_type=data_type,
                                task_type=task_type)
        dev_data_list.append(dev_data)

        test_path = os.path.join(data_dir, '{}_test.json'.format(dataset))
        test_data = None
        if os.path.exists(test_path):
            test_data = BatchGen(BatchGen.load(test_path,
                                               False,
                                               pairwise=pw_task,
                                               maxlen=args.max_seq_len),
                                 batch_size=args.batch_size_eval,
                                 gpu=args.cuda,
                                 is_train=False,
                                 task_id=task_id,
                                 maxlen=args.max_seq_len,
                                 pairwise=pw_task,
                                 data_type=data_type,
                                 task_type=task_type)
        test_data_list.append(test_data)

    logger.info('#' * 20)
    logger.info(opt)
    logger.info('#' * 20)

    all_iters = [iter(item) for item in train_data_list]
    all_lens = [len(bg) for bg in train_data_list]
    num_all_batches = args.epochs * sum(all_lens)

    if len(train_data_list) > 1 and args.ratio > 0:
        num_all_batches = int(args.epochs * (len(train_data_list[0]) *
                                             (1 + args.ratio)))

    model_path = args.init_checkpoint
    state_dict = None

    if os.path.exists(model_path):
        state_dict = torch.load(model_path, map_location='cpu')
        config = state_dict['config']
        config['attention_probs_dropout_prob'] = args.bert_dropout_p
        config['hidden_dropout_prob'] = args.bert_dropout_p
        opt.update(config)
    else:
        logger.error('#' * 20)
        logger.error('Could not find the init model!\n Exit application!')
        logger.error('#' * 20)
        try:
            shutil.rmtree(output_dir)
        except Exception as e:
            print(e)
        exit(1)

    model = MTDNNModel(opt,
                       state_dict=state_dict,
                       num_train_step=num_all_batches)
    ####model meta str
    headline = '############# Model Arch of MT-DNN #############'
    ###print network
    logger.info('\n{}\n{}\n'.format(headline, model.network))

    # dump config
    config_file = os.path.join(output_dir, 'config.json')
    with open(config_file, 'w', encoding='utf-8') as writer:
        writer.write('{}\n'.format(json.dumps(opt)))
        writer.write('\n{}\n{}\n'.format(headline, model.network))

    logger.info("Total number of params: {}".format(model.total_param))

    if args.freeze_layers > 0:
        model.network.freeze_layers(args.freeze_layers)

    if args.cuda:
        model.cuda()

    best_F1_macro = -1.0
    for epoch in range(0, args.epochs):
        logger.warning('At epoch {}'.format(epoch))
        for train_data in train_data_list:
            train_data.reset()
        start = datetime.now()
        all_indices = []
        if len(train_data_list) > 1 and (args.ratio > 0 or
                                         args.reduce_first_dataset_ratio > 0):
            main_indices = [0] * (int(args.reduce_first_dataset_ratio * len(
                train_data_list[0])) if args.reduce_first_dataset_ratio > 0
                                  else len(train_data_list[0]))
            extra_indices = []
            for i in range(1, len(train_data_list)):
                extra_indices += [i] * len(train_data_list[i])
            if args.ratio > 0:
                random_picks = int(
                    min(
                        len(train_data_list[0]) * args.ratio,
                        len(extra_indices)))
                extra_indices = np.random.choice(extra_indices,
                                                 random_picks,
                                                 replace=False).tolist()
            if args.mix_opt > 0:
                extra_indices = extra_indices
                random.shuffle(extra_indices)
                all_indices = extra_indices + main_indices
            else:
                all_indices = main_indices + extra_indices
            logger.info(
                "Main batches loaded (first dataset in list): {}".format(
                    len(main_indices)))
            logger.info(
                "Extra batches loaded (all except first dataset in list): {}".
                format(len(extra_indices)))

        else:  # shuffle the index of the train sets whose batches will be trained on in the order: e.g. if train_set[1] is large, it will get trained on more often
            for i in range(1, len(train_data_list)):
                all_indices += [i] * len(train_data_list[i])
            if args.mix_opt > 0:
                random.shuffle(all_indices)
            all_indices += [0] * len(train_data_list[0])
        if args.mix_opt < 1:
            random.shuffle(all_indices)

        for i in range(len(all_indices)):
            task_id = all_indices[i]
            batch_meta, batch_data = next(all_iters[task_id])
            model.update(batch_meta, batch_data)
            if (model.updates
                ) % args.log_per_updates == 0 or model.updates == 1:
                logger.info(
                    'Task [{0:2}] updates[{1:6}] train loss[{2:.5f}] remaining[{3}]'
                    .format(
                        task_id, model.updates, model.train_loss.avg,
                        str((datetime.now() - start) / (i + 1) *
                            (len(all_indices) - i - 1)).split('.')[0]))

        temp_dev_F1s = []
        dev_dump_list = []
        test_dump_list = []
        for idx, dataset in enumerate(args.test_datasets):
            prefix = dataset.split('_')[0]
            label_dict = GLOBAL_MAP.get(prefix, None)
            dev_data = dev_data_list[idx]
            if dev_data is not None:
                dev_metrics, dev_predictions, scores, golds, dev_ids, premises, hypotheses = eval_model(
                    model, dev_data, dataset=prefix, use_cuda=args.cuda)
                for key, val in dev_metrics.items():
                    if not isinstance(val, dict):
                        logger.warning(
                            "Task {0} -- epoch {1} -- Dev {2}: {3:.3f}".format(
                                dataset, epoch, key, val))
                score_file = os.path.join(
                    output_dir, '{}_dev_scores_{}.json'.format(dataset, epoch))
                results = {
                    'metrics': dev_metrics,
                    'predictions': dev_predictions,
                    'uids': dev_ids,
                    'scores': scores,
                    'golds': golds,
                    'premises': premises,
                    'hypotheses': hypotheses
                }
                dump(score_file, results)
                official_score_file = os.path.join(
                    output_dir, '{}_dev_scores_{}.tsv'.format(dataset, epoch))
                submit(official_score_file, results, label_dict)

                # for checkpoint
                temp_dev_F1s.append(dev_metrics['F1_macro'])
                dev_dump_list.append({
                    "output_dir": output_dir,
                    "dev_metrics": dev_metrics,
                    "dev_predictions": dev_predictions,
                    "golds": golds,
                    "opt": opt,
                    "dataset": dataset
                })

            # test eval
            test_data = test_data_list[idx]
            if test_data is not None:
                test_metrics, test_predictions, scores, golds, test_ids, premises, hypotheses = eval_model(
                    model,
                    test_data,
                    dataset=prefix,
                    use_cuda=args.cuda,
                    with_label=True)
                score_file = os.path.join(
                    output_dir,
                    '{}_test_scores_{}.json'.format(dataset, epoch))
                results = {
                    'metrics': test_metrics,
                    'predictions': test_predictions,
                    'uids': test_ids,
                    'scores': scores,
                    'golds': golds,
                    'premises': premises,
                    'hypotheses': hypotheses
                }
                dump(score_file, results)
                official_score_file = os.path.join(
                    output_dir, '{}_test_scores_{}.tsv'.format(dataset, epoch))
                submit(official_score_file, results, label_dict)
                logger.info('[new test scores saved.]')

                # for checkpoint
                test_dump_list.append({
                    "output_dir": output_dir,
                    "test_metrics": test_metrics,
                    "test_predictions": test_predictions,
                    "golds": golds,
                    "opt": opt,
                    "dataset": dataset
                })

        # save checkpoint
        if np.average(temp_dev_F1s) > best_F1_macro:
            print("Save new model! Current best F1 macro over all dev sets: " +
                  "{0:.2f}".format(best_F1_macro) + ". New: " +
                  "{0:.2f}".format(np.average(temp_dev_F1s)))
            best_F1_macro = np.average(temp_dev_F1s)

            # override current dump file
            for l in dev_dump_list:
                dump_result_files(l['dataset'])(l['output_dir'], epoch,
                                                l['dev_metrics'],
                                                str(l['dev_predictions']),
                                                str(l['golds']), "dev",
                                                l['opt'], l['dataset'])

            for l in test_dump_list:
                dump_result_files(l['dataset'])(l['output_dir'], epoch,
                                                l['test_metrics'],
                                                str(l['test_predictions']),
                                                str(l['golds']), "test",
                                                l['opt'], l['dataset'])

            # save model
            model_file = os.path.join(output_dir, 'model.pt')
            model.save(model_file)
def main():
    logger.info('Launching the MT-DNN training')
    opt = vars(args)
    # update data dir
    opt['data_dir'] = data_dir
    batch_size = args.batch_size
    train_data_list = []
    tasks = {}
    tasks_class = {}
    nclass_list = []
    decoder_opts = []
    dropout_list = []

    for dataset in args.train_datasets:
        prefix = dataset.split('_')[0]
        if prefix in tasks: continue
        assert prefix in DATA_META
        assert prefix in DATA_TYPE
        data_type = DATA_TYPE[prefix]
        nclass = DATA_META[prefix]
        task_id = len(tasks)
        if args.mtl_opt > 0:
            task_id = tasks_class[nclass] if nclass in tasks_class else len(tasks_class)

        task_type = TASK_TYPE[prefix]
        pw_task = False

        dopt = generate_decoder_opt(prefix, opt['answer_opt'])
        if task_id < len(decoder_opts):
            decoder_opts[task_id] = min(decoder_opts[task_id], dopt)
        else:
            decoder_opts.append(dopt)

        if prefix not in tasks:
            tasks[prefix] = len(tasks)
            if args.mtl_opt < 1: nclass_list.append(nclass)

        if (nclass not in tasks_class):
            tasks_class[nclass] = len(tasks_class)
            if args.mtl_opt > 0: nclass_list.append(nclass)

        dropout_p = args.dropout_p
        if tasks_config and prefix in tasks_config:
            dropout_p = tasks_config[prefix]
        dropout_list.append(dropout_p)

        train_path = os.path.join(data_dir, '{}_train.json'.format(dataset))
        logger.info('Loading {} as task {}'.format(train_path, task_id))
        train_data = BatchGen(BatchGen.load(train_path, True, pairwise=pw_task, maxlen=args.max_seq_len, 
                                        opt=opt, dataset=dataset),
                                batch_size=batch_size,
                                dropout_w=args.dropout_w,
                                gpu=args.cuda,
                                task_id=task_id,
                                maxlen=args.max_seq_len,
                                pairwise=pw_task,
                                data_type=data_type,
                                task_type=task_type,
                                dataset_name=dataset)
        train_data.reset()
        train_data_list.append(train_data)

    opt['answer_opt'] = decoder_opts
    opt['tasks_dropout_p'] = dropout_list

    args.label_size = ','.join([str(l) for l in nclass_list])
    logger.info(args.label_size)
    dev_data_list = []
    test_data_list = []
    for dataset in args.test_datasets:
        prefix = dataset.split('_')[0]
        task_id = tasks_class[DATA_META[prefix]] if args.mtl_opt > 0 else tasks[prefix]
        task_type = TASK_TYPE[prefix]

        pw_task = False

        assert prefix in DATA_TYPE
        data_type = DATA_TYPE[prefix]

        if args.predict_split is not None:
            dev_path = os.path.join(data_dir, '{}_{}.json'.format(dataset, 
                args.predict_split))
        else:
            dev_path = os.path.join(data_dir, '{}_dev.json'.format(dataset))
        dev_data = None
        if os.path.exists(dev_path):
            dev_data = BatchGen(BatchGen.load(dev_path, False, pairwise=pw_task, maxlen=args.max_seq_len,
                                            opt=opt, dataset=dataset),
                                  batch_size=args.batch_size_eval,
                                  gpu=args.cuda, is_train=False,
                                  task_id=task_id,
                                  maxlen=args.max_seq_len,
                                  pairwise=pw_task,
                                  data_type=data_type,
                                  task_type=task_type,
                                  dataset_name=dataset)
        dev_data_list.append(dev_data)

        test_path = os.path.join(data_dir, '{}_test.json'.format(dataset))
        test_data = None
        if os.path.exists(test_path):
            test_data = BatchGen(BatchGen.load(test_path, False, pairwise=pw_task, 
                                            maxlen=args.max_seq_len,opt=opt, dataset=dataset),
                                  batch_size=args.batch_size_eval,
                                  gpu=args.cuda, is_train=False,
                                  task_id=task_id,
                                  maxlen=args.max_seq_len,
                                  pairwise=pw_task,
                                  data_type=data_type,
                                  task_type=task_type,
                                  dataset_name=dataset)
        test_data_list.append(test_data)

    logger.info('#' * 20)
    logger.info(opt)
    logger.info('#' * 20)

    all_iters =[iter(item) for item in train_data_list]
    all_lens = [len(bg) for bg in train_data_list]
    num_all_batches = args.epochs * sum(all_lens)

    if len(args.external_datasets) > 0 and args.external_include_ratio > 0:
        num_in_domain_batches = args.epochs* sum(all_lens[:-len(args.external_datasets)])
        num_all_batches = num_in_domain_batches * (1 + args.external_include_ratio)
    # pdb.set_trace()

    model_path = args.init_checkpoint
    state_dict = None

    if os.path.exists(model_path):
        state_dict = torch.load(model_path)
        if args.init_config is not None: # load huggingface model
            config = json.load(open(args.init_config))
            state_dict={'config':config, 'state':state_dict}
        if args.finetune:
            # only resume config and state
            del_keys=set(state_dict.keys())-set(['config','state'])
            for key in del_keys:
                del state_dict[key]
            resume_configs=json.load(open('config/resume_configs.json'))
            del_keys=set(state_dict['config'].keys())-set(resume_configs)
            for key in del_keys:
                del state_dict['config'][key]
            if args.resume_scoring is not None:                    
                for key in state_dict['state'].keys():
                    if 'scoring_list.0' in key:
                        state_dict['state'][key]=state_dict['state'][key.replace('0',str(args.resume_scoring))]
                        # other scorings will be deleted during loading process, since finetune only has one task
            elif not args.retain_scoring:
                del_keys = [k for k in state_dict['state'] if 'scoring_list' in k]
                for key in del_keys:                    
                    print('deleted previous weight:',key)
                    del state_dict['state'][key]
        config = state_dict['config']
        config['attention_probs_dropout_prob'] = args.bert_dropout_p
        config['hidden_dropout_prob'] = args.bert_dropout_p
        opt.update(config)
    else:
        logger.error('#' * 20)
        logger.error('Could not find the init model!\n The parameters will be initialized randomly!')
        logger.error('#' * 20)
        config = BertConfig(vocab_size_or_config_json_file=30522).to_dict()
        opt.update(config)

    model = MTDNNModel(opt, state_dict=state_dict, num_train_step=num_all_batches)
    ####model meta str
    headline = '############# Model Arch of MT-DNN #############'
    ###print network
    # logger.info('\n{}\n{}\n'.format(headline, model.network))

    # dump config
    config_file = os.path.join(output_dir, 'config.json')
    with open(config_file, 'w', encoding='utf-8') as writer:
        writer.write('{}\n'.format(json.dumps(opt)))
        writer.write('\n{}\n{}\n'.format(headline, model.network))

    logger.info("Total number of params: {}".format(model.total_param))

    if args.freeze_layers > 0:
        model.network.freeze_layers(args.freeze_layers)

    if args.cuda:
        model.cuda()
    best_epoch=-1
    best_performance=0 
    best_dataset_performance={dataset:{'perf':0,'epoch':-1} for dataset in args.mtl_observe_datasets}
    for epoch in range(0, args.epochs):
        logger.warning('At epoch {}'.format(epoch))
        if epoch==0 and args.freeze_bert_first:
            model.network.freeze_bert()
            logger.warning('Bert freezed.')
        if epoch==1 and args.freeze_bert_first:
            model.network.unfreeze_bert()
            logger.warning('Bert unfreezed.')
        start = datetime.now()
        all_indices=[]
        if len(args.external_datasets)> 0 and args.external_include_ratio>0:
            main_indices = []
            extra_indices = []
            for data_idx,batcher in enumerate(train_data_list):
                if batcher.dataset_name not in args.external_datasets:
                    main_indices += [data_idx] * len(batcher)
                else:
                    extra_indices += [data_idx] * len(batcher)

            random_picks=int(min(len(main_indices) * args.external_include_ratio, len(extra_indices)))
            extra_indices = np.random.choice(extra_indices, random_picks, replace=False)
            if args.mix_opt > 0:
                extra_indices = extra_indices.tolist()
                random.shuffle(extra_indices)
                all_indices = extra_indices + main_indices
            else:
                all_indices = main_indices + extra_indices.tolist()
        else:
            for i in range(1, len(train_data_list)):
                all_indices += [i] * len(train_data_list[i])
            if args.mix_opt > 0:
                random.shuffle(all_indices)
            all_indices += [0] * len(train_data_list[0])
        if args.mix_opt < 1:
            random.shuffle(all_indices)
        if args.test_mode:
            all_indices=all_indices[:2]
        if args.predict_split is not None:
            all_indices=[]
            dev_split=args.predict_split
        else:
            dev_split='dev'

        for i in range(len(all_indices)):
            task_id = all_indices[i]
            batch_meta, batch_data= next(all_iters[task_id])
            model.update(batch_meta, batch_data)
            if (model.updates) % args.log_per_updates == 0 or model.updates == 1:
                logger.info('Task [{0:2}] updates[{1:6}] train loss[{2:.5f}] remaining[{3}]'.format(task_id,
                    model.updates, model.train_loss.avg,
                    str((datetime.now() - start) / (i + 1) * (len(all_indices) - i - 1)).split('.')[0]))
        os.system('nvidia-smi')
        for train_data in train_data_list:
            train_data.reset()        

        this_performance={}

        for idx, dataset in enumerate(args.test_datasets):
            prefix = dataset.split('_')[0]
            dev_data = dev_data_list[idx]
            if dev_data is not None:
                dev_metrics, dev_predictions, scores, golds, dev_ids= eval_model(model, dev_data, dataset=prefix,
                                                                                 use_cuda=args.cuda)
                score_file = os.path.join(output_dir, '{}_{}_scores_{}.json'.format(dataset, dev_split, epoch))
                results = {'metrics': dev_metrics, 'predictions': dev_predictions, 'uids': dev_ids, 'scores': scores}
                dump(score_file, results)
                official_score_file = os.path.join(output_dir, '{}_{}_scores_{}.csv'.format(dataset, dev_split, epoch))
                submit(official_score_file, results,dataset_name=prefix, threshold=2.0+args.mediqa_score_offset)
                if prefix in mediqa_name_list:
                    logger.warning('self test numbers:{}'.format(dev_metrics))
                    if '_' in dataset:
                        affix = dataset.split('_')[1]
                        ground_truth_path=os.path.join(args.data_root,'mediqa/task3_qa/gt_{}_{}.csv'.format(dev_split,affix))
                    else:
                        ground_truth_path=os.path.join(args.data_root,'mediqa/task3_qa/gt_{}.csv'.format(dev_split))
                    official_result=eval_mediqa_official(pred_path=official_score_file, ground_truth_path=ground_truth_path, 
                        eval_qa_more=args.mediqa_eval_more)
                    logger.warning("MediQA dev eval result:{}".format(official_result))
                    if args.mediqa_eval_more:
                        dev_metrics={'ACC':official_result['score']*100,'Spearman':official_result['score_secondary']*100,
                                    'F1':dev_metrics['F1'], 'MRR':official_result['meta']['MRR'], 'MAP':official_result['MAP'],
                                    'P@1':official_result['meta']['P@1']}
                    else:
                        dev_metrics={'ACC':official_result['score']*100,'Spearman':official_result['score_secondary']*100}

                for key, val in dev_metrics.items():
                    logger.warning("Task {0} -- epoch {1} -- Dev {2}: {3:.3f}".format(dataset, epoch, key, val))
            if args.predict_split is not None:
                continue
            print('args.mtl_observe_datasets:',args.mtl_observe_datasets, dataset)
            if dataset in args.mtl_observe_datasets:
                this_performance[dataset]=np.mean([val for val in dev_metrics.values()])
            test_data = test_data_list[idx]
            if test_data is not None:
                test_metrics, test_predictions, scores, golds, test_ids= eval_model(model, test_data, dataset=prefix,
                                                                                 use_cuda=args.cuda, with_label=False)
                for key, val in test_metrics.items():
                    logger.warning("Task {0} -- epoch {1} -- Test {2}: {3:.3f}".format(dataset, epoch, key, val))
                score_file = os.path.join(output_dir, '{}_test_scores_{}.json'.format(dataset, epoch))
                results = {'metrics': test_metrics, 'predictions': test_predictions, 'uids': test_ids, 'scores': scores}
                dump(score_file, results)
                # if dataset in mediqa_name_list:
                official_score_file = os.path.join(output_dir, '{}_test_scores_{}.csv'.format(dataset, epoch))
                submit(official_score_file, results,dataset_name=prefix, threshold=2.0+args.mediqa_score_offset)
                logger.info('[new test scores saved.]')
        print('this_performance:',this_performance)
        if args.predict_split is not None:
            break
        epoch_performance = sum([val for val in this_performance.values()])
        if epoch_performance>best_performance:
            print('changed:',epoch_performance,best_performance)
            best_performance=epoch_performance
            best_epoch=epoch

        for dataset in args.mtl_observe_datasets:
            if best_dataset_performance[dataset]['perf']<this_performance[dataset]:
                best_dataset_performance[dataset]={'perf':this_performance[dataset],
                                                   'epoch':epoch} 


        print('current best:',best_performance,'at epoch', best_epoch)
        if not args.not_save_model:
            model_name = 'model_last.pt' if args.save_last else 'model_{}.pt'.format(epoch) 
            model_file = os.path.join(output_dir, model_name)
            if args.save_last and os.path.exists(model_file):
                model_temp=os.path.join(output_dir, 'model_secondlast.pt')
                copyfile(model_file, model_temp)
            model.save(model_file)
            if args.save_best and best_epoch==epoch:
                best_path = os.path.join(output_dir,'best_model.pt')
                copyfile(model_file,best_path)
                for dataset in args.mtl_observe_datasets:
                    if best_dataset_performance[dataset]['epoch']==epoch:
                        best_path = os.path.join(output_dir,'best_model_{}.pt'.format(dataset))
                        copyfile(model_file,best_path)
Exemplo n.º 4
0
def main():
    logger.info('Launching the MT-DNN training')
    opt = vars(args)
    # update data dir
    opt['data_dir'] = data_dir
    batch_size = args.batch_size
    train_data_list = []
    tasks = {}
    tasks_class = {}
    nclass_list = []
    decoder_opts = []
    dropout_list = []

    for dataset in args.train_datasets:
        prefix = dataset.split('_')[0]
        if prefix in tasks: continue
        assert prefix in DATA_META
        assert prefix in DATA_TYPE
        nclass = DATA_META[prefix]
        task_id = len(tasks)
        if args.mtl_opt > 0:
            task_id = tasks_class[nclass] if nclass in tasks_class else len(tasks_class)

        dopt = generate_decoder_opt(prefix, opt['answer_opt'])
        if task_id < len(decoder_opts):
            decoder_opts[task_id] = min(decoder_opts[task_id], dopt)
        else:
            decoder_opts.append(dopt)

        if prefix not in tasks:
            tasks[prefix] = len(tasks)
            if args.mtl_opt < 1: nclass_list.append(nclass)

        if (nclass not in tasks_class):
            tasks_class[nclass] = len(tasks_class)
            if args.mtl_opt > 0: nclass_list.append(nclass)

        dropout_p = args.dropout_p
        if tasks_config and prefix in tasks_config:
            dropout_p = tasks_config[prefix]
        dropout_list.append(dropout_p)

    opt['answer_opt'] = decoder_opts
    opt['tasks_dropout_p'] = dropout_list

    args.label_size = ','.join([str(l) for l in nclass_list])
    logger.info(args.label_size)
    dev_data_list = []
    test_data_list = []
    stress_data_list = []
    for dataset in args.test_datasets:
        prefix = dataset.split('_')[0]
        task_id = tasks_class[DATA_META[prefix]] if args.mtl_opt > 0 else tasks[prefix]
        task_type = TASK_TYPE[prefix]

        pw_task = False
        if prefix in opt['pw_tasks']:
            pw_task = True

        assert prefix in DATA_TYPE
        data_type = DATA_TYPE[prefix]

        dev_path = os.path.join(data_dir, '{}_dev.json'.format(dataset))
        dev_data = None
        if os.path.exists(dev_path):
            dev_data = BatchGen(BatchGen.load(dev_path, False, pairwise=pw_task, maxlen=args.max_seq_len),
                                batch_size=args.batch_size_eval,
                                gpu=args.cuda, is_train=False,
                                task_id=task_id,
                                maxlen=args.max_seq_len,
                                pairwise=pw_task,
                                data_type=data_type,
                                task_type=task_type)
        dev_data_list.append(dev_data)

        test_path = os.path.join(data_dir, '{}_test.json'.format(dataset))
        test_data = None
        if os.path.exists(test_path):
            test_data = BatchGen(BatchGen.load(test_path, False, pairwise=pw_task, maxlen=args.max_seq_len),
                                 batch_size=args.batch_size_eval,
                                 gpu=args.cuda, is_train=False,
                                 task_id=task_id,
                                 maxlen=args.max_seq_len,
                                 pairwise=pw_task,
                                 data_type=data_type,
                                 task_type=task_type)
        test_data_list.append(test_data)

        stress_data = []
        if args.stress_tests != "NONE":
            for stress_test in args.stress_tests.split(','):
                stress_path = os.path.join(data_dir, '{}_test_{}.json'.format(dataset, stress_test))
                if os.path.exists(stress_path):
                    stress_data.append(BatchGen(BatchGen.load(stress_path, False, pairwise=pw_task, maxlen=args.max_seq_len),
                                         batch_size=args.batch_size_eval,
                                         gpu=args.cuda, is_train=False,
                                         task_id=task_id,
                                         maxlen=512,
                                         pairwise=pw_task,
                                         data_type=data_type,
                                         task_type=task_type)  )
            stress_data_list.append(stress_data)


    logger.info('#' * 20)
    logger.info(opt)
    logger.info('#' * 20)

    all_lens = [len(bg) for bg in train_data_list]
    num_all_batches = args.epochs * sum(all_lens)

    if len(train_data_list) > 1 and args.ratio > 0:
        num_all_batches = int(args.epochs * (len(train_data_list[0]) * (1 + args.ratio)))

    model_path = args.init_checkpoint
    state_dict = None

    if os.path.exists(model_path):
        state_dict = torch.load(model_path)
        config = state_dict['config']
        config['attention_probs_dropout_prob'] = args.bert_dropout_p
        config['hidden_dropout_prob'] = args.bert_dropout_p
        opt.update(config)
    else:
        logger.error('#' * 20)
        logger.error('Could not find the init model!\n Exit application!')
        logger.error('#' * 20)


    model = MTDNNModel(opt, state_dict=state_dict, num_train_step=num_all_batches)
    ####model meta str
    headline = '############# Model Arch of MT-DNN #############'
    ###print network
    logger.info('\n{}\n{}\n'.format(headline, model.network))

    # dump config
    config_file = os.path.join(output_dir, 'config.json')
    with open(config_file, 'w', encoding='utf-8') as writer:
        writer.write('{}\n'.format(json.dumps(opt)))
        writer.write('\n{}\n{}\n'.format(headline, model.network))

    logger.info("Total number of params: {}".format(model.total_param))

    if args.freeze_layers > 0:
        model.network.freeze_layers(args.freeze_layers)

    if args.cuda:
        model.cuda()
    for epoch in range(0, 1):
        dev_dump_list = []
        test_dump_list = []
        stress_dump_list = []
        for idx, dataset in enumerate(args.test_datasets):
            prefix = dataset.split('_')[0]
            label_dict = GLOBAL_MAP.get(prefix, None)
            dev_data = dev_data_list[idx]
            if dev_data is not None:
                dev_metrics, dev_predictions, scores, golds, dev_ids, premises, hypotheses = eval_model(model, dev_data, dataset=prefix,
                                                                                 use_cuda=args.cuda)
                for key, val in dev_metrics.items():
                    if not isinstance(val, dict):
                        logger.warning("Task {0} -- epoch {1} -- Dev {2}: {3:.3f}".format(dataset, epoch, key, val))

                if args.dump_to_checkpoints == 1:
                    score_file = os.path.join(output_dir, '{}_dev_scores_{}_EVAL_ONLY.json'.format(dataset, epoch))
                    results = {'metrics': dev_metrics, 'predictions': dev_predictions, 'uids': dev_ids,
                               'scores': scores, 'golds': golds,
                               'premises': premises, 'hypotheses': hypotheses}
                    dump(score_file, results)
                    official_score_file = os.path.join(output_dir,
                                                       '{}_dev_scores_{}_EVAL_ONLY.tsv'.format(dataset, epoch))
                    submit(official_score_file, results, label_dict)

                # for checkpoint
                dev_dump_list.append({
                    "output_dir": output_dir,
                    "dev_metrics": dev_metrics,
                    "dev_predictions": dev_predictions,
                    "golds": golds,
                    "opt": opt,
                    "dataset": dataset
                })

            # test eval
            test_data = test_data_list[idx]
            if test_data is not None:
                test_metrics, test_predictions, scores, golds, test_ids, premises, hypotheses = eval_model(model, test_data, dataset=prefix,
                                                                                 use_cuda=args.cuda, with_label=True)

                if args.dump_to_checkpoints == 1:
                    score_file = os.path.join(output_dir, '{}_test_scores_{}_EVAL_ONLY.json'.format(dataset, epoch))
                    results = {'metrics': test_metrics, 'predictions': test_predictions, 'uids': test_ids, 'scores': scores, 'golds': golds,
                               'premises': premises, 'hypotheses': hypotheses}
                    dump(score_file, results)
                    official_score_file = os.path.join(output_dir, '{}_test_scores_{}_EVAL_ONLY.tsv'.format(dataset, epoch))
                    submit(official_score_file, results, label_dict)
                    logger.info('[new test scores saved.]')

                # for checkpoint
                test_dump_list.append({
                    "output_dir": output_dir,
                    "test_metrics": test_metrics,
                    "test_predictions": test_predictions,
                    "golds": golds,
                    "opt": opt,
                    "dataset": dataset
                })

            # stress test eval
            if args.stress_tests != "NONE":
                stress_data = stress_data_list[idx]
                for j, stress_test in enumerate(args.stress_tests.split(',')):
                    stress_metrics, stress_predictions, scores, golds, stress_ids, premises, hypotheses = \
                        eval_model(model, stress_data[j], dataset=prefix, use_cuda=args.cuda, with_label=True)

                    if args.dump_to_checkpoints == 1:
                        score_file = os.path.join(output_dir, '{}_test_{}_scores_{}_EVAL_ONLY.json'.format(dataset, stress_test, epoch))
                        results = {'metrics': stress_metrics, 'predictions': stress_predictions, 'uids': stress_ids, 'scores': scores, 'golds': golds,
                                   'premises': premises, 'hypotheses': hypotheses}
                        dump(score_file, results)
                        official_score_file = os.path.join(output_dir, '{}_test_{}_scores_{}_EVAL_ONLY.tsv'.format(dataset, stress_test, epoch))
                        submit(official_score_file, results, label_dict)
                        logger.info('[new stress test scores for "{}" saved.]'.format(stress_test))

                    # for checkpoint
                    stress_dump_list.append({
                        "output_dir": output_dir,
                        "test_metrics": stress_metrics,
                        "test_predictions": stress_predictions,
                        "golds": golds,
                        "opt": opt,
                        "dataset": dataset,
                        "stress_test": stress_test
                    })



        # save results
        print("Save new results!")

        for l in dev_dump_list:
            dump_result_files(l['dataset'])(l['output_dir'], -1, l['dev_metrics'], str(l['dev_predictions']),
                                            str(l['golds']), "dev", l['opt'], l['dataset'])
        for l in test_dump_list:
            dump_result_files(l['dataset'])(l['output_dir'], -1, l['test_metrics'], str(l['test_predictions']),
                                            str(l['golds']), "test", l['opt'], l['dataset'])

        if args.stress_tests != "NONE":
            for l in stress_dump_list:
                dump_result_files(l['dataset'])(l['output_dir'], -1, l['test_metrics'], str(l['test_predictions']),
                                                str(l['golds']), l['stress_test'], l['opt'], l['dataset'])