예제 #1
0
    def clone_model(self, model, id2label_map):
        """ clone only part of params """
        # deal with data parallel model
        slot_id2label, intent_id2label = id2label_map['slot'], id2label_map[
            'intent']
        new_model: BertSLU
        old_model: BertSLU
        if self.opt.local_rank != -1 or self.n_gpu > 1 and hasattr(
                model, 'module'):  # the model is parallel class here
            old_model = model.module
        else:
            old_model = model

        label2id_map = {
            key: {v: k
                  for k, v in values.items()}
            for key, values in id2label_map.items()
        }
        config = {'label2id': label2id_map, 'id2label': id2label_map}
        # get a new instance for different domain
        new_model = make_finetune_model(opt=self.opt, config=config)
        new_model = prepare_model(self.opt, new_model, self.device, self.n_gpu)
        if self.opt.local_rank != -1 or self.n_gpu > 1:
            sub_new_model = new_model.module
        else:
            sub_new_model = new_model
        ''' copy weights and stuff '''
        sub_new_model.load_state_dict(old_model.state_dict())

        return new_model
예제 #2
0
    def clone_model(self, model, id2label):
        """ clone only part of params """
        # deal with data parallel model
        new_model: FewShotSeqLabeler
        old_model: FewShotSeqLabeler
        if self.opt.local_rank != -1 or self.n_gpu > 1 and hasattr(model, 'module'):  # the model is parallel class here
            old_model = model.module
        else:
            old_model = model
        emission_dict = old_model.emission_scorer.state_dict()
        old_num_tags = len(self.get_value_from_order_dict(emission_dict, 'label_reps'))

        config = {'num_tags': len(id2label), 'id2label': id2label}
        if 'num_anchors' in old_model.config:
            config['num_anchors'] = old_model.config['num_anchors']  # Use previous model's random anchors.
        # get a new instance for different domain
        new_model = make_model(opt=self.opt, config=config)
        new_model = prepare_model(self.opt, new_model, self.device, self.n_gpu)
        if self.opt.local_rank != -1 or self.n_gpu > 1:
            sub_new_model = new_model.module
        else:
            sub_new_model = new_model
        ''' copy weights and stuff '''
        if old_model.opt.task == 'sl' and old_model.transition_scorer:
            # copy one-by-one because target transition and decoder will be left un-assigned
            sub_new_model.context_embedder.load_state_dict(old_model.context_embedder.state_dict())
            sub_new_model.emission_scorer.load_state_dict(old_model.emission_scorer.state_dict())
            for param_name in ['backoff_trans_mat', 'backoff_start_trans_mat', 'backoff_end_trans_mat']:
                sub_new_model.transition_scorer.state_dict()[param_name].copy_(
                    old_model.transition_scorer.state_dict()[param_name].data)
        else:
            sub_new_model.load_state_dict(old_model.state_dict())

        return new_model
예제 #3
0
def load_model(path):
    try:
        with open(path, 'rb') as reader:
            cpt = torch.load(reader, map_location='cpu')
            model = make_model(opt=cpt['opt'], config=cpt['config'])
            model = prepare_model(args=cpt['opt'], model=model, device=cpt['opt'].device, n_gpu=cpt['opt'].n_gpu)
            model.load_state_dict(cpt['state_dict'])
            return model
    except IOError as e:
        logger.info("Failed to load model from {} \n {}".format(path, e))
        return None
