def main(config, results):

    model_config = tu.load_model_config(config)
    tokenizer = tu.load_tokenizer(config, model_config)

    ds = Dataset(config, tokenizer)
    label_names = ds.label_names
    model_config = tu.load_model_config(config)

    for i, (train_dataloader, (valid_dataloader, valid_df)) in enumerate(
            ds.get_train_valid_dataloaders(include_valid_df=True)):
        print(f"------------------ BEGIN ITER {i} -----------------------")
        model = tu.load_model(config, model_config)
        model.to(config.device)
        util.set_seed(config)

        run_name = f"fold{i}"

        experiment = Experiment(config,
                                model,
                                tokenizer,
                                label_names=label_names,
                                run_name=run_name,
                                results=results)
        global_step, tr_loss = experiment.train(
            train_dataloader, valid_dataloader=valid_dataloader)

        results = experiment.results
        # experiment.evaluate('valid', valid_dataloader)
        print(f"================== DONE ITER {i} =======================\n\n")

    return results
def main():
    config = get_config()
    with config:
        config.logging_steps = 400
        config.train_epochs = 2
        config.lr = 4e-5
        # config.lr = 1e-4
        config.model_type = 'roberta'
        config.model_path = util.models_path('StackOBERTflow-comments-small-v1')
        # config.train_head_only = True

    ds = TDDataset(config, binary=True)

    tokenizer = tu.load_tokenizer(config)
    model_cls = tu.get_model_cls(config)

    train_dataloader = ds.get_complete_train_dataloader(tokenizer)
    model = tu.load_model(config)
    model.to(config.device)
    util.set_seed(config)

    experiment = Experiment(config, model, tokenizer)
    global_step, tr_loss = experiment.train(train_dataloader)

    experiment.save_model(util.models_path('satd_complete_binary'))
Exemplo n.º 3
0
def main(config, results):

    model_config = tu.load_model_config(config)
    tokenizer = tu.load_tokenizer(config, model_config)

    ds = Dataset(config, tokenizer)
    label_names = ds.label_names

    train_dataloader, valid_dataloader, test_dataloader = getattr(
        ds, f"get_{config.dataset}_train_valid_test_dataloaders")()

    model = tu.load_model(config, model_config)
    model.resize_token_embeddings(len(tokenizer))
    model.to(config.device)
    util.set_seed(config)

    experiment = Experiment(config,
                            model,
                            tokenizer,
                            label_names=label_names,
                            results=results)
    global_step, tr_loss = experiment.train(train_dataloader,
                                            valid_dataloader=valid_dataloader,
                                            test_dataloader=test_dataloader)

    results = experiment.results
    return results
Exemplo n.º 4
0
def main(config, results):

    model_config = tu.load_model_config(config)
    tokenizer = tu.load_tokenizer(config, model_config)

    ds = Dataset(config, tokenizer)
    label_names = ds.label_names

    dataloaders = getattr(ds, f"get_{config.dataset}_train_valid_dataloaders")(include_valid_df=True)

    for i, (train_dataloader, (valid_dataloader, valid_df)) in enumerate(dataloaders):
        print(f"------------------ BEGIN ITER {i} -----------------------")
        # need to reload original model config, to avoid vocabulary size mismatch
        # caused by custom tokens
        model_config = tu.load_model_config(config)
        model = tu.load_model(config, model_config)
        model.resize_token_embeddings(len(tokenizer))
        model.to(config.device)
        util.set_seed(config)

        run_name = f"fold{i}"

        experiment = Experiment(config, model, tokenizer, label_names=label_names, run_name=run_name, results=results)
        global_step, tr_loss = experiment.train(
            train_dataloader, valid_dataloader=valid_dataloader)

        results = experiment.results
        # experiment.evaluate('valid', valid_dataloader)
        print(f"================== DONE ITER {i} =======================\n\n")

    return results
def main(config, results):
    model_config = tu.load_model_config(config)
    tokenizer = tu.load_tokenizer(config, model_config)

    ds = Dataset(config, tokenizer)
    label_names = ds.label_names

    train_dataloader, valid_dataloader = ds.get_train_valid_dataloaders()

    model = tu.load_model(config, model_config)
    model.to(config.device)
    util.set_seed(config)

    experiment = Experiment(config,
                            model,
                            tokenizer,
                            label_names=label_names,
                            results=results)
    global_step, tr_loss = experiment.train(train_dataloader,
                                            valid_dataloader=valid_dataloader)
    results = experiment.results

    experiment.save_model(util.models_path('comment_code_shuffle'))

    return results
