예제 #1
0
    def log_results(results, dataset_name, steps, logging=True, print=True):
        # Print a header
        header = "\n\n"
        header += BUSH_SEP + "\n"
        header += "***************************************************\n"
        header += f"***** EVALUATION | {dataset_name.upper()} SET | AFTER {steps} BATCHES *****\n"
        header += "***************************************************\n"
        header += BUSH_SEP + "\n"
        logger.info(header)

        for head_num, head in enumerate(results):
            logger.info("\n _________ {} _________".format(head['task_name']))
            for metric_name, metric_val in head.items():
                # log with ML framework (e.g. Mlflow)
                if logging:
                    if isinstance(metric_val, numbers.Number):
                        MlLogger.log_metrics(
                            metrics={
                                f"{dataset_name}_{metric_name}_{head['task_name']}": metric_val
                            },
                            step=steps,
                        )
                # print via standard python logger
                if print:
                    if metric_name == "report":
                        if isinstance(metric_val, str) and len(metric_val) > 8000:
                            metric_val = metric_val[:7500] + "\n ............................. \n" + metric_val[-500:]
                        logger.info("{}: \n {}".format(metric_name, metric_val))
                    else:
                        logger.info("{}: {}".format(metric_name, metric_val))
예제 #2
0
파일: train.py 프로젝트: skiran252/FARM
    def backward_propagate(self, loss, step):
        loss = self.adjust_loss(loss)
        if self.global_step % self.log_loss_every == 0 and self.local_rank in [
                -1, 0
        ]:
            if self.local_rank in [-1, 0]:
                MlLogger.log_metrics(
                    {"Train_loss_total": float(loss.detach().cpu().numpy())},
                    step=self.global_step,
                )
                if self.log_learning_rate:
                    MlLogger.log_metrics(
                        {"learning_rate": self.lr_schedule.get_last_lr()[0]},
                        step=self.global_step)
        if self.use_amp:
            with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()

        if step % self.grad_acc_steps == 0:
            if self.max_grad_norm is not None:
                if self.use_amp:
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(self.optimizer), self.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                                   self.max_grad_norm)
            self.optimizer.step()
            self.optimizer.zero_grad()
            if self.lr_schedule:
                self.lr_schedule.step()
        return loss
예제 #3
0
파일: train.py 프로젝트: wwmmqq/FARM
    def backward_propagate(self, loss, step):
        loss = self.adjust_loss(loss)
        if self.global_step % 10 == 1:
            MlLogger.log_metrics(
                {"Train_loss_total": float(loss.detach().cpu().numpy())},
                step=self.global_step,
            )
        if self.use_amp:
            with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()

        if self.log_learning_rate:
            MlLogger.log_metrics(
                {"learning_rate": self.lr_schedule.get_lr()[0]},
                step=self.global_step)

        if step % self.grad_acc_steps == 0:
            # TODO We might wanna add gradient clipping here
            self.optimizer.step()
            self.optimizer.zero_grad()
            if self.lr_schedule:
                self.lr_schedule.step()
        return loss
    def logits_to_loss(self, logits, global_step=None, **kwargs):
        """
        Get losses from all prediction heads & reduce to single loss *per sample*.

        :param logits: logits, can vary in shape and type, depending on task
        :type logits: object
        :param global_step: number of current training step
        :type global_step: int
        :param kwargs: placeholder for passing generic parameters.
                       Note: Contains the batch (as dict of tensors), when called from Trainer.train().
        :type kwargs: object
        :return loss: torch.tensor that is the per sample loss (len: batch_size)
        """
        all_losses = self.logits_to_loss_per_head(logits, **kwargs)
        # This aggregates the loss per sample across multiple prediction heads
        # Default is sum(), but you can configure any fn that takes [Tensor, Tensor ...] and returns [Tensor]

        # Log the loss per task
        for i, per_sample_loss in enumerate(all_losses):
            task_name = self.prediction_heads[i].task_name
            task_loss = per_sample_loss.mean()
            MlLogger.log_metrics(
                {
                    f"train_loss_{task_name}":
                    float(task_loss.detach().cpu().numpy())
                },
                step=global_step)

        loss = self.loss_aggregation_fn(all_losses,
                                        global_step=global_step,
                                        batch=kwargs)
        return loss