예제 #4
0
def load_model(path):
    try:
        with open(path, 'rb') as reader:
            cpt = torch.load(reader, map_location='cpu')
            old_num_tags = len(
                get_value_from_order_dict(cpt['state_dict'], 'label_reps'))
            model = make_model(opt=cpt['opt'],
                               num_tags=cpt['num_tags'],
                               trans_r=cpt['trans_r'],
                               random_num_tags=old_num_tags // 3)
            model = prepare_model(args=cpt['opt'],
                                  model=model,
                                  device=cpt['opt'].device,
                                  n_gpu=cpt['opt'].n_gpu)
            model.load_state_dict(cpt['state_dict'])
            return model
    except IOError:
        logger.info("Failed to load model, sleeping ...")
        return None
예제 #5
0
    def clone_model(self, model, id2label, mat_type='test'):
        """ clone only part of params """
        # deal with data parallel model
        new_model: FewShotSeqLabeler
        old_model: FewShotSeqLabeler
        if self.opt.local_rank != -1 or self.n_gpu > 1 and hasattr(
                model, 'module'):  # the model is parallel class here
            old_model = model.module
        else:
            old_model = model
        emission_dict = old_model.emission_scorer.state_dict()
        old_num_tags = len(
            self.get_value_from_order_dict(emission_dict, 'label_reps'))
        # get a new instance for different domain
        new_model = make_model(opt=self.opt,
                               num_tags=len(id2label),
                               trans_r=self.opt.trans_r,
                               id2label=id2label,
                               random_num_tags=old_num_tags // 3,
                               mat_type=mat_type)
        new_model = prepare_model(self.opt, new_model, self.device, self.n_gpu)
        if self.opt.local_rank != -1 or self.n_gpu > 1:
            sub_new_model = new_model.module
        else:
            sub_new_model = new_model
        # copy weights and stuff
        # target transition and decoder will be left un-assigned
        sub_new_model.context_embedder.load_state_dict(
            old_model.context_embedder.state_dict())
        sub_new_model.emission_scorer.load_state_dict(
            old_model.emission_scorer.state_dict())

        if old_model.transition_scorer:
            for param_name in [
                    'backoff_trans_mat', 'backoff_start_trans_mat',
                    'backoff_end_trans_mat'
            ]:
                sub_new_model.transition_scorer.state_dict()[param_name].copy_(
                    old_model.transition_scorer.state_dict()[param_name].data)
        return new_model
예제 #6
0
def main():
    """ to start the experiment """
    ''' set option '''
    parser = argparse.ArgumentParser()
    parser = define_args(parser, basic_args, train_args, test_args, preprocess_args, model_args)
    opt = parser.parse_args()
    print('Args:\n', json.dumps(vars(opt), indent=2))
    opt = option_check(opt)

    ''' device & environment '''
    device, n_gpu = set_device_environment(opt)
    os.makedirs(opt.output_dir, exist_ok=True)
    logger.info("Environment: device {}, n_gpu {}".format(device, n_gpu))

    ''' data & feature '''
    data_loader = FewShotRawDataLoader(debugging=opt.do_debug)
    preprocessor = make_preprocessor(opt)
    if opt.do_train:
        train_features, train_label2id, train_id2label, train_trans_mat, \
            dev_features, dev_label2id, dev_id2label, dev_trans_mat = \
            get_training_data_and_feature(opt, data_loader, preprocessor)
        # todo: remove the train label mask out of opt.
        if opt.mask_transition:
            opt.train_label_mask = make_label_mask(opt, opt.train_path, train_label2id)
            opt.dev_label_mask = make_label_mask(opt, opt.dev_path, dev_label2id)
            opt.train_trans_mat = [torch.Tensor(item).to(device) for item in train_trans_mat]
            opt.dev_trans_mat = [torch.Tensor(item).to(device) for item in dev_trans_mat]
    else:
        train_features, train_label2id, train_id2label, dev_features, dev_label2id, dev_id2label = [None] * 6
        if opt.mask_transition:
            opt.train_label_mask = None
            opt.dev_label_mask = None
    if opt.do_predict:
        test_features, test_label2id, test_id2label, test_trans_mat = get_testing_data_feature(opt, data_loader, preprocessor)
        if opt.mask_transition:
            opt.test_label_mask = make_label_mask(opt, opt.test_path, test_label2id)
            opt.test_trans_mat = [torch.Tensor(item).to(device) for item in test_trans_mat]
    else:
        test_features, test_label2id, test_id2label = [None] * 3
        if opt.mask_transition:
            opt.test_label_mask = None

    ''' over fitting test '''
    if opt.do_overfit_test:
        test_features, test_label2id, test_id2label = train_features, train_label2id, train_id2label
        dev_features, dev_label2id, dev_id2label = train_features, train_label2id, train_id2label

    ''' select training & testing mode '''
    trainer_class = SchemaFewShotTrainer if opt.use_schema else FewShotTrainer
    tester_class = SchemaFewShotTester if opt.use_schema else FewShotTester

    ''' training '''
    best_model = None
    if opt.do_train:
        logger.info("***** Perform training *****")
        training_model = make_model(opt, num_tags=len(train_label2id), trans_r=1)  # trans_r is 1 for training
        training_model = prepare_model(opt, training_model, device, n_gpu)
        if opt.mask_transition:
            training_model.label_mask = opt.train_label_mask.to(device)
        if opt.upper_lr > 0:  # use different learning rate for upper structure parameter
            param_to_optimize, optimizer = prepare_few_shot_optimizer(opt, training_model, len(train_features))
        else:
            param_to_optimize, optimizer = prepare_optimizer(opt, training_model, len(train_features))
        tester = tester_class(opt, device, n_gpu)
        trainer = trainer_class(opt, optimizer, param_to_optimize, device, n_gpu, tester=tester)
        if opt.warmup_epoch > 0:
            training_model.no_embedder_grad = True

            if opt.upper_lr > 0:  # use different learning rate for upper structure parameter
                stage_1_param_to_optimize, stage_1_optimizer = prepare_few_shot_optimizer(opt, training_model, len(train_features))
            else:
                stage_1_param_to_optimize, stage_1_optimizer = prepare_optimizer(opt, training_model, len(train_features))

            stage_1_trainer = trainer_class(opt, stage_1_optimizer, stage_1_param_to_optimize, device, n_gpu, tester=None)
            trained_model, best_dev_score, test_score = stage_1_trainer.do_train(
                training_model, train_features, opt.warmup_epoch)
            training_model = trained_model
            training_model.no_embedder_grad = False
            print('========== Stage one training finished! ==========')
        trained_model, best_dev_score, test_score = trainer.do_train(
            training_model, train_features, opt.num_train_epochs,
            dev_features, dev_id2label, test_features, test_id2label, best_dev_score_now=0)

        # decide the best model
        if not opt.eval_when_train:  # select best among check points
            best_model, best_score, test_score_then = trainer.select_model_from_check_point(
                train_id2label, dev_features, dev_id2label, test_features, test_id2label, rm_cpt=opt.delete_checkpoint)
        else:  # best model is selected during training
            best_model = trained_model
        logger.info('dev:{}, test:{}'.format(best_dev_score, test_score))
        print('dev:{}, test:{}'.format(best_dev_score, test_score))

    ''' testing '''
    if opt.do_predict:
        logger.info("***** Perform testing *****")
        tester = tester_class(opt, device, n_gpu)
        if not best_model:
            if not opt.saved_model_path:
                raise ValueError("No model trained and no trained model file given!")
            if os.path.isdir(opt.saved_model_path):
                all_cpt_file = list(filter(lambda x: '.cpt.pl' in x, os.listdir(opt.saved_model_path)))
                all_cpt_file = sorted(all_cpt_file,
                                      key=lambda x: int(x.replace('model.step', '').replace('.cpt.pl', '')))
                max_score = 0
                for cpt_file in all_cpt_file:
                    cpt_model = load_model(os.path.join(opt.saved_model_path, cpt_file))
                    testing_model = tester.clone_model(cpt_model, test_id2label)
                    if opt.mask_transition:
                        testing_model.label_mask = opt.test_label_mask.to(device)
                    test_score = tester.do_test(testing_model, test_features, test_id2label, log_mark='test_pred')
                    if test_score > max_score:
                        max_score = test_score
                    logger.info('cpt_file:{} - test:{}'.format(cpt_file, test_score))
                print('max_score:{}'.format(max_score))
            else:
                if not os.path.exists(opt.saved_model_path):
                    logger.info('The model is not exits')
                    raise ValueError('The model is not exits')
                best_model = load_model(opt.saved_model_path)
        if not os.path.isdir(opt.saved_model_path):
            testing_model = tester.clone_model(best_model, test_id2label)  # copy reusable params
            if opt.mask_transition:
                testing_model.label_mask = opt.test_label_mask.to(device)
            test_score = tester.do_test(testing_model, test_features, test_id2label, log_mark='test_pred')
            logger.info('test:{}'.format(test_score))
            print('test:{}'.format(test_score))
예제 #7
0
파일: main.py 프로젝트: wutong8023/FewJoint
def main():
    """ to start the experiment """
    ''' set option '''
    parser = argparse.ArgumentParser()
    parser = define_args(parser, basic_args, train_args, test_args,
                         preprocess_args, model_args)
    opt = parser.parse_args()
    print('Args:\n', json.dumps(vars(opt), indent=2))
    opt = option_check(opt)
    ''' device & environment '''
    device, n_gpu = set_device_environment(opt)
    os.makedirs(opt.output_dir, exist_ok=True)
    logger.info("Environment: device {}, n_gpu {}".format(device, n_gpu))
    ''' data & feature '''
    data_loader = FewShotRawDataLoader(opt)
    preprocessor = make_preprocessor(opt)
    if opt.do_train:
        train_features, (train_slot_label2id, train_slot_id2label), (train_intent_label2id, train_intent_id2label), \
            dev_features, (dev_slot_label2id, dev_slot_id2label), (dev_intent_label2id, dev_intent_id2label) = \
            get_training_data_and_feature(opt, data_loader, preprocessor)

        if opt.mask_transition and opt.task in ['slot_filling', 'slu']:
            opt.train_label_mask = make_label_mask(opt, opt.train_path,
                                                   train_slot_label2id)
            opt.dev_label_mask = make_label_mask(opt, opt.dev_path,
                                                 dev_slot_label2id)
    else:
        train_features, train_slot_label2id, train_slot_id2label, train_intent_label2id, train_intent_id2label, \
            dev_features, dev_slot_label2id, dev_slot_id2label, dev_intent_label2id, dev_intent_id2label = [None] * 10
        if opt.mask_transition and opt.task in ['slot_filling', 'slu']:
            opt.train_label_mask = None
            opt.dev_label_mask = None

    if opt.do_predict:
        test_features, (test_slot_label2id, test_slot_id2label), (
            test_intent_label2id,
            test_intent_id2label) = get_testing_data_feature(
                opt, data_loader, preprocessor)
        if opt.mask_transition and opt.task in ['slot_filling', 'slu']:
            opt.test_label_mask = make_label_mask(opt, opt.test_path,
                                                  test_slot_label2id)
    else:
        test_features, test_slot_label2id, test_slot_id2label, test_intent_label2id, test_intent_id2label = [
            None
        ] * 6
        if opt.mask_transition and opt.task in ['slot_filling', 'slu']:
            opt.test_label_mask = None
    ''' over fitting test '''
    if opt.do_overfit_test:
        test_features, (test_slot_label2id, test_slot_id2label), (test_intent_label2id, test_intent_id2label) = \
            train_features, (train_slot_label2id, train_slot_id2label), (train_intent_label2id, train_intent_id2label)
        dev_features, (dev_slot_label2id, dev_slot_id2label), (dev_intent_label2id, dev_intent_id2label) = \
            train_features, (train_slot_label2id, train_slot_id2label), (train_intent_label2id, train_intent_id2label)

    train_id2label_map = {
        'slot': train_slot_id2label,
        'intent': train_intent_id2label
    }
    dev_id2label_map = {
        'slot': dev_slot_id2label,
        'intent': dev_intent_id2label
    }
    test_id2label_map = {
        'slot': test_slot_id2label,
        'intent': test_intent_id2label
    }
    ''' select training & testing mode '''
    trainer_class = SchemaFewShotTrainer if opt.use_schema else FewShotTrainer
    tester_class = SchemaFewShotTester if opt.use_schema else FewShotTester
    ''' training '''
    best_model = None
    if opt.do_train:
        logger.info("***** Perform training *****")
        if opt.restore_cpt:  # restart training from a check point.
            training_model = load_model(
                opt.saved_model_path
            )  # restore optimizer param is not support now.
            opt = training_model.opt
            opt.warmup_epoch = -1
        else:
            training_model = make_model(
                opt,
                config={
                    # 'num_tags': len(train_slot_label2id) if opt.task in ['slot_filling', 'slu'] else 0,
                    'num_tags': {
                        'slot': len(train_slot_label2id),
                        'intent': len(train_intent_label2id)
                    },
                    'id2label': train_id2label_map
                })

        # ================== fine-tune =====================
        if opt.finetune and not opt.restore_cpt:
            logger.info(
                "***** Start fine-tuning target domain with only support data *****"
            )
            # train & dev as fine-tune train
            # test as fine-tune dev
            ft_preprocessor = make_preprocessor(opt, finetune=opt.finetune)

            if opt.ft_dataset == 'origin':
                ft_train_features, (ft_train_slot_label2id, ft_train_slot_id2label), \
                    (ft_train_intent_label2id, ft_train_intent_id2label) = get_training_data_feature(
                        opt, data_loader, ft_preprocessor)
                ft_test_features, (ft_test_slot_label2id, ft_test_slot_id2label), \
                    (ft_test_intent_label2id, ft_test_intent_id2label) = get_testing_data_feature(
                    opt, data_loader, ft_preprocessor)
            elif opt.ft_dataset == 'shuffle':
                ft_features, (ft_train_slot_label2id, ft_train_slot_id2label), \
                    (ft_train_intent_label2id, ft_train_intent_id2label) = get_support_data_feature(
                        opt, data_loader, ft_preprocessor)
                (ft_test_slot_label2id, ft_test_slot_id2label), (ft_test_intent_label2id, ft_test_intent_id2label) = \
                    (ft_train_slot_label2id, ft_train_slot_id2label), (ft_train_intent_label2id, ft_train_intent_id2label)

                random.seed(opt.seed)
                random.shuffle(ft_features)

                train_num = int(len(ft_features) * opt.ft_td_rate)

                ft_train_features, ft_test_features = ft_features[:train_num], ft_features[
                    train_num:]

            else:
                raise NotImplementedError

            ft_label2id = {
                'slot': ft_train_slot_label2id,
                'intent': ft_train_intent_label2id
            }
            ft_id2label = {
                'slot': ft_train_slot_id2label,
                'intent': ft_train_intent_id2label
            }

            ft_test_id2label = {
                'slot': ft_test_slot_id2label,
                'intent': ft_test_intent_id2label
            }

            finetune_model = make_finetune_model(opt,
                                                 config={
                                                     'label2id': ft_label2id,
                                                     'id2label': ft_id2label
                                                 })
            finetune_model = prepare_model(opt, finetune_model, device, n_gpu)
            ft_upper_structures = ['intent', 'slot']
            ft_param_to_optimize, ft_optimizer, ft_scheduler = prepare_optimizer(
                opt, finetune_model, len(ft_train_features),
                ft_upper_structures)

            ft_tester = FineTuneTester(opt, device, n_gpu)
            ft_trainer = FineTuneTrainer(opt,
                                         ft_optimizer,
                                         ft_scheduler,
                                         ft_param_to_optimize,
                                         device,
                                         n_gpu,
                                         tester=ft_tester)

            ft_trained_model, ft_best_dev_score, ft_test_score = ft_trainer.do_train(
                finetune_model, ft_train_features, opt.ft_num_train_epochs,
                ft_test_features, ft_test_id2label)

            print('========== Fine-tuning finished! ==========')

            # ============ init embedder parameters to training_model ===========
            training_model = prepare_model(opt, training_model, device, n_gpu)
            if opt.mask_transition and opt.task in ['slot_filling', 'slu']:
                training_model.label_mask = opt.train_label_mask.to(device)
            # prepare a set of name subseuqence/mark to use different learning rate for part of params

            training_model.context_embedder.embedder.load_state_dict(
                ft_trained_model.context_embedder.embedder.state_dict())

            # remove saved model file
            all_cpt_file = list(
                filter(lambda x: '.cpt.pl' in x, os.listdir(opt.output_dir)))
            for cpt_file in all_cpt_file:
                os.unlink(os.path.join(opt.output_dir, cpt_file))

        else:
            training_model = prepare_model(opt, training_model, device, n_gpu)
            if opt.mask_transition and opt.task in ['slot_filling', 'slu']:
                training_model.label_mask = opt.train_label_mask.to(device)
            # prepare a set of name subseuqence/mark to use different learning rate for part of params

        upper_structures = [
            'backoff', 'scale_rate', 'f_theta', 'phi', 'start_reps',
            'end_reps', 'biaffine', 'relation'
        ]
        metric_params = ['intent', 'slot', 'metric', 'slu_rnn_encoder']
        if not opt.no_up_metric_params:
            upper_structures.extend(metric_params)
        param_to_optimize, optimizer, scheduler = prepare_optimizer(
            opt, training_model, len(train_features), upper_structures)
        tester = tester_class(opt, device, n_gpu)
        trainer = trainer_class(opt,
                                optimizer,
                                scheduler,
                                param_to_optimize,
                                device,
                                n_gpu,
                                tester=tester)
        if opt.warmup_epoch > 0:
            training_model.no_embedder_grad = True
            stage_1_param_to_optimize, stage_1_optimizer, stage_1_scheduler = prepare_optimizer(
                opt, training_model, len(train_features), upper_structures)
            stage_1_trainer = trainer_class(opt,
                                            stage_1_optimizer,
                                            stage_1_scheduler,
                                            stage_1_param_to_optimize,
                                            device,
                                            n_gpu,
                                            tester=None)
            trained_model, best_dev_score, test_score = stage_1_trainer.do_train(
                training_model, train_features, opt.warmup_epoch)
            training_model = trained_model
            training_model.no_embedder_grad = False
            print('========== Warmup training finished! ==========')
        trained_model, best_dev_score, test_score = trainer.do_train(
            training_model,
            train_features,
            opt.num_train_epochs,
            dev_features,
            dev_id2label_map,
            test_features,
            test_id2label_map,
            best_dev_score_now=0)

        # decide the best model
        if not opt.eval_when_train:  # select best among check points
            best_model, best_score, test_score_then = trainer.select_model_from_check_point(
                train_id2label_map,
                dev_features,
                dev_id2label_map,
                test_features,
                test_id2label_map,
                rm_cpt=opt.delete_checkpoint)
        else:  # best model is selected during training
            best_model = trained_model
        logger.info('dev:{}, test:{}'.format(best_dev_score, test_score))
        print('dev:{}, test:{}'.format(best_dev_score, test_score))
    ''' testing '''
    if opt.do_predict:
        logger.info("***** Perform testing *****")
        print("***** Perform testing *****")
        tester = tester_class(opt, device, n_gpu)
        if not best_model:  # no trained model load it from disk.
            if not opt.saved_model_path or not os.path.exists(
                    opt.saved_model_path):
                raise ValueError(
                    "No model trained and no trained model file given (or not exist)"
                )
            if os.path.isdir(
                    opt.saved_model_path):  # eval a list of checkpoints
                max_score = eval_check_points(opt, tester, test_features,
                                              test_id2label_map, device)
                print('best check points scores:{}'.format(max_score))
                exit(0)
            else:
                best_model = load_model(opt.saved_model_path)
        ''' test the best model '''
        testing_model = tester.clone_model(
            best_model, test_id2label_map)  # copy reusable params
        if opt.mask_transition and opt.task in ['slot_filling', 'slu']:
            testing_model.label_mask = opt.test_label_mask.to(device)
        test_score = tester.do_test(testing_model,
                                    test_features,
                                    test_id2label_map,
                                    log_mark='test_pred')
        logger.info('test:{}'.format(test_score))
        print('test:{}'.format(test_score))
예제 #8
0
def main():
    """ to start the experiment """
    ''' set option '''
    parser = argparse.ArgumentParser()
    parser = define_args(parser, basic_args, train_args, test_args,
                         preprocess_args, model_args)
    opt = parser.parse_args()
    print('Args:\n', json.dumps(vars(opt), indent=2))
    opt = option_check(opt)
    ''' device & environment '''
    device, n_gpu = set_device_environment(opt)
    os.makedirs(opt.output_dir, exist_ok=True)
    logger.info("Environment: device {}, n_gpu {}".format(device, n_gpu))
    ''' data & feature '''
    data_loader = FewShotRawDataLoader(opt)
    preprocessor = make_preprocessor(opt)
    if opt.do_train:
        train_features, train_label2id_map, train_id2label_map, dev_features, dev_label2id_map, dev_id2label_map = \
            get_training_data_and_feature(opt, data_loader, preprocessor)

        if opt.mask_transition and 'sl' in opt.task:
            opt.train_label_mask = make_label_mask(opt, opt.train_path,
                                                   train_label2id_map['sl'])
            opt.dev_label_mask = make_label_mask(opt, opt.dev_path,
                                                 dev_label2id_map['sl'])
    else:
        train_features, train_label2id_map, train_id2label_map, dev_features, dev_label2id_map, dev_id2label_map = \
            [None] * 6
        if opt.mask_transition and 'sl' in opt.task:
            opt.train_label_mask = None
            opt.dev_label_mask = None

    if opt.do_predict:
        test_features, test_label2id_map, test_id2label_map = get_testing_data_feature(
            opt, data_loader, preprocessor)
        if opt.mask_transition and 'sl' in opt.task:
            opt.test_label_mask = make_label_mask(opt, opt.test_path,
                                                  test_label2id_map['sl'])
    else:
        test_features, test_label2id_map, test_id2label_map = [None] * 3
        if opt.mask_transition and 'sl' in opt.task:
            opt.test_label_mask = None
    ''' over fitting test '''
    if opt.do_overfit_test:
        test_features, test_label2id_map, test_id2label_map = train_features, train_label2id_map, train_id2label_map
        dev_features, dev_label2id_map, dev_id2label_map = train_features, train_label2id_map, train_id2label_map
    ''' select training & testing mode '''
    trainer_class = SchemaFewShotTrainer if opt.use_schema else FewShotTrainer
    tester_class = SchemaFewShotTester if opt.use_schema else FewShotTester
    ''' training '''
    best_model = None
    if opt.do_train:
        logger.info("***** Perform training *****")
        if opt.restore_cpt:  # restart training from a check point.
            training_model = load_model(
                opt.saved_model_path
            )  # restore optimizer param is not support now.
            opt = training_model.opt
            opt.warmup_epoch = -1
        else:
            training_model = make_model(opt,
                                        config={
                                            'num_tags':
                                            len(train_label2id_map['sl']) if
                                            'sl' in train_label2id_map else 0
                                        })
        training_model = prepare_model(opt, training_model, device, n_gpu)
        if opt.mask_transition and 'sl' in opt.task:
            training_model.label_mask = opt.train_label_mask.to(device)
        # prepare a set of name subseuqence/mark to use different learning rate for part of params
        upper_structures = [
            'backoff', 'scale_rate', 'f_theta', 'phi', 'start_reps',
            'end_reps', 'biaffine'
        ]
        param_to_optimize, optimizer, scheduler = prepare_optimizer(
            opt, training_model, len(train_features), upper_structures)
        tester = tester_class(opt, device, n_gpu)
        trainer = trainer_class(opt,
                                optimizer,
                                scheduler,
                                param_to_optimize,
                                device,
                                n_gpu,
                                tester=tester)
        if opt.warmup_epoch > 0:
            training_model.no_embedder_grad = True
            stage_1_param_to_optimize, stage_1_optimizer, stage_1_scheduler = prepare_optimizer(
                opt, training_model, len(train_features), upper_structures)
            stage_1_trainer = trainer_class(opt,
                                            stage_1_optimizer,
                                            stage_1_scheduler,
                                            stage_1_param_to_optimize,
                                            device,
                                            n_gpu,
                                            tester=None)
            trained_model, best_dev_score, test_score = stage_1_trainer.do_train(
                training_model, train_features, opt.warmup_epoch)
            training_model = trained_model
            training_model.no_embedder_grad = False
            print('========== Warmup training finished! ==========')
        trained_model, best_dev_score, test_score = trainer.do_train(
            training_model,
            train_features,
            opt.num_train_epochs,
            dev_features,
            dev_id2label_map,
            test_features,
            test_id2label_map,
            best_dev_score_now=0)

        # decide the best model
        if not opt.eval_when_train:  # select best among check points
            best_model, best_score, test_score_then = trainer.select_model_from_check_point(
                train_id2label_map,
                dev_features,
                dev_id2label_map,
                test_features,
                test_id2label_map,
                rm_cpt=opt.delete_checkpoint)
        else:  # best model is selected during training
            best_model = trained_model
        logger.info('dev:{}, test:{}'.format(best_dev_score, test_score))
        print('dev:{}, test:{}'.format(best_dev_score, test_score))
    ''' testing '''
    if opt.do_predict:
        logger.info("***** Perform testing *****")
        print("***** Perform testing *****")
        tester = tester_class(opt, device, n_gpu)
        if not best_model:  # no trained model load it from disk.
            if not opt.saved_model_path or not os.path.exists(
                    opt.saved_model_path):
                raise ValueError(
                    "No model trained and no trained model file given (or not exist)"
                )
            if os.path.isdir(
                    opt.saved_model_path):  # eval a list of checkpoints
                max_score = eval_check_points(opt, tester, test_features,
                                              test_id2label_map, device)
                print('best check points scores:{}'.format(max_score))
                exit(0)
            else:
                best_model = load_model(opt.saved_model_path)
        ''' test the best model '''
        testing_model = tester.clone_model(
            best_model, test_id2label_map)  # copy reusable params
        if opt.mask_transition and 'sl' in opt.task:
            testing_model.label_mask = opt.test_label_mask.to(device)
        test_score = tester.do_test(testing_model,
                                    test_features,
                                    test_id2label_map,
                                    log_mark='test_pred')
        logger.info('test:{}'.format(test_score))
        print('test:{}'.format(test_score))