def __init__(self, hparams: Namespace):
        super().__init__()
        self.hparams = hparams
        self.save_hyperparameters()
        cached_config_file = join(self.hparams.exp_name, 'cached_config.bin')
        if os.path.exists(cached_config_file):
            cached_config = torch.load(cached_config_file)
            encoder_path = join(self.hparams.exp_name, cached_config['encoder'])
            model_path = join(self.hparams.exp_name, cached_config['model'])
        else:
            model_path = None
            if self.hparams.fine_tuned_encoder is not None:
                encoder_path = join(self.hparams.fine_tuned_encoder_path, self.hparams.fine_tuned_encoder, 'encoder.pkl')
            else:
                encoder_path = None

        _, _, tokenizer_class = MODEL_CLASSES[self.hparams.model_type]
        self.tokenizer = tokenizer_class.from_pretrained(self.hparams.encoder_name_or_path,
                                                    do_lower_case=self.hparams.do_lower_case)
        # Set Encoder and Model
        self.encoder, _ = load_encoder_model(self.hparams.encoder_name_or_path, self.hparams.model_type)
        self.model = HierarchicalGraphNetwork(config=self.hparams)
        if encoder_path is not None:
            self.encoder.load_state_dict(torch.load(encoder_path))
            logging.info('Initialize parameter with {}'.format(encoder_path))
        if model_path is not None:
            self.model.load_state_dict(torch.load(model_path))
        logging.info('Loading encoder and model completed')
Beispiel #2
0
 def __init__(self, config):
     super(UnifiedHGNModel, self).__init__()
     self.config = config
     self.encoder, _ = load_encoder_model(self.config.encoder_name_or_path,
                                          self.config.model_type)
     self.model = HierarchicalGraphNetwork(config=self.config)
     if self.config.encoder_path is not None:
         self.initialize_model()
Beispiel #3
0
class UnifiedHGNModel(nn.Module):
    def __init__(self, config):
        super(UnifiedHGNModel, self).__init__()
        self.config = config
        self.encoder, _ = load_encoder_model(self.config.encoder_name_or_path,
                                             self.config.model_type)
        self.model = HierarchicalGraphNetwork(config=self.config)
        if self.config.encoder_path is not None:
            self.initialize_model()

    def initialize_model(self):
        if self.config.encoder_path is not None:
            logging.info("Loading encoder from: {}".format(
                self.config.encoder_path))
            self.encoder.load_state_dict(torch.load(self.config.encoder_path))
            logging.info("Loading encoder completed")
        else:
            raise 'The encoder name is none {}'.format(self.config.model)
        if self.config.model_path is not None:
            logging.info("Loading model from: {}".format(
                self.config.model_path))
            self.model.load_state_dict(torch.load(self.config.model_path))
            logging.info("Loading model completed")
        else:
            raise 'The model name is none'.format(self.config.model)

    def forward(self, batch, return_yp=True, return_cls=True):
        ###############################################################################################################
        inputs = {
            'input_ids':
            batch['context_idxs'],
            'attention_mask':
            batch['context_mask'],
            'token_type_ids':
            batch['segment_idxs']
            if self.config.model_type in ['bert', 'xlnet'] else None
        }  # XLM don't use segment_ids
        ####++++++++++++++++++++++++++++++++++++++
        outputs = self.encoder(**inputs)
        batch['context_encoding'] = outputs[0]
        ####++++++++++++++++++++++++++++++++++++++
        batch['context_mask'] = batch['context_mask'].float().to(
            self.config.device)
        start, end, q_type, paras, sents, ents, y1, y2, cls_emb = self.model.forward(
            batch, return_yp=return_yp, return_cls=return_cls)
        return start, end, q_type, paras, sents, ents, y1, y2, cls_emb
    def __init__(self, args: Namespace):
        super(lightningHGN, self).__init__()
        self.args = args
        cached_config_file = join(self.args.exp_name, 'cached_config.bin')
        if os.path.exists(cached_config_file):
            self.cached_config = torch.load(cached_config_file)
            encoder_path = join(self.args.exp_name, self.cached_config['encoder'])
            model_path = join(self.args.exp_name, self.cached_config['model'])
        else:
            encoder_path = None
            model_path = None
            self.cached_config = None

        _, _, tokenizer_class = MODEL_CLASSES[self.args.model_type]
        self.tokenizer = tokenizer_class.from_pretrained(self.args.encoder_name_or_path,
                                                    do_lower_case=args.do_lower_case)
        # Set Encoder and Model
        self.encoder, _ = load_encoder_model(self.args.encoder_name_or_path, self.args.model_type)
        self.model = HierarchicalGraphNetwork(config=self.args)
        if encoder_path is not None:
            self.encoder.load_state_dict(torch.load(encoder_path))
        if model_path is not None:
            self.model.load_state_dict(torch.load(model_path))
Beispiel #5
0
encoder_path = join(args.exp_name,
                    args.encoder_name)  ## replace encoder.pkl as encoder
model_path = join(args.exp_name,
                  args.model_name)  ## replace encoder.pkl as encoder