Exemplo n.º 6
0
    def __init__(self,
                 config,
                 model,
                 tokenizer,
                 total_samples=None,
                 label_names=None,
                 results=None,
                 run_name=None):
        self.config = config
        self.model = model
        self.tokenizer = tokenizer

        self.global_step = 0

        self.optimizer_state_dict = None
        self.scheduler_state_dict = None

        self.total_samples = total_samples

        self.label_names = label_names

        self.results = results
        self.run_name = run_name

        util.set_seed(config)
def main(config, results):
    pd.set_option('display.max_rows', None)

    model_config = tu.load_model_config(config)
    tokenizer = tu.load_tokenizer(config, model_config)

    ds = SentiDataset(config, tokenizer)
    test_dataloader = ds.get_test_dataloader()

    model = tu.load_model(config, model_config)
    model.to(config.device)
    util.set_seed(config)

    train_dataloader, valid_dataloader = ds.get_train_valid_dataloaders()
    test_dataloader = ds.get_test_dataloader()

    test_dataloaders = {'test': ds.get_test_dataloader()}

    if config.jira:
        test_dataloaders['JIRA'] = (
            ds.get_jira_dataloader(),
            dict(pred_label_ids_func=ds.neutral_to_negative))

    if config.app_reviews:
        test_dataloaders['AppReviews'] = ds.get_app_reviews_dataloader()

    if config.sentidata_so:
        test_dataloaders[
            'StackOverflow (SentiData)'] = ds.get_stack_overflow_dataloader()

    experiment = Experiment(config,
                            model,
                            tokenizer,
                            label_names=ds.label_names,
                            results=results)
    global_step, tr_loss = experiment.train(train_dataloader,
                                            valid_dataloader=valid_dataloader,
                                            test_dataloader=test_dataloaders)

    # interp_df = experiment.interpret(test_dataloader, ds.test_df, label_names=ds.label_names)
    # with pd.option_context('display.max_rows', None, 'display.max_columns', None):
    #     print(interp_df)

    # if config.interp_out_file:
    #     interp_df.to_csv(config.interp_out_file, index=False)

    return experiment.results
def main(config, results):
    model_config = tu.load_model_config(config)
    tokenizer = tu.load_tokenizer(config, model_config)

    ds = Dataset(config, tokenizer)
    label_names = ds.label_names

    train_dataloader = ds.get_train_dataloader()
    fake_valid_dataloader = ds.get_fake_valid_dataloader()

    # with config:
    #     config.max_steps=100

    model = tu.load_model(config, model_config)
    model.to(config.device)
    util.set_seed(config)

    experiment = Experiment(config, model, tokenizer, label_names=label_names, results=results)
    global_step, tr_loss = experiment.train(train_dataloader, valid_dataloader=fake_valid_dataloader) #test_dataloader=test_dataloader)

    valid_dataloader = ds.get_valid_dataloader()
    test_dataloader = ds.get_test_dataloader()
    experiment.evaluate('test_final', test_dataloader)
    experiment.evaluate('valid_final', valid_dataloader)
    experiment.save_model('test_model_complexity')

    with config:
        config.model_path = 'test_model_complexity' 
    model = tu.load_model(config, model_config)
    model.to(config.device)
    logger.warn('#################################### =========================')
    experiment = Experiment(config, model, tokenizer, label_names=label_names, results=results)
    experiment.evaluate('test_final_reloaded', test_dataloader)
    experiment.evaluate('valid_final_reloaded', valid_dataloader)



    results = experiment.results
    
    return results