예제 #5
0
 def log_results(results, dataset_name, steps, logging=True, print=True):
     logger.info(
         "\n***** Evaluation Results on {} data after {} steps *****".format(
             dataset_name, steps
         )
     )
     for head_num, head in enumerate(results):
         logger.info("\n _________ Prediction Head {} _________".format(head_num))
         for metric_name, metric_val in head.items():
             # log with ML framework (e.g. Mlflow)
             if logging:
                 if isinstance(metric_val, numbers.Number):
                     MlLogger.log_metrics(
                         metrics={
                             f"{dataset_name}_{metric_name}_head{head_num}": metric_val
                         },
                         step=steps,
                     )
             # print via standard python logger
             if print:
                 if metric_name == "report":
                     if isinstance(metric_val, str) and len(metric_val) > 8000:
                         metric_val = metric_val[:7500] + "\n ............................. \n" + metric_val[-500:]
                     logger.info("{}: \n {}".format(metric_name, metric_val))
                 else:
                     logger.info("{}: {}".format(metric_name, metric_val))
예제 #6
0
파일: eval.py 프로젝트: zzzbit/COVID-QA
def eval_question_similarity(y_true,
                             y_pred,
                             lang,
                             model_name,
                             params,
                             user=None,
                             log_to_mlflow=True,
                             run_name="default"):
    # basic metrics
    mean_diff = np.mean(np.abs(y_true - y_pred))
    roc_auc = roc_auc_score(y_true, y_pred)
    f1 = f1_score(y_true, y_pred.round(0))
    metrics = {"roc_auc": roc_auc, "mean_abs_diff": mean_diff, "f1_score": f1}
    print(metrics)

    # log experiment results to MLFlow (visit https://public-mlflow.deepset.ai/)
    if log_to_mlflow:
        params["lang"] = lang
        params["model_name"] = model_name
        if user:
            params["user"] = user

        ml_logger = MLFlowLogger(
            tracking_uri="https://public-mlflow.deepset.ai/")
        ml_logger.init_experiment(experiment_name="COVID-question-sim",
                                  run_name=run_name)
        ml_logger.log_params(params)
        ml_logger.log_metrics(metrics, step=0)
def log_results(results, dataset_name, steps, logging=True, print=True, save_path=None, num_fold=None):
    logger = get_logger(__name__)

    # Print a header
    header = "\n\n"
    header += BUSH_SEP + "\n"
    header += "***************************************************\n"
    if num_fold:
        header += f"***** EVALUATION | FOLD: {num_fold} | {dataset_name.upper()} SET | AFTER {steps} BATCHES *****\n"
    else:
        header += f"***** EVALUATION | {dataset_name.upper()} SET | AFTER {steps} BATCHES *****\n"
    header += "***************************************************\n"
    header += BUSH_SEP + "\n"
    logger.info(header)

    save_log = header

    for head_num, head in enumerate(results):
        logger.info("\n _________ {} _________".format(head['task_name']))
        for metric_name, metric_val in head.items():
            metric_log = None

            # log with ML framework (e.g. Mlflow)
            if logging:
                if not metric_name in ["preds", "probs", "labels"] and not metric_name.startswith("_"):
                    if isinstance(metric_val, numbers.Number):
                        MlLogger.log_metrics(
                            metrics={
                                f"{dataset_name}_{metric_name}_{head['task_name']}": metric_val
                            },
                            step=steps,
                        )

            # print via standard python logger
            if print:
                if metric_name == "report":
                    if isinstance(metric_val, str) and len(metric_val) > 8000:
                        metric_val = metric_val[:7500] + "\n ............................. \n" + metric_val[-500:]
                    metric_log = "{}: \n {}".format(metric_name, metric_val)
                    logger.info(metric_log)
                else:
                    if not metric_name in ["preds", "probs", "labels"] and not metric_name.startswith("_"):
                        metric_log = "{}: {};".format(metric_name, metric_val)
                        logger.info(metric_log)

            if save_path and metric_log:
                save_log += "\n" + metric_log

    if save_path:
        with open(save_path, "w", encoding="utf-8") as log_file:
            log_file.write(save_log)
예제 #8
0
    def dataset_from_file(self, file, log_time=True):
        """
        Contains all the functionality to turn a data file into a PyTorch Dataset and a
        list of tensor names. This is used for training and evaluation.

        :param file: Name of the file containing the data.
        :type file: str
        :return: a Pytorch dataset and a list of tensor names.
        """
        if log_time:
            a = time.time()
            self._init_baskets_from_file(file)
            b = time.time()
            MlLogger.log_metrics(metrics={"t_from_file": (b - a) / 60}, step=0)
            self._init_samples_in_baskets()
            c = time.time()
            MlLogger.log_metrics(metrics={"t_init_samples": (c - b) / 60}, step=0)
            self._featurize_samples()
            d = time.time()
            MlLogger.log_metrics(metrics={"t_featurize_samples": (d - c) / 60}, step=0)
            self._log_samples(3)
        else:
            self._init_baskets_from_file(file)
            self._init_samples_in_baskets()
            self._featurize_samples()
            self._log_samples(3)
        dataset, tensor_names = self._create_dataset()
        return dataset, tensor_names