logger.info("Loading encoder from: {}".format(encoder_path))
logger.info("Loading model from: {}".format(model_path))

if torch.cuda.is_available():
    device_ids, _ = single_free_cuda()
    device = torch.device('cuda:{}'.format(device_ids[0]))
else:
    device = torch.device('cpu')

encoder, _ = load_encoder_model(args.encoder_name_or_path, args.model_type)
model = HierarchicalGraphNetwork(config=args)

if encoder_path is not None:
    state_dict = torch.load(encoder_path)
    print('loading parameter from {}'.format(encoder_path))
    for key in list(state_dict.keys()):
        if 'module.' in key:
            state_dict[key.replace('module.', '')] = state_dict[key]
            del state_dict[key]
    encoder.load_state_dict(state_dict)
if model_path is not None:
    state_dict = torch.load(model_path)
    print('loading parameter from {}'.format(model_path))
    for key in list(state_dict.keys()):
        if 'module.' in key:
            state_dict[key.replace('module.', '')] = state_dict[key]
class lightningHGN(pl.LightningModule):
    def __init__(self, args: Namespace):
        super(lightningHGN, self).__init__()
        self.args = args
        cached_config_file = join(self.args.exp_name, 'cached_config.bin')
        if os.path.exists(cached_config_file):
            self.cached_config = torch.load(cached_config_file)
            encoder_path = join(self.args.exp_name, self.cached_config['encoder'])
            model_path = join(self.args.exp_name, self.cached_config['model'])
        else:
            encoder_path = None
            model_path = None
            self.cached_config = None

        _, _, tokenizer_class = MODEL_CLASSES[self.args.model_type]
        self.tokenizer = tokenizer_class.from_pretrained(self.args.encoder_name_or_path,
                                                    do_lower_case=args.do_lower_case)
        # Set Encoder and Model
        self.encoder, _ = load_encoder_model(self.args.encoder_name_or_path, self.args.model_type)
        self.model = HierarchicalGraphNetwork(config=self.args)
        if encoder_path is not None:
            self.encoder.load_state_dict(torch.load(encoder_path))
        if model_path is not None:
            self.model.load_state_dict(torch.load(model_path))

    def prepare_data(self):
        helper = DataHelper(gz=True, config=self.args)
        self.train_data = helper.train_loader
        self.dev_example_dict = helper.dev_example_dict
        self.dev_feature_dict = helper.dev_feature_dict
        self.dev_data = helper.dev_loader

    def setup(self, stage: str = 'fit'):
        if stage == 'fit':
            # Get dataloader by calling it - train_dataloader() is called after setup() by default
            train_loader = self.train_dataloader()
            # Calculate total steps
            if self.args.max_steps > 0:
                self.total_steps = self.args.max_steps
                self.args.num_train_epochs = self.args.max_steps // (
                            len(train_loader) // self.args.gradient_accumulation_steps) + 1
            else:
                self.total_steps = len(train_loader) // self.args.gradient_accumulation_steps * self.args.num_train_epochs
            print('total steps = {}'.format(self.total_steps))

    def train_dataloader(self):
        return self.train_data

    def val_dataloader(self):
        return self.dev_data

    def forward(self, batch):
        inputs = {'input_ids':      batch['context_idxs'],
                  'attention_mask': batch['context_mask'],
                  'token_type_ids': batch['segment_idxs'] if self.args.model_type in ['bert', 'xlnet'] else None}  # XLM don't use segment_ids
        batch['context_encoding'] = self.encoder(**inputs)[0]
        batch['context_mask'] = batch['context_mask'].float()
        start, end, q_type, paras, sents, ents, yp1, yp2 = self.model(batch, return_yp=True)
        return start, end, q_type, paras, sents, ents, yp1, yp2

    def training_step(self, batch, batch_idx):
        start, end, q_type, paras, sents, ents, _, _ = self.forward(batch=batch)
        loss_list = compute_loss(self.args, batch, start, end, paras, sents, ents, q_type)
        del batch
        ##################################################################################
        loss, loss_span, loss_type, loss_sup, loss_ent, loss_para = loss_list
        dict_for_progress_bar = {'span_loss': loss_span, 'type_loss': loss_type,
                                 'sent_loss': loss_sup, 'ent_loss': loss_ent,
                                 'para_loss': loss_para}
        dict_for_log = dict_for_progress_bar.copy()
        dict_for_log['step'] = batch_idx + 1
        ##################################################################################
        output = {'loss': loss, 'log': dict_for_log, 'progress_bar': dict_for_progress_bar}
        return output

    def validation_step(self, batch, batch_idx):
        start, end, q_type, paras, sents, ents, yp1, yp2 = self.forward(batch=batch)
        loss_list = compute_loss(self.args, batch, start, end, paras, sents, ents, q_type)
        loss, loss_span, loss_type, loss_sup, loss_ent, loss_para = loss_list
        dict_for_log = {'span_loss': loss_span, 'type_loss': loss_type,
                                 'sent_loss': loss_sup, 'ent_loss': loss_ent,
                                 'para_loss': loss_para,
                        'step': batch_idx + 1}
        #######################################################################
        type_prob = F.softmax(q_type, dim=1).data.cpu().numpy()
        answer_dict_, answer_type_dict_, answer_type_prob_dict_ = convert_to_tokens(self.dev_example_dict,
                                                                                    self.dev_feature_dict,
                                                                                    batch['ids'],
                                                                                    yp1.data.cpu().numpy().tolist(),
                                                                                    yp2.data.cpu().numpy().tolist(),
                                                                                    type_prob)
        predict_support_np = torch.sigmoid(sents[:, :, 1]).data.cpu().numpy()
        valid_dict = {'answer': answer_dict_, 'ans_type': answer_type_dict_, 'ids': batch['ids'],
                      'ans_type_pro': answer_type_prob_dict_, 'supp_np': predict_support_np}
        #######################################################################
        del batch
        #######################################################################
        output = {'valid_loss': loss, 'log': dict_for_log, 'valid_dict_output': valid_dict}
        return output

    def validation_epoch_end(self, validation_step_outputs):
        avg_loss = torch.stack([x['valid_loss'] for x in validation_step_outputs]).mean()
        self.log('valid_loss', avg_loss, on_epoch=True, prog_bar=True, sync_dist=True)
        # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        answer_dict = {}
        answer_type_dict = {}
        answer_type_prob_dict = {}

        thresholds = np.arange(0.1, 1.0, 0.025)
        N_thresh = len(thresholds)
        total_sp_dict = [{} for _ in range(N_thresh)]
        total_record_num = 0
        #++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        valid_dict_outputs = [x['valid_dict_output'] for x in validation_step_outputs]
        for batch_idx, valid_dict in enumerate(valid_dict_outputs):
            answer_dict_, answer_type_dict_, answer_type_prob_dict_ = valid_dict['answer'], valid_dict['ans_type'], valid_dict['ans_type_pro']
            answer_type_dict.update(answer_type_dict_)
            answer_type_prob_dict.update(answer_type_prob_dict_)
            answer_dict.update(answer_dict_)

            predict_support_np = valid_dict['supp_np']
            batch_ids = valid_dict['ids']
            ###
            total_record_num = total_record_num + predict_support_np.shape[0]
            ###
            for i in range(predict_support_np.shape[0]):
                cur_sp_pred = [[] for _ in range(N_thresh)]
                cur_id = batch_ids[i]

                for j in range(predict_support_np.shape[1]):
                    if j >= len(self.dev_example_dict[cur_id].sent_names):
                        break
                    for thresh_i in range(N_thresh):
                        if predict_support_np[i, j] > thresholds[thresh_i]:
                            cur_sp_pred[thresh_i].append(self.dev_example_dict[cur_id].sent_names[j])

                for thresh_i in range(N_thresh):
                    if cur_id not in total_sp_dict[thresh_i]:
                        total_sp_dict[thresh_i][cur_id] = []
                    total_sp_dict[thresh_i][cur_id].extend(cur_sp_pred[thresh_i])

        def choose_best_threshold(ans_dict, pred_file):
            best_joint_f1 = 0
            best_metrics = None
            best_threshold = 0
            for thresh_i in range(N_thresh):
                prediction = {'answer': ans_dict,
                              'sp': total_sp_dict[thresh_i],
                              'type': answer_type_dict,
                              'type_prob': answer_type_prob_dict}
                tmp_file = os.path.join(os.path.dirname(pred_file), 'tmp.json')
                with open(tmp_file, 'w') as f:
                    json.dump(prediction, f)
                metrics = hotpot_eval(tmp_file, self.args.dev_gold_file)
                if metrics['joint_f1'] >= best_joint_f1:
                    best_joint_f1 = metrics['joint_f1']
                    best_threshold = thresholds[thresh_i]
                    best_metrics = metrics
                    shutil.move(tmp_file, pred_file)
            return best_metrics, best_threshold

        output_pred_file = os.path.join(self.args.exp_name, f'pred.epoch_{self.current_epoch + 1}.gpu_{self.trainer.root_gpu}.json')
        output_eval_file = os.path.join(self.args.exp_name, f'eval.epoch_{self.current_epoch + 1}.gpu_{self.trainer.root_gpu}.txt')
        best_metrics, best_threshold = choose_best_threshold(answer_dict, output_pred_file)
        logging.info('Leader board evaluation completed over {} records with threshold = {}'.format(total_record_num, best_threshold))
        log_metrics(mode='Evaluation epoch {} gpu {}'.format(self.current_epoch, self.trainer.root_gpu), metrics=best_metrics)
        logging.info('*' * 75)
        json.dump(best_metrics, open(output_eval_file, 'w'))
        return best_metrics, best_threshold

    def configure_optimizers(self):
        "Prepare optimizer and schedule (linear warmup and decay)"
        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [p for n, p in self.named_parameters() if
                           (p.requires_grad) and (not any(nd in n for nd in no_decay))],
                "weight_decay": self.args.weight_decay,
            },
            {
                "params": [p for n, p in self.named_parameters() if
                           (p.requires_grad) and (any(nd in n for nd in no_decay))],
                "weight_decay": 0.0,
            }
        ]
        optimizer = AdamW(optimizer_grouped_parameters, lr=self.args.learning_rate, eps=self.args.adam_epsilon)
        scheduler = get_linear_schedule_with_warmup(
            optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=self.total_steps
        )
        scheduler = {
            'scheduler': scheduler,
            'interval': 'step',
            'frequency': 1
        }
        return [optimizer], [scheduler]