def main(config, results):
    model_config = tu.load_model_config(config)
    tokenizer = tu.load_tokenizer(config, model_config)

    if config.clap:
        ds = ClapDataset(config, tokenizer)
    else:
        ds = Dataset(config, tokenizer)

    label_names = ds.label_names
    model_config = tu.load_model_config(config)

    interp_out_file = Path(config.interp_out_file) if config.interp_out_file else None

    for i, (train_dataloader, (valid_dataloader, valid_df)) in enumerate(ds.get_train_valid_dataloaders(include_valid_df=True)):
        print(f"------------------ BEGIN ITER {i} -----------------------")
        model = tu.load_model(config, model_config)
        model.to(config.device)
        util.set_seed(config)

        run_name = f"fold{i}"

        experiment = Experiment(config, model, tokenizer, label_names=label_names, run_name=run_name, results=results)
        global_step, tr_loss = experiment.train(
            train_dataloader, valid_dataloader=valid_dataloader)


        if interp_out_file:
            interp_df = experiment.interpret(valid_dataloader, valid_df, label_names=label_names)
            with pd.option_context('display.max_rows', None, 'display.max_columns', None):
                print(interp_df)
            interp_df.to_csv(interp_out_file.with_name(f"{interp_out_file.name}_iter{i}"), index=False)

        results = experiment.results
        # experiment.evaluate('valid', valid_dataloader)
        print(f"================== DONE ITER {i} =======================\n\n")

    return results
def main(config, results):
    model_config = tu.load_model_config(config)
    tokenizer = tu.load_tokenizer(config, model_config)

    ds = Dataset(config, tokenizer)
    label_names = ds.label_names

    train_dataloader, valid_dataloader = ds.get_train_valid_dataloaders()
    test_dataloader = ds.get_test_dataloader()

    # with config:
    #     config.max_steps=100

    model = tu.load_model(config, model_config)
    model.to(config.device)
    util.set_seed(config)

    experiment = Experiment(config, model, tokenizer, label_names=label_names, results=results)
    global_step, tr_loss = experiment.train(train_dataloader, valid_dataloader=valid_dataloader, test_dataloader=test_dataloader)

    results = experiment.results
    
    return results
Exemplo n.º 11
0
def main():
    config = get_config(parse_args)
    util.set_seed(config)

    with config:
        config.logging_steps = 50
        config.train_epochs = 5

    # config.train_head_only = True

    print("model is now", config.model_path)
    ds = CadoDataset(config)
    label_names = ds.label_names

    model_config = tu.load_model_config(config)
    tokenizer = tu.load_tokenizer(config, model_config)

    f1s = []
    results = None

    test_dataloader = ds.get_test_dataloader(tokenizer)

    for i, (train_dataloader, (valid_dataloader, valid_df)) in enumerate(ds.get_all_train_valid_dataloaders(tokenizer, include_valid_df=True)):
        print(f"------------------ BEGIN ITER {i} -----------------------")
        model = tu.load_model(config, model_config)
        model.to(config.device)
        util.set_seed(config)

        run_name = f"{config.single_class if config.single_class else 'multi'}_{i}"

        experiment = Experiment(config, model, tokenizer, label_names=label_names, run_name=run_name, results=results)
        global_step, tr_loss = experiment.train(
            train_dataloader, valid_dataloader=valid_dataloader, test_dataloader=test_dataloader)

        results = experiment.results
        # experiment.evaluate('valid', valid_dataloader)
        print(f"================== DONE ITER {i} =======================\n\n")