예제 #9
0
파일: train.py 프로젝트: Wkryst/FARM
    def backward_propagate(self, loss, step):
        loss = self.adjust_loss(loss)
        if self.global_step % 10 == 1:
            MlLogger.log_metrics(
                {"Train_loss_total": float(loss.detach().cpu().numpy())},
                step=self.global_step,
            )
        if self.fp16:
            self.optimizer.backward(loss)
        else:
            loss.backward()

        if (step + 1) % self.grad_acc_steps == 0:
            if self.fp16:
                # modify learning rate with special warm up BERT uses
                # if args.fp16 is False, BertAdam is used that handles this automatically
                lr_this_step = self.learning_rate * self.warmup_linear.get_lr(
                    self.global_step, self.warmup_proportion)
                for param_group in self.optimizer.param_groups:
                    param_group["lr"] = lr_this_step
                # MlLogger.write_metrics({"learning_rate": lr_this_step}, step=self.global_step)
            self.optimizer.step()
            self.optimizer.zero_grad()
예제 #10
0
def question_answering_crossvalidation():
    ##########################
    ########## Logging
    ##########################
    logger = logging.getLogger(__name__)
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO)
    # reduce verbosity from transformers library
    logging.getLogger('transformers').setLevel(logging.WARNING)

    #ml_logger = MLFlowLogger(tracking_uri="https://public-mlflow.deepset.ai/")
    # for local logging instead:
    ml_logger = MLFlowLogger(tracking_uri="logs")
    #ml_logger.init_experiment(experiment_name="QA_X-Validation", run_name="Squad_Roberta_Base")

    ##########################
    ########## Settings
    ##########################
    save_per_fold_results = False  # unsupported for now
    set_all_seeds(seed=42)
    device, n_gpu = initialize_device_settings(use_cuda=True)

    lang_model = "deepset/roberta-base-squad2"
    do_lower_case = False

    n_epochs = 2
    batch_size = 80
    learning_rate = 3e-5

    data_dir = Path("../data/covidqa")
    filename = "COVID-QA.json"
    xval_folds = 5
    dev_split = 0
    evaluate_every = 0
    no_ans_boost = -100  # use large negative values to disable giving "no answer" option
    accuracy_at = 3  # accuracy at n is useful for answers inside long documents
    use_amp = None

    ##########################
    ########## k fold Cross validation
    ##########################

    # 1.Create a tokenizer
    tokenizer = Tokenizer.load(pretrained_model_name_or_path=lang_model,
                               do_lower_case=do_lower_case)

    # 2. Create a DataProcessor that handles all the conversion from raw text into a pytorch Dataset
    processor = SquadProcessor(
        tokenizer=tokenizer,
        max_seq_len=384,
        label_list=["start_token", "end_token"],
        metric="squad",
        train_filename=filename,
        dev_filename=None,
        dev_split=dev_split,
        test_filename=None,
        data_dir=data_dir,
        doc_stride=192,
    )

    # 3. Create a DataSilo that loads several datasets (train/dev/test), provides DataLoaders for them and calculates a few descriptive statistics of our datasets
    data_silo = DataSilo(processor=processor, batch_size=batch_size)

    # Load one silo for each fold in our cross-validation
    silos = DataSiloForCrossVal.make(data_silo, n_splits=xval_folds)

    # the following steps should be run for each of the folds of the cross validation, so we put them
    # into a function
    def train_on_split(silo_to_use, n_fold):
        logger.info(
            f"############ Crossvalidation: Fold {n_fold} ############")

        # fine-tune pre-trained question-answering model
        model = AdaptiveModel.convert_from_transformers(
            lang_model, device=device, task_type="question_answering")
        model.connect_heads_with_processor(data_silo.processor.tasks,
                                           require_labels=True)
        # If positive, thjs will boost "No Answer" as prediction.
        # If negative, this will prevent the model from giving "No Answer" as prediction.
        model.prediction_heads[0].no_ans_boost = no_ans_boost
        # Number of predictions the model will make per Question.
        # The multiple predictions are used for evaluating top n recall.
        model.prediction_heads[0].n_best = accuracy_at

        # # or train question-answering models from scratch
        # # Create an AdaptiveModel
        # # a) which consists of a pretrained language model as a basis
        # language_model = LanguageModel.load(lang_model)
        # # b) and a prediction head on top that is suited for our task => Question-answering
        # prediction_head = QuestionAnsweringHead(no_ans_boost=no_ans_boost, n_best=accuracy_at)
        # model = AdaptiveModel(
        #    language_model=language_model,
        #    prediction_heads=[prediction_head],
        #    embeds_dropout_prob=0.1,
        #    lm_output_types=["per_token"],
        #    device=device,)

        # Create an optimizer
        model, optimizer, lr_schedule = initialize_optimizer(
            model=model,
            learning_rate=learning_rate,
            device=device,
            n_batches=len(silo_to_use.loaders["train"]),
            n_epochs=n_epochs,
            use_amp=use_amp)

        # Feed everything to the Trainer, which keeps care of growing our model into powerful plant and evaluates it from time to time
        # Also create an EarlyStopping instance and pass it on to the trainer

        trainer = Trainer(model=model,
                          optimizer=optimizer,
                          data_silo=silo_to_use,
                          epochs=n_epochs,
                          n_gpu=n_gpu,
                          lr_schedule=lr_schedule,
                          evaluate_every=evaluate_every,
                          device=device,
                          evaluator_test=False)

        # train it
        trainer.train()

        return trainer.model

    # for each fold, run the whole training, then evaluate the model on the test set of each fold
    # Remember all the results for overall metrics over all predictions of all folds and for averaging
    all_results = []
    all_preds = []
    all_labels = []
    all_f1 = []
    all_em = []
    all_topnaccuracy = []

    for num_fold, silo in enumerate(silos):
        model = train_on_split(silo, num_fold)

        # do eval on test set here (and not in Trainer),
        # so that we can easily store the actual preds and labels for a "global" eval across all folds.
        evaluator_test = Evaluator(data_loader=silo.get_data_loader("test"),
                                   tasks=silo.processor.tasks,
                                   device=device)
        result = evaluator_test.eval(model, return_preds_and_labels=True)
        evaluator_test.log_results(result,
                                   "Test",
                                   logging=False,
                                   steps=len(silo.get_data_loader("test")),
                                   num_fold=num_fold)

        all_results.append(result)
        all_preds.extend(result[0].get("preds"))
        all_labels.extend(result[0].get("labels"))
        all_f1.append(result[0]["f1"])
        all_em.append(result[0]["EM"])
        all_topnaccuracy.append(result[0]["top_n_accuracy"])

        # emtpy cache to avoid memory leak and cuda OOM across multiple folds
        model.cpu()
        torch.cuda.empty_cache()

    # Save the per-fold results to json for a separate, more detailed analysis
    # TODO currently not supported - adjust to QAPred and QACandidate objects
    # if save_per_fold_results:
    #     def convert_numpy_dtype(obj):
    #         if type(obj).__module__ == "numpy":
    #             return obj.item()
    #
    #         raise TypeError("Unknown type:", type(obj))
    #
    #     with open("qa_xval.results.json", "wt") as fp:
    #          json.dump(all_results, fp, default=convert_numpy_dtype)

    # calculate overall metrics across all folds
    xval_score = squad(preds=all_preds, labels=all_labels)

    logger.info(f"Single EM-Scores:   {all_em}")
    logger.info(f"Single F1-Scores:   {all_f1}")
    logger.info(
        f"Single top_{accuracy_at}_accuracy Scores:   {all_topnaccuracy}")
    logger.info(f"XVAL EM:   {xval_score['EM']}")
    logger.info(f"XVAL f1:   {xval_score['f1']}")
    logger.info(
        f"XVAL top_{accuracy_at}_accuracy:   {xval_score['top_n_accuracy']}")
    ml_logger.log_metrics({"XVAL EM": xval_score["EM"]}, 0)
    ml_logger.log_metrics({"XVAL f1": xval_score["f1"]}, 0)
    ml_logger.log_metrics(
        {f"XVAL top_{accuracy_at}_accuracy": xval_score["top_n_accuracy"]}, 0)