class lightningHGN(pl.LightningModule):
    def __init__(self, hparams: Namespace):
        super().__init__()
        self.hparams = hparams
        self.save_hyperparameters()
        cached_config_file = join(self.hparams.exp_name, 'cached_config.bin')
        if os.path.exists(cached_config_file):
            cached_config = torch.load(cached_config_file)
            encoder_path = join(self.hparams.exp_name, cached_config['encoder'])
            model_path = join(self.hparams.exp_name, cached_config['model'])
        else:
            model_path = None
            if self.hparams.fine_tuned_encoder is not None:
                encoder_path = join(self.hparams.fine_tuned_encoder_path, self.hparams.fine_tuned_encoder, 'encoder.pkl')
            else:
                encoder_path = None

        _, _, tokenizer_class = MODEL_CLASSES[self.hparams.model_type]
        self.tokenizer = tokenizer_class.from_pretrained(self.hparams.encoder_name_or_path,
                                                    do_lower_case=self.hparams.do_lower_case)
        # Set Encoder and Model
        self.encoder, _ = load_encoder_model(self.hparams.encoder_name_or_path, self.hparams.model_type)
        self.model = HierarchicalGraphNetwork(config=self.hparams)
        if encoder_path is not None:
            self.encoder.load_state_dict(torch.load(encoder_path))
            logging.info('Initialize parameter with {}'.format(encoder_path))
        if model_path is not None:
            self.model.load_state_dict(torch.load(model_path))
        logging.info('Loading encoder and model completed')
        ##########

    def prepare_data(self):
        helper = DataHelper(gz=True, config=self.hparams)
        self.train_data = helper.train_loader
        self.dev_example_dict = helper.dev_example_dict
        self.dev_feature_dict = helper.dev_feature_dict
        self.dev_data = helper.dev_loader

    def setup(self, stage: str = 'fit'):
        if stage == 'fit':
            # Get dataloader by calling it - train_dataloader() is called after setup() by default
            train_loader = self.train_dataloader()
            # Calculate total steps
            if self.hparams.max_steps > 0:
                self.total_steps = self.hparams.max_steps
                self.hparams.num_train_epochs = self.hparams.max_steps // (
                            len(train_loader) // self.hparams.gradient_accumulation_steps) + 1
            else:
                self.total_steps = len(train_loader) // self.hparams.gradient_accumulation_steps * self.hparams.num_train_epochs
            print('total steps = {}'.format(self.total_steps))

    def train_dataloader(self):
        dataloader = DataLoader(dataset=self.train_data,
                                batch_size=self.hparams.per_gpu_train_batch_size,
                                shuffle=True,
                                drop_last=True,
                                pin_memory=True,
                                num_workers=max(1, self.hparams.cpu_num // 2),
                                collate_fn=HotpotDataset.collate_fn)
        return dataloader

    def val_dataloader(self):
        dataloader = DataLoader(dataset=self.dev_data,
                                batch_size=self.hparams.eval_batch_size,
                                shuffle=False,
                                drop_last=False,
                                pin_memory=True,
                                num_workers=max(1, self.hparams.cpu_num // 2),
                                collate_fn=HotpotDataset.collate_fn)
        return dataloader

    def forward(self, batch):
        inputs = {'input_ids':      batch['context_idxs'],
                  'attention_mask': batch['context_mask'],
                  'token_type_ids': batch['segment_idxs'] if self.hparams.model_type in ['bert', 'xlnet', 'electra'] else None}  # XLM don't use segment_ids
        ####++++++++++++++++++++++++++++++++++++++
        if self.hparams.model_type == 'electra':
            batch['context_encoding'] = self.encoder(**inputs).last_hidden_state
        else:
            batch['context_encoding'] = self.encoder(**inputs)[0]
        ####++++++++++++++++++++++++++++++++++++++
        batch['context_mask'] = batch['context_mask'].float().to(batch['context_encoding'].device)
        start, end, q_type, paras, sents, ents, yp1, yp2 = self.model(batch, return_yp=True)
        return start, end, q_type, paras, sents, ents, yp1, yp2

    def training_step(self, batch, batch_idx):
        start, end, q_type, paras, sents, ents, _, _ = self.forward(batch=batch)
        loss_list = compute_loss(self.hparams, batch, start, end, paras, sents, ents, q_type)
        ##################################################################################
        loss, loss_span, loss_type, loss_sup, loss_ent, loss_para = loss_list
        dict_for_progress_bar = {'span_loss': loss_span, 'type_loss': loss_type,
                                 'sent_loss': loss_sup, 'ent_loss': loss_ent,
                                 'para_loss': loss_para}
        dict_for_log = dict_for_progress_bar.copy()
        dict_for_log['step'] = batch_idx + 1
        ##################################################################################
        output = {'loss': loss, 'log': dict_for_log, 'progress_bar': dict_for_progress_bar}
        return output

    def validation_step(self, batch, batch_idx):
        start, end, q_type, paras, sents, ents, yp1, yp2 = self.forward(batch=batch)
        loss_list = compute_loss(self.hparams, batch, start, end, paras, sents, ents, q_type)
        loss, loss_span, loss_type, loss_sup, loss_ent, loss_para = loss_list
        dict_for_log = {'span_loss': loss_span, 'type_loss': loss_type,
                                 'sent_loss': loss_sup, 'ent_loss': loss_ent,
                                 'para_loss': loss_para,
                        'step': batch_idx + 1}
        #######################################################################
        type_prob = F.softmax(q_type, dim=1).data.cpu().numpy()
        answer_dict_, answer_type_dict_, answer_type_prob_dict_ = convert_to_tokens(self.dev_example_dict,
                                                                                    self.dev_feature_dict,
                                                                                    batch['ids'],
                                                                                    yp1.data.cpu().numpy().tolist(),
                                                                                    yp2.data.cpu().numpy().tolist(),
                                                                                    type_prob)
        predict_support_np = torch.sigmoid(sents[:, :, 1]).data.cpu().numpy()
        valid_dict = {'answer': answer_dict_, 'ans_type': answer_type_dict_, 'ids': batch['ids'],
                      'ans_type_pro': answer_type_prob_dict_, 'supp_np': predict_support_np}
        #######################################################################
        output = {'valid_loss': loss, 'log': dict_for_log, 'valid_dict_output': valid_dict}
        # output = {'valid_dict_output': valid_dict}
        return output

    def validation_epoch_end(self, validation_step_outputs):
        avg_loss = torch.stack([x['valid_loss'] for x in validation_step_outputs]).mean()
        # print(avg_loss, type(avg_loss), avg_loss.device)
        # self.log('valid_loss', avg_loss, on_epoch=True, prog_bar=True, sync_dist=True)
        # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        answer_dict = {}
        answer_type_dict = {}
        answer_type_prob_dict = {}

        thresholds = np.arange(0.1, 1.0, 0.02)
        N_thresh = len(thresholds)
        total_sp_dict = [{} for _ in range(N_thresh)]
        total_record_num = 0
        #++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        valid_dict_outputs = [x['valid_dict_output'] for x in validation_step_outputs]
        for batch_idx, valid_dict in enumerate(valid_dict_outputs):
            answer_dict_, answer_type_dict_, answer_type_prob_dict_ = valid_dict['answer'], valid_dict['ans_type'], valid_dict['ans_type_pro']
            answer_type_dict.update(answer_type_dict_)
            answer_type_prob_dict.update(answer_type_prob_dict_)
            answer_dict.update(answer_dict_)

            predict_support_np = valid_dict['supp_np']
            batch_ids = valid_dict['ids']
            ###
            total_record_num = total_record_num + predict_support_np.shape[0]
            ###
            for i in range(predict_support_np.shape[0]):
                cur_sp_pred = [[] for _ in range(N_thresh)]
                cur_id = batch_ids[i]

                for j in range(predict_support_np.shape[1]):
                    if j >= len(self.dev_example_dict[cur_id].sent_names):
                        break
                    for thresh_i in range(N_thresh):
                        if predict_support_np[i, j] > thresholds[thresh_i]:
                            cur_sp_pred[thresh_i].append(self.dev_example_dict[cur_id].sent_names[j])

                for thresh_i in range(N_thresh):
                    if cur_id not in total_sp_dict[thresh_i]:
                        total_sp_dict[thresh_i][cur_id] = []
                    total_sp_dict[thresh_i][cur_id].extend(cur_sp_pred[thresh_i])

        def choose_best_threshold(ans_dict, pred_file):
            best_joint_f1 = 0
            best_metrics = None
            best_threshold = 0
            #################
            metric_dict = {}
            #################
            for thresh_i in range(N_thresh):
                prediction = {'answer': ans_dict,
                              'sp': total_sp_dict[thresh_i],
                              'type': answer_type_dict,
                              'type_prob': answer_type_prob_dict}
                tmp_file = os.path.join(os.path.dirname(pred_file), 'tmp_{}.json'.format(self.trainer.root_gpu))
                with open(tmp_file, 'w') as f:
                    json.dump(prediction, f)
                metrics = hotpot_eval(tmp_file, self.hparams.dev_gold_file)
                if metrics['joint_f1'] >= best_joint_f1:
                    best_joint_f1 = metrics['joint_f1']
                    best_threshold = thresholds[thresh_i]
                    best_metrics = metrics
                    shutil.move(tmp_file, pred_file)
                #######
                metric_dict[thresh_i] = (metrics['em'], metrics['f1'], metrics['sp_em'], metrics['sp_f1'], metrics['joint_em'], metrics['joint_f1'])
                #######
            return best_metrics, best_threshold, metric_dict

        output_pred_file = os.path.join(self.hparams.exp_name, f'pred.epoch_{self.current_epoch + 1}.gpu_{self.trainer.root_gpu}.json')
        output_eval_file = os.path.join(self.hparams.exp_name, f'eval.epoch_{self.current_epoch + 1}.gpu_{self.trainer.root_gpu}.txt')
        ####+++++
        best_metrics, best_threshold, metric_dict = choose_best_threshold(answer_dict, output_pred_file)
        ####++++++
        logging.info('Leader board evaluation completed over {} records with threshold = {:.4f}'.format(total_record_num, best_threshold))
        log_metrics(mode='Evaluation epoch {} gpu {}'.format(self.current_epoch, self.trainer.root_gpu), metrics=best_metrics)
        logging.info('*' * 75)
        ####++++++
        for key, value in metric_dict.items():
            str_value = ['{:.4f}'.format(_) for _ in value]
            logging.info('threshold {:.4f}: \t metrics: {}'.format(thresholds[key], str_value))
        ####++++++
        json.dump(best_metrics, open(output_eval_file, 'w'))
        #############################################################################
        # self.log('valid_loss', avg_loss, 'joint_f1', best_metrics['joint_f1'], on_epoch=True, prog_bar=True, sync_dist=True)
        joint_f1 = torch.Tensor([best_metrics['joint_f1']])[0].to(avg_loss.device)
        self.log('joint_f1', joint_f1, on_epoch=True, prog_bar=True, sync_dist=True)
        #############################################################################
        return best_metrics, best_threshold

    def configure_optimizers(self):
        # "Prepare optimizer and schedule (linear warmup and decay)"
        if self.hparams.optimizer == 'Adam':
            if self.hparams.learning_rate_schema == 'fixed':
                return self.fixed_learning_rate_optimizers()
            elif self.hparamself.learning_rate_schema == 'layer_decay':
                return self.layer_wise_learning_rate_optimizer()
            else:
                raise 'Wrong lr setting method = {}'.format(self.hparams.learning_rate_schema)
        else:
            return self.rec_adam_learning_optimizer()

    def fixed_learning_rate_optimizers(self):
        "Prepare optimizer and schedule (linear warmup and decay)"
        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [p for n, p in self.named_parameters() if
                           (p.requires_grad) and (not any(nd in n for nd in no_decay))],
                "weight_decay": self.hparams.weight_decay,
            },
            {
                "params": [p for n, p in self.named_parameters() if
                           (p.requires_grad) and (any(nd in n for nd in no_decay))],
                "weight_decay": 0.0,
            }
        ]
        optimizer = AdamW(optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon)
        # scheduler = get_linear_schedule_with_warmup(
        #     optimizer, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=self.total_steps
        # )
        if self.hparams.lr_scheduler == 'linear':
            scheduler = get_linear_schedule_with_warmup(optimizer,
                                                        num_warmup_steps=self.hparams.warmup_steps,
                                                        num_training_steps=self.total_steps)
        elif self.hparams.lr_scheduler == 'cosine':
            scheduler = get_cosine_schedule_with_warmup(optimizer=optimizer,
                                                        num_warmup_steps=self.hparams.warmup_steps,
                                                        num_training_steps=self.total_steps)
        elif self.hparams.lr_scheduler == 'cosine_restart':
            scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(optimizer=optimizer,
                                                                           num_warmup_steps=self.hparams.warmup_steps,
                                                                           num_training_steps=self.total_steps)
        else:
            raise '{} is not supported'.format(self.hparams.lr_scheduler)
        scheduler = {
            'scheduler': scheduler,
            'interval': 'step',
            'frequency': 1
        }
        return [optimizer], [scheduler]

    def layer_wise_learning_rate_optimizer(self):
        "Prepare optimizer and schedule (linear warmup and decay)"
        encoder_layer_number_dict = {'roberta-large': 24, 'albert-xxlarge-v2': 1}
        assert self.hparams.encoder_name_or_path in encoder_layer_number_dict

        def achieve_module_groups(encoder, number_of_layer, number_of_groups):
            layer_num_each_group = number_of_layer // number_of_groups
            number_of_divided_groups = number_of_groups + 1 if number_of_layer % number_of_groups > 0 else number_of_groups
            groups = []
            groups.append([encoder.embeddings, *encoder.encoder.layer[:layer_num_each_group]])
            for group_id in range(1, number_of_divided_groups):
                groups.append(
                    [*encoder.encoder.layer[(group_id * layer_num_each_group):((group_id + 1) * layer_num_each_group)]])
            return groups, number_of_divided_groups

        if self.hparams.encoder_name_or_path == 'roberta-large':
            encoder_layer_number = encoder_layer_number_dict[self.hparams.encoder_name_or_path]
            encoder_group_number = encoder_layer_number
            module_groups, encoder_group_number = achieve_module_groups(encoder=self.encoder,
                                                                        number_of_layer=encoder_layer_number,
                                                                        number_of_groups=encoder_group_number)
            module_groups.append([self.model])
            assert len(module_groups) == encoder_group_number + 1
        elif self.hparams.encoder_name_or_path == 'albert-xxlarge-v2':
            module_groups = []
            module_groups.append([self.encoder])
            module_groups.append([self.model])
            assert len(module_groups) == 2
        else:
            raise 'Not supported {}'.format(self.hparams.encoder_name_or_path)

        def achieve_parameter_groups(module_group, weight_decay, lr):
            named_parameters = []
            no_decay = ["bias", "LayerNorm.weight"]
            for module in module_group:
                named_parameters += module.named_parameters()
            grouped_parameters = [
                {
                    "params": [p for n, p in named_parameters if
                               (p.requires_grad) and (not any(nd in n for nd in no_decay))],
                    "weight_decay": weight_decay, 'lr': lr
                },
                {
                    "params": [p for n, p in named_parameters if
                               (p.requires_grad) and (any(nd in n for nd in no_decay))],
                    "weight_decay": 0.0, 'lr': lr
                }
            ]
            return grouped_parameters

        optimizer_grouped_parameters = []
        for idx, module_group in enumerate(module_groups):
            lr = self.hparams.learning_rate * (10.0 ** idx)
            logging.info('group {} lr = {}'.format(idx, lr))
            grouped_parameters = achieve_parameter_groups(module_group=module_group,
                                                          weight_decay=self.hparams.weight_decay,
                                                          lr=lr)
            optimizer_grouped_parameters += grouped_parameters

        optimizer = AdamW(optimizer_grouped_parameters, lr=self.hparams.learning_rate,
                          eps=self.hparams.adam_epsilon)
        # scheduler = get_linear_schedule_with_warmup(
        #     optimizer, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=self.total_steps
        # )
        if self.hparams.lr_scheduler == 'linear':
            scheduler = get_linear_schedule_with_warmup(optimizer,
                                                        num_warmup_steps=self.hparams.warmup_steps,
                                                        num_training_steps=self.total_steps)
        elif self.hparams.lr_scheduler == 'cosine':
            scheduler = get_cosine_schedule_with_warmup(optimizer=optimizer,
                                                        num_warmup_steps=self.hparams.warmup_steps,
                                                        num_training_steps=self.total_steps)
        elif self.hparams.lr_scheduler == 'cosine_restart':
            scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(optimizer=optimizer,
                                                                           num_warmup_steps=self.hparams.warmup_steps,
                                                                           num_training_steps=self.total_steps)
        else:
            raise '{} is not supported'.format(self.hparams.lr_scheduler)
        scheduler = {
            'scheduler': scheduler,
            'interval': 'step',
            'frequency': 1
        }
        return [optimizer], [scheduler]

    def rec_adam_learning_optimizer(self):
        no_decay = ["bias", "LayerNorm.weight"]
        new_model = self.model
        args = self.hparams
        pretrained_model = self.encoder
        optimizer_grouped_parameters = [
            {
                "params": [p for n, p in new_model.named_parameters() if
                           not any(nd in n for nd in no_decay) and args.model_type in n],
                "weight_decay": args.weight_decay,
                "anneal_w": args.recadam_anneal_w,
                "pretrain_params": [p_p for p_n, p_p in pretrained_model.named_parameters() if
                                    not any(nd in p_n for nd in no_decay) and args.model_type in p_n]
            },
            {
                "params": [p for n, p in new_model.named_parameters() if
                           not any(nd in n for nd in no_decay) and args.model_type not in n],
                "weight_decay": args.weight_decay,
                "anneal_w": 0.0,
                "pretrain_params": [p_p for p_n, p_p in pretrained_model.named_parameters() if
                                    not any(nd in p_n for nd in no_decay) and args.model_type not in p_n]
            },
            {
                "params": [p for n, p in new_model.named_parameters() if
                           any(nd in n for nd in no_decay) and args.model_type in n],
                "weight_decay": 0.0,
                "anneal_w": args.recadam_anneal_w,
                "pretrain_params": [p_p for p_n, p_p in pretrained_model.named_parameters() if
                                    any(nd in p_n for nd in no_decay) and args.model_type in p_n]
            },
            {
                "params": [p for n, p in new_model.named_parameters() if
                           any(nd in n for nd in no_decay) and args.model_type not in n],
                "weight_decay": 0.0,
                "anneal_w": 0.0,
                "pretrain_params": [p_p for p_n, p_p in pretrained_model.named_parameters() if
                                    any(nd in p_n for nd in no_decay) and args.model_type not in p_n]
            }
        ]
        optimizer = RecAdam(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon,
                            anneal_fun=args.recadam_anneal_fun, anneal_k=args.recadam_anneal_k,
                            anneal_t0=args.recadam_anneal_t0, pretrain_cof=args.recadam_pretrain_cof)
        # scheduler = get_linear_schedule_with_warmup(
        #     optimizer, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=self.total_steps
        # )
        if self.hparams.lr_scheduler == 'linear':
            scheduler = get_linear_schedule_with_warmup(optimizer,
                                                        num_warmup_steps=self.hparams.warmup_steps,
                                                        num_training_steps=self.total_steps)
        elif self.hparams.lr_scheduler == 'cosine':
            scheduler = get_cosine_schedule_with_warmup(optimizer=optimizer,
                                                        num_warmup_steps=self.hparams.warmup_steps,
                                                        num_training_steps=self.total_steps)
        elif self.hparams.lr_scheduler == 'cosine_restart':
            scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(optimizer=optimizer,
                                                                           num_warmup_steps=self.hparams.warmup_steps,
                                                                           num_training_steps=self.total_steps)
        else:
            raise '{} is not supported'.format(self.hparams.lr_scheduler)
        scheduler = {
            'scheduler': scheduler,
            'interval': 'step',
            'frequency': 1
        }
        return [optimizer], [scheduler]

    # def configure_optimizers(self):
    #     # "Prepare optimizer and schedule (linear warmup and decay)"
    #     if self.hparams.learning_rate_schema == 'fixed':
    #         return self.fixed_learning_rate_optimizers()
    #     else:
    #         return self.layer_wise_learning_rate_optimizer()
    #
    # def fixed_learning_rate_optimizers(self):
    #     "Prepare optimizer and schedule (linear warmup and decay)"
    #     no_decay = ["bias", "LayerNorm.weight"]
    #     optimizer_grouped_parameters = [
    #         {
    #             "params": [p for n, p in self.named_parameters() if
    #                        (p.requires_grad) and (not any(nd in n for nd in no_decay))],
    #             "weight_decay": self.hparams.weight_decay,
    #         },
    #         {
    #             "params": [p for n, p in self.named_parameters() if
    #                        (p.requires_grad) and (any(nd in n for nd in no_decay))],
    #             "weight_decay": 0.0,
    #         }
    #     ]
    #     optimizer = AdamW(optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon)
    #     scheduler = get_linear_schedule_with_warmup(
    #         optimizer, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=self.total_steps
    #     )
    #     scheduler = {
    #         'scheduler': scheduler,
    #         'interval': 'step',
    #         'frequency': 1
    #     }
    #     return [optimizer], [scheduler]
    #
    # def layer_wise_learning_rate_optimizer(self):
    #     "Prepare optimizer and schedule (linear warmup and decay)"
    #     encoder_layer_number_dict = {'roberta-large': 24, 'albert-xxlarge-v2': 1}
    #     assert self.hparams.encoder_name_or_path in encoder_layer_number_dict
    #
    #     def achieve_module_groups(encoder, number_of_layer, number_of_groups):
    #         layer_num_each_group = number_of_layer // number_of_groups
    #         number_of_divided_groups = number_of_groups + 1 if number_of_layer % number_of_groups > 0 else number_of_groups
    #         groups = []
    #         groups.append([encoder.embeddings, *encoder.encoder.layer[:layer_num_each_group]])
    #         for group_id in range(1, number_of_divided_groups):
    #             groups.append(
    #                 [*encoder.encoder.layer[(group_id * layer_num_each_group):((group_id + 1) * layer_num_each_group)]])
    #         return groups, number_of_divided_groups
    #     if self.hparams.encoder_name_or_path == 'roberta-large':
    #         encoder_layer_number = encoder_layer_number_dict[self.hparams.encoder_name_or_path]
    #         encoder_group_number = 2
    #         module_groups, encoder_group_number = achieve_module_groups(encoder=self.encoder, number_of_layer=encoder_layer_number,
    #                                               number_of_groups=encoder_group_number)
    #         module_groups.append([self.model])
    #         assert len(module_groups) == encoder_group_number + 1
    #     elif self.hparams.encoder_name_or_path == 'albert-xxlarge-v2':
    #         module_groups = []
    #         module_groups.append([self.encoder])
    #         module_groups.append([self.model])
    #         assert len(module_groups) == 2
    #     else:
    #         raise 'Not supported {}'.format(self.hparams.encoder_name_or_path)
    #
    #     def achieve_parameter_groups(module_group, weight_decay, lr):
    #         named_parameters = []
    #         no_decay = ["bias", "LayerNorm.weight"]
    #         for module in module_group:
    #             named_parameters += module.named_parameters()
    #         grouped_parameters = [
    #             {
    #                 "params": [p for n, p in named_parameters if
    #                            (p.requires_grad) and (not any(nd in n for nd in no_decay))],
    #                 "weight_decay": weight_decay, 'lr': lr
    #             },
    #             {
    #                 "params": [p for n, p in named_parameters if
    #                            (p.requires_grad) and (any(nd in n for nd in no_decay))],
    #                 "weight_decay": 0.0, 'lr': lr
    #             }
    #         ]
    #         return grouped_parameters
    #
    #     optimizer_grouped_parameters = []
    #     for idx, module_group in enumerate(module_groups):
    #         lr = self.hparams.learning_rate * (10.0**idx)
    #         logging.info('group {} lr = {}'.format(idx, lr))
    #         grouped_parameters = achieve_parameter_groups(module_group=module_group,
    #                                                       weight_decay=self.hparams.weight_decay,
    #                                                       lr=lr)
    #         optimizer_grouped_parameters += grouped_parameters
    #
    #     optimizer = AdamW(optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon)
    #     scheduler = get_linear_schedule_with_warmup(
    #         optimizer, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=self.total_steps
    #     )
    #     scheduler = {
    #         'scheduler': scheduler,
    #         'interval': 'step',
    #         'frequency': 1
    #     }
    #     return [optimizer], [scheduler]