def active_learn(config,
                 model_config,
                 tokenizer,
                 results,
                 label_names,
                 test_df,
                 full_pool_df,
                 backtrans_pool_dfs,
                 get_dataloader_func,
                 run_configs,
                 active_learning_iters=10,
                 dropout_iters=20,
                 balance=False):
    test_dataloader = get_dataloader_func(test_df, bs=config.eval_bs)

    for run_config in run_configs:
        method, dropout, backtrans_langs, cluster_size = run_config
        run_name = method.__name__
        if dropout:
            run_name += '_dropout'
        run_name = '_'.join([run_name, *backtrans_langs, f"c{cluster_size}"])

        util.set_seed(config)

        model = tu.load_model(config, model_config)
        model.to(config.device)

        # remove initial seed from pool
        train_df, pool_df = train_test_split(
            full_pool_df,
            train_size=config.active_learn_seed_size,
            random_state=config.seed)

        logger.info("RUN CONFIG: %s (pool size: %d)", run_name,
                    pool_df.shape[0])

        experiment = Experiment(config,
                                model,
                                tokenizer,
                                label_names=label_names,
                                run_name=run_name,
                                results=results)

        cur_iter = 0

        extra_log = {'iter': cur_iter, 'pool': pool_df.shape[0]}
        experiment.evaluate('test', test_dataloader, extra_log=extra_log)

        while pool_df.shape[0] > 0:
            train_dataloader = get_dataloader_func(train_df,
                                                   bs=config.train_bs,
                                                   balance=balance)

            # DON'T SHUFFLE THE POOL!
            dataloader_pool = get_dataloader_func(pool_df,
                                                  bs=config.eval_bs,
                                                  shuffle=False)

            logger.info(
                "=================== Remaining %d (%s) ================",
                pool_df.shape[0], run_config)
            logger.info(
                "Evaluating: training set size: %d | pool set size: %d",
                train_df.shape[0], pool_df.shape[0])

            global_step, tr_loss = experiment.train(train_dataloader)

            extra_log = {'iter': cur_iter, 'pool': pool_df.shape[0]}

            _, _, preds = experiment.evaluate('pool',
                                              dataloader_pool,
                                              extra_log=extra_log)
            experiment.evaluate('test', test_dataloader, extra_log=extra_log)

            if method != af.random_conf:
                if dropout:
                    for i in range(dropout_iters):
                        torch.manual_seed(i)

                        _, _, preds_i = experiment.evaluate('pool_dropout',
                                                            dataloader_pool,
                                                            mc_dropout=True,
                                                            skip_cb=True)
                        preds_i = torch.from_numpy(preds_i)
                        probs_i = F.softmax(preds_i, dim=1)

                        if i == 0:
                            probs = probs_i
                        else:
                            probs.add_(probs_i)
                    probs.div_(dropout_iters)
                else:
                    preds = torch.from_numpy(preds)
                    probs = F.softmax(preds, dim=1)
            else:
                preds = torch.from_numpy(preds)

                # only need the shape
                probs = preds

            scores = method(probs)
            _, topk_indices = torch.topk(
                scores,
                min(cluster_size * config.active_learn_step_size,
                    scores.shape[0]))

            if cluster_size > 1:
                topk_preds = preds[topk_indices]
                n_clusters = min(config.active_learn_step_size,
                                 scores.shape[0])
                kmeans = KMeans(n_clusters=n_clusters).fit(topk_preds.numpy())
                _, unique_indices = np.unique(kmeans.labels_,
                                              return_index=True)
                topk_indices = topk_indices[unique_indices]
                # assert(topk_indices.shape[0] == n_clusters)
                logger.debug("top_k: %s", topk_indices.shape)

            logger.debug("%s %s", scores.shape, pool_df.shape)

            assert (scores.shape[0] == pool_df.shape[0])

            uncertain_rows = pool_df.iloc[topk_indices]
            train_df = train_df.append(uncertain_rows, ignore_index=True)

            for backtrans_lang in backtrans_langs:
                backtrans_pool_df = backtrans_pool_dfs[backtrans_lang]
                backtrans_uncertain_rows = backtrans_pool_df[
                    backtrans_pool_df.id.isin(uncertain_rows.id)]
                train_df = train_df.append(backtrans_uncertain_rows,
                                           ignore_index=True)

            pool_df = pool_df.drop(pool_df.index[topk_indices])
            cur_iter += 1

        logger.debug(
            "Pool exhausted, stopping active learning loop (%d remaining)",
            pool_df.shape[0])

        results = experiment.results
    return results