def doc_classification_crossvalidation():
    # the code for this function is partially taken from:
    # https://github.com/deepset-ai/FARM/blob/master/examples/doc_classification_multilabel.py and
    # https://github.com/deepset-ai/FARM/blob/master/examples/doc_classification_crossvalidation.py

    # for local logging:
    ml_logger = MLFlowLogger(tracking_uri="")
    ml_logger.init_experiment(experiment_name="covid-document-classification",
                              run_name=RUNNAME)

    # model settings
    xval_folds = FOLDS
    set_all_seeds(seed=42)
    device, n_gpu = initialize_device_settings(use_cuda=True)
    if RUNLOCAL:
        device = "cpu"
    n_epochs = NEPOCHS
    batch_size = BATCHSIZE
    evaluate_every = EVALEVERY
    lang_model = MODELTYPE
    do_lower_case = False

    # 1.Create a tokenizer
    tokenizer = Tokenizer.load(
        pretrained_model_name_or_path=lang_model,
        do_lower_case=do_lower_case)

    metric = "f1_macro"

    # 2. Create a DataProcessor that handles all the conversion from raw text into a pytorch Dataset
    # The processor wants to know the possible labels ...
    label_list = LABELS
    processor = TextClassificationProcessor(tokenizer=tokenizer,
                                            max_seq_len=MAXLEN,
                                            data_dir=DATADIR,
                                            train_filename=TRAIN,
                                            test_filename=TEST,
                                            dev_split=0.1,
                                            label_list=label_list,
                                            metric=metric,
                                            label_column_name="Categories",
                                            # confusing parameter name: it should be called multiCLASS
                                            # not multiLABEL
                                            multilabel=True
                                            )

    # 3. Create a DataSilo that loads several datasets (train/dev/test), provides DataLoaders for them and calculates a few descriptive statistics of our datasets
    data_silo = DataSilo(
        processor=processor,
        batch_size=batch_size)

    # Load one silo for each fold in our cross-validation
    silos = DataSiloForCrossVal.make(data_silo, n_splits=xval_folds)

    # the following steps should be run for each of the folds of the cross validation, so we put them
    # into a function
    def train_on_split(silo_to_use, n_fold, save_dir, dev):
        # Create an AdaptiveModel
        # a) which consists of a pretrained language model as a basis
        language_model = LanguageModel.load(lang_model)
        # b) and a prediction head on top that is suited for our task => Text classification
        prediction_head = MultiLabelTextClassificationHead(
            # there is still an error with class weights ...
            # class_weights=data_silo.calculate_class_weights(task_name="text_classification"),
            num_labels=len(label_list))

        model = AdaptiveModel(
            language_model=language_model,
            prediction_heads=[prediction_head],
            embeds_dropout_prob=0.2,
            lm_output_types=["per_sequence"],
            device=dev)

        # Create an optimizer
        model, optimizer, lr_schedule = initialize_optimizer(
            model=model,
            learning_rate=0.5e-5,
            device=dev,
            n_batches=len(silo_to_use.loaders["train"]),
            n_epochs=n_epochs)

        # Feed everything to the Trainer, which keeps care of growing our model into powerful plant and evaluates it from time to time
        # Also create an EarlyStopping instance and pass it on to the trainer
        save_dir = Path(str(save_dir) + f"-{n_fold}")
        # unfortunately, early stopping is still not working
        earlystopping = EarlyStopping(
            metric="f1_macro", mode="max",
            save_dir=save_dir,  # where to save the best model
            patience=5 # number of evaluations to wait for improvement before terminating the training
        )

        trainer = Trainer(model=model, optimizer=optimizer,
                          data_silo=silo_to_use, epochs=n_epochs,
                          n_gpu=n_gpu, lr_schedule=lr_schedule,
                          evaluate_every=evaluate_every,
                          device=dev, evaluator_test=False,
                          #early_stopping=earlystopping)
                          )
        # train it
        trainer.train()
        trainer.model.save(save_dir)
        return trainer.model

    # for each fold, run the whole training, earlystopping to get a model, then evaluate the model
    # on the test set of each fold
    # Remember all the results for overall metrics over all predictions of all folds and for averaging
    allresults = []
    all_preds = []
    all_labels = []
    bestfold = None
    bestf1_macro = -1
    save_dir = Path("saved_models/covid-classification-v1")

    for num_fold, silo in enumerate(silos):
        model = train_on_split(silo, num_fold, save_dir, device)

        # do eval on test set here (and not in Trainer),
        #  so that we can easily store the actual preds and labels for a "global" eval across all folds.
        evaluator_test = Evaluator(
            data_loader=silo.get_data_loader("test"),
            tasks=silo.processor.tasks,
            device=device,
        )
        result = evaluator_test.eval(model, return_preds_and_labels=True)

        os.makedirs(os.path.dirname(BESTMODEL + "/classification_report.txt"), exist_ok=True)
        with open(BESTMODEL + "/classification_report.txt", "a+") as file:
            file.write("Evaluation on withheld split for numfold no. {} \n".format(num_fold))
            file.write(result[0]["report"])
            file.write("\n\n")
            file.close()

        evaluator_test.log_results(result, "Test", steps=len(silo.get_data_loader("test")), num_fold=num_fold)

        allresults.append(result)
        all_preds.extend(result[0].get("preds"))
        all_labels.extend(result[0].get("labels"))

        # keep track of best fold
        f1_macro = result[0]["f1_macro"]
        if f1_macro > bestf1_macro:
            bestf1_macro = f1_macro
            bestfold = num_fold

    # Save the per-fold results to json for a separate, more detailed analysis
    with open("../data/predictions/covid-classification-xval.results.json", "wt") as fp:
        json.dump(allresults, fp, cls=NumpyArrayEncoder)

    # calculate overall f1 score across all folds
    xval_f1_macro = f1_score(all_labels, all_preds, average="macro")
    ml_logger.log_metrics({"f1 macro across all folds": xval_f1_macro}, step=None)

    # test performance
    evaluator_origtest = Evaluator(
        data_loader=data_silo.get_data_loader("test"),
        tasks=data_silo.processor.tasks,
        device=device
    )
    # restore model from the best fold
    lm_name = model.language_model.name
    save_dir = Path(f"saved_models/covid-classification-v1-{bestfold}")
    model = AdaptiveModel.load(save_dir, device, lm_name=lm_name)
    model.connect_heads_with_processor(data_silo.processor.tasks, require_labels=True)

    result = evaluator_origtest.eval(model)
    ml_logger.log_metrics({"f1 macro on final test set": result[0]["f1_macro"]}, step=None)

    with open(BESTMODEL + "/classification_report.txt", "a+") as file:
        file.write("Final result of the best model \n")
        file.write(result[0]["report"])
        file.write("\n\n")
        file.close()

    ml_logger.log_artifacts(BESTMODEL + "/")

    # save model for later use
    processor.save(BESTMODEL)
    model.save(BESTMODEL)
    return model