def main():
    config = get_config()
    with config:
        config.logging_steps = 400
        config.train_epochs = 2
        config.lr = 4e-5
        # config.lr = 1e-4
        config.model_type = 'roberta'
        config.model_path = util.models_path('satd_complete_binary')
        # config.train_head_only = True

    tokenizer = tu.load_tokenizer(config)
    model_cls = tu.get_model_cls(config)

    df = pd.read_csv(util.data_path('satd', 'unclassified.csv'))
    # df = pd.read_csv(util.data_path('satd', 'dataset.csv'))
    df.dropna(inplace=True)
    # df.rename(columns={'classification': 'orig_classification'}, inplace=True)

    print(df.dtypes)

    print(df.head())

    df['preprocessed'] = df.commenttext.map(TDDataset.preprocess)
    df.dropna(inplace=True)
    # df = df.head(100)
    preprocessed = df.preprocessed.values
    dummy_labels = np.zeros(preprocessed.shape[0])
    dataloader = tu.get_dataloader(config,
                                   tokenizer,
                                   preprocessed,
                                   dummy_labels,
                                   bs=128,
                                   shuffle=False)

    model = tu.load_model(config)
    model.to(config.device)
    util.set_seed(config)

    experiment = Experiment(config, model, tokenizer)

    preds = experiment.predict(dataloader)
    preds = torch.from_numpy(preds)
    probs = F.softmax(preds, dim=1)
    uncertainty = least_conf(probs).numpy()
    labels = np.argmax(preds, axis=1)

    df['uncertainty'] = uncertainty
    df['probs0'] = probs[:, 0].numpy()
    df['probs1'] = probs[:, 1].numpy()
    df['classification'] = labels
    df.drop('preprocessed', axis='columns', inplace=True)

    label_name_map = {i: l for i, l in enumerate(TDDataset.BINARY_LABEL_NAMES)}
    print(label_name_map)

    # convert_label = {'DEFECT': 1, 'DESIGN': 1,
    #                  'IMPLEMENTATION': 1, 'TEST': 1,
    #                  'WITHOUT_CLASSIFICATION': 0, 'DOCUMENTATION': 1}
    # df['correct'] = (df.orig_classification.map(convert_label) == df.classification)
    # print(df.correct.value_counts(normalize=True))

    df.classification = df.classification.map(label_name_map)
    df.to_csv(util.data_path('satd', 'unclassified_evaled.csv'), index=False)

    tech_debt_df = df[df.classification == 'TECHNICAL_DEBT']
    print(tech_debt_df.shape)
    tech_debt_df.to_csv(util.data_path('satd', 'unclassified_pos.csv'),
                        index=False)
Exemplo n.º 14
0
def main(config, results):

    logger.warning('Unclassified threshold: %s', config.self_train_thresh)

    ds = TDDataset(config,
                   binary=True,
                   self_train_thresh=config.self_train_thresh,
                   keyword_masking_frac=config.keyword_masking_frac)

    model_config = tu.load_model_config(config)
    tokenizer = tu.load_tokenizer(config, model_config)
    label_names = ds.label_names

    #project_name = 'emf-2.4.1'
    project_name = config.single_project

    iter_obj = [
        (project_name, *ds.get_train_valid_dataloaders(
            tokenizer, project_name, include_valid_df=True))
    ] if project_name else ds.get_fold_dataloaders(tokenizer,
                                                   include_valid_df=True)

    interp_out_file = Path(
        config.interp_out_file) if config.interp_out_file else None

    # for train_dataloader, valid_dataloader in [ds.get_train_valid_dataloaders(tokenizer, project_name)]:
    for project_name, train_dataloader, (valid_dataloader,
                                         valid_df) in iter_obj:
        print(
            f"------------------ BEGIN PROJECT {project_name} -----------------------"
        )

        model = tu.load_model(config, model_config)
        model.to(config.device)
        util.set_seed(config)

        experiment = Experiment(config,
                                model,
                                tokenizer,
                                label_names=label_names,
                                run_name=project_name,
                                results=results)
        global_step, tr_loss = experiment.train(
            train_dataloader, valid_dataloader=valid_dataloader)

        if interp_out_file:
            interp_df = experiment.interpret(valid_dataloader, valid_df)
            with pd.option_context('display.max_rows', None,
                                   'display.max_columns', None):
                print(interp_df)
            interp_df.to_csv(interp_out_file.with_name(
                f"{project_name}_{interp_out_file.name}"),
                             index=False)

        results = experiment.results

        # experiment.evaluate('valid', valid_dataloader)

        print(
            f"================== DONE PROJECT {project_name} =======================\n\n"
        )
    return results