예제 #12
0
    def step(self, closure=None):
        """Performs a single optimization step.

        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None:
                    continue
                grad = p.grad.data
                if grad.is_sparse:
                    raise RuntimeError(
                        "Adam does not support sparse gradients, please consider SparseAdam instead"
                    )

                state = self.state[p]

                # State initialization
                if len(state) == 0:
                    state["step"] = 0
                    # Exponential moving average of gradient values
                    state["next_m"] = torch.zeros_like(p.data)
                    # Exponential moving average of squared gradient values
                    state["next_v"] = torch.zeros_like(p.data)

                state["step"] += 1
                next_m, next_v = state["next_m"], state["next_v"]
                beta1, beta2 = group["b1"], group["b2"]

                # Add grad clipping
                if group["max_grad_norm"] > 0:
                    clip_grad_norm_(p, group["max_grad_norm"])

                # Decay the first and second moment running average coefficient
                # In-place operations to update the averages at the same time
                next_m.mul_(beta1).add_(1 - beta1, grad)
                next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad)
                update = next_m / (next_v.sqrt() + group["e"])

                # Just adding the square of the weights to the loss function is *not*
                # the correct way of using L2 regularization/weight decay with Adam,
                # since that will interact with the m and v parameters in strange ways.
                #
                # Instead we want to decay the weights in a manner that doesn't interact
                # with the m/v parameters. This is equivalent to adding the square
                # of the weights to the loss with plain (non-momentum) SGD.
                if group["weight_decay"] > 0.0:
                    update += group["weight_decay"] * p.data

                lr_scheduled = group["lr"]
                lr_scheduled *= group["schedule"].get_lr(state["step"])

                update_with_lr = lr_scheduled * update
                p.data.add_(-update_with_lr)

                # step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1
                # No bias correction
                # bias_correction1 = 1 - beta1 ** state['step']
                # bias_correction2 = 1 - beta2 ** state['step']
        # Custom logging functionality
        if self.log_learning_rate:
            MlLogger.log_metrics({"learning_rate": lr_scheduled},
                                 step=state["step"])
            logger.info(f'step:{state["step"]}, lr:{lr_scheduled}')
        return loss