Exemplo n.º 15
0
    def train(self,
              train_dataloader,
              valid_dataloader=None,
              test_dataloader=None,
              should_continue=False):
        """ Train the model """
        tb_writer = SummaryWriter()

        train_epochs = self.config.train_epochs

        if self.config.max_steps > 0:
            train_steps = self.config.max_steps
            train_epochs = self.config.max_steps // (
                len(train_dataloader) // self.config.grad_acc_steps) + 1
        else:
            train_steps = len(
                train_dataloader) // self.config.grad_acc_steps * train_epochs

        if self.total_samples and should_continue:
            steps_total = self.total_samples // self.config.train_bs // self.config.grad_acc_steps * train_epochs
        else:
            steps_total = train_steps

        # Prepare optimizer and schedule (linear warmup and decay)
        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [
                    p for n, p in self.model.named_parameters()
                    if not any(nd in n for nd in no_decay)
                ],
                "weight_decay":
                self.config.weight_decay,
            },
            {
                "params": [
                    p for n, p in self.model.named_parameters()
                    if any(nd in n for nd in no_decay)
                ],
                "weight_decay":
                0.0
            },
        ]

        self.optimizer = AdamW(
            optimizer_grouped_parameters,
            lr=self.config.lr,
            eps=self.config.adam_eps,
        )

        self.scheduler = get_linear_schedule_with_warmup(
            self.optimizer,
            num_warmup_steps=self.config.warmup_steps,
            num_training_steps=steps_total)

        # self.scheduler = get_constant_schedule(self.optimizer)

        if should_continue and self.global_step > 0:
            logger.info("loading saved optimizer and scheduler states")
            assert (self.optimizer_state_dict)
            assert (self.scheduler_state_dict)
            self.optimizer.load_state_dict(self.optimizer_state_dict)
            self.scheduler.load_state_dict(self.scheduler_state_dict)
        else:
            logger.info("Using fresh optimizer and scheduler")

        if self.config.fp16:
            try:
                from apex import amp
            except ImportError:
                raise ImportError(
                    "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
                )
            self.model, self.optimizer = amp.initialize(
                self.model,
                self.optimizer,
                opt_level=self.config.fp16_opt_level)

        # multi-gpu training (should be after apex fp16 initialization)
        if self.config.n_gpu > 1 and not isinstance(self.model,
                                                    torch.nn.DataParallel):
            self.model = torch.nn.DataParallel(self.model)

        logger.info("***** Running training *****")
        logger.info("  Num examples = %d (%d)", len(train_dataloader.dataset),
                    len(train_dataloader))
        logger.info("  Num Epochs = %d", train_epochs)
        logger.info("  Batch size = %d", self.config.train_bs)
        logger.info("  Learning rate = %e", self.config.lr)
        logger.info("  Loss label weights = %s",
                    self.config.loss_label_weights)
        logger.info(
            "  Total train batch size (w. parallel, distributed & accumulation) = %d",
            self.config.train_bs * self.config.grad_acc_steps)
        logger.info("  Gradient Accumulation steps = %d",
                    self.config.grad_acc_steps)
        logger.info("  Total optimization steps = %d", train_steps)

        if not should_continue:
            self.global_step = 0

        epochs_trained = 0
        steps_trained_in_current_epoch = 0

        # # Check if continuing training from a checkpoint
        # if os.path.exists(self.config.model_path):
        #     if self.config.should_continue:
        #         step_str = self.config.model_path.split("-")[-1].split("/")[0]

        #         if step_str:
        #             # set self.global_step to gobal_step of last saved checkpoint from model path
        #             self.global_step = int(step_str)
        #             epochs_trained = self.global_step // (len(train_dataloader) //
        #                                                   self.config.grad_acc_steps)
        #             steps_trained_in_current_epoch = self.global_step % (
        #                 len(train_dataloader) // self.config.grad_acc_steps)

        #             logger.info(
        #                 "  Continuing training from checkpoint, will skip to saved self.global_step")
        #             logger.info(
        #                 "  Continuing training from epoch %d", epochs_trained)
        #             logger.info(
        #                 "  Continuing training from global step %d", self.global_step)
        #             logger.info("  Will skip the first %d steps in the first epoch",
        #                         steps_trained_in_current_epoch)

        train_loss = 0.0
        self.model.zero_grad()
        train_iterator = trange(
            epochs_trained,
            int(train_epochs),
            desc="Epoch",
        )
        util.set_seed(self.config)  # Added here for reproductibility

        self.model.train()

        if self.config.train_head_only:
            for param in self.model.roberta.embeddings.parameters():
                param.requires_grad = False
            logger.info("Training only head")
            # for param in self.model.__getattr__(self.config.model_type).roberta.parameters():
            #     param.requires_grad = False

        for _ in train_iterator:
            epoch_iterator = tqdm(train_dataloader, desc="Iteration")
            for step, batch in enumerate(epoch_iterator):

                # Skip past any already trained steps if resuming training
                if steps_trained_in_current_epoch > 0:
                    steps_trained_in_current_epoch -= 1
                    continue

                self.model.train()

                inputs = self.__inputs_from_batch(batch)
                outputs = self.model(**inputs)

                # model outputs are always tuple in transformers (see doc)
                loss = outputs[0]

                if self.config.n_gpu > 1:
                    loss = loss.mean(
                    )  # mean() to average on multi-gpu parallel training
                if self.config.grad_acc_steps > 1:
                    loss = loss / self.config.grad_acc_steps

                if self.config.fp16:
                    with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()

                batch_loss = loss.item()
                train_loss += batch_loss

                if (step + 1) % self.config.grad_acc_steps == 0:
                    if self.config.fp16:
                        torch.nn.utils.clip_grad_norm_(
                            amp.master_params(self.optimizer),
                            self.config.max_grad_norm)
                    else:
                        torch.nn.utils.clip_grad_norm_(
                            self.model.parameters(), self.config.max_grad_norm)

                    self.optimizer.step()
                    self.scheduler.step()  # Update learning rate schedule
                    self.model.zero_grad()
                    self.global_step += 1

                    if self.config.logging_steps > 0 and self.global_step % self.config.logging_steps == 0:
                        logs = {}
                        if valid_dataloader:
                            result_valid, * \
                                _ = self.evaluate(
                                    'valid', valid_dataloader, backtrans=(test_dataloader == None))
                            logs.update({
                                f"valid_{k}": v
                                for k, v in result_valid.items()
                            })

                        if test_dataloader:
                            test_dataloader = test_dataloader if isinstance(
                                test_dataloader, dict) else {
                                    'test': test_dataloader
                                }
                            for eval_name, dataloader_or_tuple in test_dataloader.items(
                            ):
                                if isinstance(dataloader_or_tuple, tuple):
                                    dataloader, kwargs = dataloader_or_tuple
                                else:
                                    dataloader = dataloader_or_tuple
                                    kwargs = {}

                                result_test, * \
                                    _ = self.evaluate(
                                        eval_name, dataloader, **kwargs)
                                logs.update({
                                    f"{eval_name}_{k}": v
                                    for k, v in result_test.items()
                                })

                        learning_rate_scalar = self.scheduler.get_last_lr()[0]
                        logger.info("Learning rate: %f (at step %d)",
                                    learning_rate_scalar, step)
                        logs["learning_rate"] = learning_rate_scalar
                        logs["train_loss"] = train_loss

                        self.after_logging(logs)

                        logger.info("Batch loss: %f", batch_loss)

                        # for key, value in logs.items():
                        #     tb_writer.add_scalar(key, value, self.global_step)

                    if self.config.save_steps > 0 and self.global_step % self.config.save_steps == 0:
                        # Save model checkpoint
                        self.save_checkpoint()

                if self.config.max_steps > 0 and self.global_step > self.config.max_steps:
                    epoch_iterator.close()
                    break
            if self.config.max_steps > 0 and self.global_step > self.config.max_steps:
                train_iterator.close()
                break

        if self.config.train_head_only:
            logger.info("Training only head")
            # for param in self.model.__getattr__(self.config.model_type).parameters():
            #     param.requires_grad = True

            for param in self.model.roberta.embeddings.parameters():
                param.requires_grad = False

        tb_writer.close()
        self.optimizer_state_dict = self.optimizer.state_dict()
        self.scheduler_state_dict = self.scheduler.state_dict()

        avg_train_loss = train_loss / self.global_step

        logger.info("Learning rate now: %s", self.scheduler.get_last_lr())
        logger.info("***** Done training *****")
        return self.global_step, avg_train_loss