Example #1
0
    writer = SummaryWriter(checkpoint_dir)

    data_loader_train, data_loader_val = load_data(parameters["batch_size"],
                                                   parameters["npoints"],
                                                   parameters["npatchs"])

    optimizer = optim.Adam(model.parameters(),
                           lr=parameters["lr"],
                           amsgrad=True)
    scheduler = torch.optim.lr_scheduler.StepLR(
        optimizer, parameters["lr_scheduler"]["step"],
        parameters["lr_scheduler"]["gamma"])

    del parameters["lr_scheduler"]
    writer.add_hparams({k: str(v) for k, v in parameters.items()}, {})

    criterion = torch.nn.NLLLoss(reduction='mean').to(device)
    best_loss = 99

    for epoch in range(1, parameters["nb_epochs"]):
        epoch_loss, epoch_acc = train(model, criterion, optimizer,
                                      data_loader_train, writer, epoch)
        scheduler.step()

        epoch_loss_train, epoch_acc_train, list_feat = eval(
            model, criterion, data_loader_train)
        epoch_loss_valid, epoch_acc_valid, list_feat = eval(
            model, criterion, data_loader_val)

        writer.add_scalar('Loss/training', epoch_loss, epoch)
Example #2
0
class TensorBoardCallback(TrainerCallback):
    """
    A :class:`~transformers.TrainerCallback` that sends the logs to `TensorBoard
    <https://www.tensorflow.org/tensorboard>`__.

    Args:
        tb_writer (:obj:`SummaryWriter`, `optional`):
            The writer to use. Will instantiate one if not set.
    """

    def __init__(self, tb_writer=None):
        assert (
            _has_tensorboard
        ), "TensorBoardCallback requires tensorboard to be installed. Either update your PyTorch version or install tensorboardX."
        self.tb_writer = tb_writer

    def _init_summary_writer(self, args, log_dir=None):
        log_dir = log_dir or args.logging_dir
        self.tb_writer = SummaryWriter(log_dir=log_dir)

    def on_train_begin(self, args, state, control, **kwargs):
        if not state.is_world_process_zero:
            return

        log_dir = None

        if state.is_hyper_param_search:
            trial_name = state.trial_name
            if trial_name is not None:
                log_dir = os.path.join(args.logging_dir, trial_name)

        self._init_summary_writer(args, log_dir)

        if self.tb_writer is not None:
            self.tb_writer.add_text("args", args.to_json_string())
            if "model" in kwargs:
                model = kwargs["model"]
                if hasattr(model, "config") and model.config is not None:
                    model_config_json = model.config.to_json_string()
                    self.tb_writer.add_text("model_config", model_config_json)
            # Version of TensorBoard coming from tensorboardX does not have this method.
            if hasattr(self.tb_writer, "add_hparams"):
                self.tb_writer.add_hparams(args.to_sanitized_dict(), metric_dict={})

    def on_log(self, args, state, control, logs=None, **kwargs):
        if state.is_world_process_zero:
            if self.tb_writer is None:
                self._init_summary_writer(args)

        if self.tb_writer:
            logs = rewrite_logs(logs)
            for k, v in logs.items():
                if isinstance(v, (int, float)):
                    self.tb_writer.add_scalar(k, v, state.global_step)
                else:
                    logger.warning(
                        "Trainer is attempting to log a value of "
                        '"%s" of type %s for key "%s" as a scalar. '
                        "This invocation of Tensorboard's writer.add_scalar() "
                        "is incorrect so we dropped this attribute.",
                        v,
                        type(v),
                        k,
                    )
            self.tb_writer.flush()

    def on_train_end(self, args, state, control, **kwargs):
        if self.tb_writer:
            self.tb_writer.close()
Example #3
0
def training(hparams: dict):
    """trains a model 

    Args:
        hparams (dict): a dictionary containing the hyperparameters
    """

    batch_size = hparams["batch_size"]
    num_epochs = hparams["num_epochs"]

    device = torch.device("cuda")
    # model = Bl_model(10, steps=hparams["steps"]).to(device)

    model = Bl_resnet(
        10,
        steps=hparams["steps"],
        threshold=torch.ones(8, device=device) * int(hparams["threshold"]),
        recurrence=hparams["recurrence"],
        residual=hparams["residual"],
    ).to(device)

    dataloaders = CIFAR10(batch_size, s=hparams["occlusion_size"])

    t0 = time.time()

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=hparams["lr_start"])

    lr_scheduler = optim.lr_scheduler.OneCycleLR(
        optimizer,
        hparams["lr_start"],
        epochs=num_epochs,
        steps_per_epoch=(dataloaders.sizes["train"] // batch_size),
    )

    starttime = f"started_{datetime.now().strftime('%d-%m-%Y %H:%M:%S')}"
    num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

    print(f"{num_params/1000000:.4}M")
    output_subdir = (
        Path("cifar10")
        / f"{model.mname}_{num_params/1000000:.2}M"
        / f"{num_epochs}_ep"
        / f"BS_{batch_size}_occlusion_{hparams['occlusion_size']}"
        / f"lr_{hparams['lr_start']}_threshold-{hparams['threshold']}"
        / starttime
    )

    writer_path = Path("Runs") / output_subdir
    writer = SummaryWriter(str(writer_path.resolve()))
    writer.add_graph(model, torch.zeros(1, 3, 32, 32).cuda(), verbose=False)
    writer.flush()
    # for tensorboard:
    hparams["recurrence"] = "".join([str(int(i)) for i in hparams["recurrence"]])
    tb_profile_trace_handler = torch.profiler.tensorboard_trace_handler(
        str(writer_path.resolve())
    )

    for epoch in range(num_epochs):
        current_lr = optimizer.param_groups[0]["lr"]
        print("-" * 10)
        print(f"Epoch {epoch+1}/{num_epochs}, LR: {current_lr}")
        print("-" * 10)

        for phase in ["train", "test"]:
            if phase == "train":
                model.train()
            else:
                model.eval()

            running_loss = 0.0
            running_corrects = 0

            tot = dataloaders.sizes[phase] // batch_size
            with torch.profiler.profile(
                schedule=torch.profiler.schedule(wait=2, warmup=2, active=3, repeat=1),
                on_trace_ready=tb_profile_trace_handler,
            ) as profiler:
                for batch, (inputs, labels) in tqdm(
                    enumerate(dataloaders[phase]), total=tot
                ):
                    inputs = inputs.to(device)
                    labels = labels.to(device)

                    optimizer.zero_grad(set_to_none=True)

                    with torch.set_grad_enabled(phase == "train"):
                        outputs = model(inputs)
                        _, preds = torch.max(outputs[-1], 1)
                        loss = sum([criterion(o, labels) for o in outputs])
                        if phase == "train":
                            loss.backward()
                            optimizer.step()
                            lr_scheduler.step()
                        if epoch == 0:
                            profiler.step()

                    running_loss += loss.item() * inputs.size(0)
                    running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / dataloaders.sizes[phase]
            epoch_acc = running_corrects.double() / dataloaders.sizes[phase]

            # tensorboard stuff
            writer.add_scalar("Loss/" + phase, epoch_loss, epoch)
            writer.add_scalar("Accuracy/" + phase, epoch_acc, epoch)

            if phase == "train":
                writer.add_scalar("LR/lr", current_lr, epoch)

                for name, param in model.named_parameters():
                    if "weight" in name and param.requires_grad:
                        if "bn" not in name:
                            if "lateral" not in name:
                                writer.add_histogram(name, param, epoch)
                                writer.add_histogram(name + ".grad", param.grad, epoch)

            print(f"{phase} Loss: {epoch_loss:.4} Acc: {epoch_acc:.5}")

        writer.flush()

    t1 = time.time()
    print(f"Total training time {t1-t0:.4} seconds")

    print("saving model")

    ckpt_path = Path("ckpt") / output_subdir

    os.makedirs(ckpt_path)
    torch.save(model.state_dict(), ckpt_path / "model.pt")

    print("... validating ...")
    for occ_mode in ["cutout", "noise"]:
        model.eval()
        tot = dataloaders.sizes["val"] // batch_size * 2
        # 1024 total pixels > 1 % ~ 10
        for p, occ in enumerate([10 * i for i in range(0, 20)]):
            # init stuff
            run_val_loss = 0.0
            run_val_acc = 0.0
            dataloaders.update_val_loader(occ, mode=occ_mode)
            for batch, (inputs, labels) in tqdm(
                enumerate(dataloaders.val_loader), total=tot
            ):
                inputs = inputs.to(device)
                labels = labels.to(device)
                outputs = model(inputs)
                _, preds = torch.max(outputs[-1], 1)
                loss = sum([criterion(o, labels) for o in outputs])

                run_val_loss += loss.item() * inputs.size(0)
                run_val_acc += torch.sum(preds == labels.data)

            val_loss = run_val_loss / dataloaders.sizes["val"]
            val_acc = run_val_acc.double() / dataloaders.sizes["val"]
            print(f"Validation Loss: {val_loss:.4} Acc: {val_acc:.5}")

            metric_dict = {
                "hparam/loss": val_loss,
                "hparam/accuracy": val_acc,
            }

            writer.add_scalar(f"Accuracy/val_{occ_mode}", val_acc, p)

            writer.add_hparams(hparams, metric_dict, run_name="ht")
            writer.flush()
            if val_acc < 0.11:
                print(f"finished {occ_mode}----------------------")
                break

    writer.close()
    return 0
Example #4
0
def trainer(model: BaseModel,
            total_reads: int,
            length: int,
            train_ratio: float,
            id_list: Optional[List[str]],
            distribution: List[float],
            class_list: Optional[List] = None,
            metadata: Optional[str] = None,
            batch_size: Optional[int] = 1,
            data_directory: Optional[str] = None,
            random_seed: Optional[int] = None,
            external_validation_ids: Optional[List[str]] = None,
            n_external_validation_reads: Optional[int] = None,
            external_validation_distribution: Optional[List[float]] = None,
            external_validation_classes: Optional[List] = None,
            start_time: Optional[float] = None,
            append_time: Optional[bool] = True,
            additional_hparams: Optional[dict] = None,
            model_kwargs: Optional[dict] = None,
            train_kwargs: Optional[dict] = None,
            summary_kwargs: Optional[dict] = None) -> BaseModel:
    """

    Parameters
    ----------
    model
        An model class to use for training
    total_reads
        The total number of reads to simulate for training
    length
        The length of reads to simulate for training
    train_ratio
        The portion of simulated reads that should be used for training
    id_list
        The list of genome ids to use for training
    distribution
        The relative amount of each id in id_list to use
    class_list
        The class of each id in id_list
    metadata
        Path to metadata containing a '# assembly_accession' and 'ftp_path'
        column (NCBI assembly summary format)
    batch_size
        Size of batches for neural network training
    data_directory
        Directory that genome data is saved in, or should be saved in for
        downloaded data
    random_seed
        Seed for the random number generator
    external_validation_ids
        Genome ids to use for validating that should be exclusive with
        training id_list
    n_external_validation_reads
        how many reads to sample from the external validation ids
    external_validation_distribution
        How to distribute the reads amongst the external validation ID's
    external_validation_classes
        The class of each id in external_validation_ids
    start_time
        An object for tracking amount of time since training has started.
        Can be passed in for convenience, but will be initialized if not
        provided
    append_time
        Indicates whether the time should be appended to the log_dir,
        cannot be False if log_dir is None
    additional_hparams
        hyper-parameters that can be passed into to be saved with the model
    model_kwargs
        kwargs to be passed to the `model`
    train_kwargs
        kwargs to be passed to the training function of the `model`
    summary_kwargs
        kwargs to be passed to the summary funciton of the `model`



    Returns
    -------

    trained_model
        The model that has been trained by your trainer

    """

    if additional_hparams is None:
        additional_hparams = dict()
    if model_kwargs is None:
        model_kwargs = dict()
    if train_kwargs is None:
        train_kwargs = dict()
    if summary_kwargs is None:
        summary_kwargs = dict()
    if start_time is None:
        start_time = time.time()

    log_dir = train_kwargs.pop('log_dir', None)
    if append_time and not log_dir:
        # this error is to avoid multiple models accidentally getting the
        #  written to the same directory. Take care to use different
        #  log_dirs if using append_time=False
        raise ValueError("'append_time' can cannot be False when 'log_dir' "
                         "is None")
    log_dir = model.get_log_dir(log_dir, append_time=append_time)

    # If id_list is not None, use the specified id's
    if id_list is not None:
        file_list = data_downloader(id_list,
                                    output_directory=data_directory,
                                    metadata=metadata)
    # if id_list _is_ None, just use whatever is in the directory
    elif os.path.exists(data_directory) and \
            len(os.listdir(data_directory)) > 0:
        file_list = [
            os.path.join(data_directory, file_)
            for file_ in os.listdir(data_directory)
        ]
    else:
        raise FileExistsError('Data directory must exist and contain '
                              'data if `id_list` is not supplied.')

    reads, ids = simulate_from_genomes(distribution, total_reads, length,
                                       file_list, data_directory, random_seed)

    id_depths = [round(val * total_reads) for val in distribution]
    if class_list is None:
        class_list = id_list
    list_of_classes = [[class_] * depth
                       for class_, depth in zip(class_list, id_depths)]
    class_list = [item for sublist in list_of_classes for item in sublist]

    external_validation = False
    external_classes = None
    if external_validation_ids is not None and \
            n_external_validation_reads is not None and \
            external_validation_distribution is not None and \
            external_validation_classes is not None:
        data_downloader(external_validation_ids,
                        output_directory=data_directory,
                        metadata=metadata)
        external_reads, external_ids = simulate_from_genomes(
            external_validation_distribution, n_external_validation_reads,
            length, external_validation_ids, data_directory, random_seed + 5)

        external_validation = True

        ext_depths = [
            round(val * n_external_validation_reads)
            for val in external_validation_distribution
        ]
        list_of_ext_classes = [
            [class_] * depth
            for class_, depth in zip(external_validation_classes, ext_depths)
        ]

        external_classes = [
            item for sublist in list_of_ext_classes for item in sublist
        ]

    elif external_validation_ids is not None or \
            n_external_validation_reads is not None or \
            external_validation_distribution is not None or \
            external_validation_classes is not None:
        raise ValueError('If any external validation parameters are '
                         'specified, all must be specified.')

    # split into train and validation
    train_indices, val_indices = train_val_split(total_reads,
                                                 train_ratio,
                                                 random_seed=random_seed)

    # create Dataset and DataLoader for train and validation
    train_dataloader = prepare_dataloader(reads,
                                          class_list,
                                          ids,
                                          batch_size=batch_size,
                                          indices=train_indices)
    val_dataloader = prepare_dataloader(reads,
                                        class_list,
                                        ids,
                                        batch_size=batch_size,
                                        indices=val_indices)
    if external_validation:
        external_dataloader = prepare_dataloader(external_reads,
                                                 external_classes,
                                                 external_ids,
                                                 batch_size=batch_size)
    else:
        external_dataloader = None

    training_model = model(n_classes=len(set(class_list)),
                           seed=random_seed,
                           **model_kwargs)

    # conv1d complains if there are float's instead of doubles <- could slow
    # down training
    training_model.double()

    writer = SummaryWriter(log_dir=log_dir)

    if train_kwargs['gpu']:
        training_model.cuda()

    # TODO prefix/suffix option for naming
    training_model.fit(train_dataloader,
                       log_dir=log_dir,
                       val_dataset=val_dataloader,
                       external_dataset=external_dataloader,
                       seed=random_seed,
                       summary_kwargs=summary_kwargs,
                       log_dir_append_time=False,
                       writer=writer,
                       start_time=start_time,
                       **train_kwargs)

    end_time = time.time()
    mod = training_model

    hparam_dict = dict()
    hparam_train_kwargs = ['learning_rate', 'epochs', 'gpu']
    hparam_dict.update(
        {kwarg: train_kwargs[kwarg]
         for kwarg in hparam_train_kwargs})
    hparam_dict.update(model_kwargs)
    hparam_dict.update({'model_type': model.__name__, 'random-seed': mod.seed})
    hparam_dict.update(additional_hparams)

    metric_dict = {
        'best-val-accuracy': mod.best_val_accuracy,
        'best-val-loss': mod.best_val_loss,
        'best-val-epoch': mod.best_val_epoch,
        'best-val-train-accuracy': mod.best_val_train_accuracy,
        'best-val-train-loss': mod.best_val_train_loss,
        'best-val-time': mod.best_val_time,
        'train-dataset-length': len(train_dataloader.dataset),
        'val-dataset-length': len(val_dataloader.dataset),
        'total-time': end_time - start_time,
    }
    writer.add_hparams(hparam_dict, metric_dict)

    writer.close()

    return training_model
Example #5
0
        args.actor_critic_lr) + "_g_" + str(args.gamma) + "_mem_" + str(
            args.max_memory) + "_batch_" + str(args.batch_size)
    dstr = datetime.datetime.now().strftime("_dt-%Y-%m-%d-%H-%M-%S")
    writer = SummaryWriter(log_dir="./model" + hyper_params + dstr)
    arg_dict = vars(args)
    writer.add_text('Model Parameters: ', str(arg_dict), 0)

    # config
    state = agent.get_state()  # env.observation_space.shape[0]
    n_actions = 3  # env.action_space.shape[0]
    actorcritic = PPOActorCritic(state.shape[0], n_actions,
                                 activation=Mish).to(agent.device)
    if (args.model_path) or args.test:
        actorcritic.load_state_dict(torch.load(args.model_path))
    trainer = PPOTrainer(actorcritic,
                         gamma=args.gamma,
                         batch_size=64,
                         device=agent.device,
                         actor_critic_lr=args.actor_critic_lr)

    if args.test:
        test(game, args)
    else:
        agent.run()

    writer.add_hparams(hparam_dict=vars(args),
                       metric_dict={
                           'mean_reward': agent.mean_reward,
                           'high_score': agent.record,
                           'mean_score': agent.mean_score
                       })
Example #6
0
def main(args):
    ##Switchs
    isTraining = args.isTraining
    saveModel = args.saveModel
    loadModel = args.loadModel
    debug = args.debug
    saveResults = args.saveResults
    savePeriod = args.savePeriod

    modelSavingPath = args.modelSavingPath
    modelLoadingPath = args.modelLoadingPath

    env = gym.make(args.env)
    runFor = "_Training" if isTraining else "_Evaluation"
    nameToWriter = "__DDPG__" + args.env + runFor
    writer = SummaryWriter(comment=nameToWriter)

    env.seed(0)
    np.random.seed(0)
    torch.manual_seed(0)

    actionDim = env.action_space.shape[0]
    actionH = env.action_space.high
    actionL = env.action_space.low
    stateDim = env.observation_space.shape[0]
    hiddenDim = args.hiddenDim

    buffer = UniformReplayBuffer(args.maxCapacity, env.observation_space.shape,
                                 env.action_space.shape, np.float32, np.long)
    ddpgAgent = DDPGAgent(buffer, stateDim, actionDim, hiddenDim, actionH,
                          args)

    if not isTraining or loadModel:
        LoadModel(ddpgAgent, modelLoadingPath)

    noise = OUNoise(actionDim, args)
    rewardToStop = args.maxReward

    totalReward = []
    stepCounter = 0
    meanReward = 0
    totalPassedTime = 0
    isSuccesfull = False

    if saveResults:
        writer.add_hparams(
            {
                "lrPolicy": args.lrPolicy,
                "lrCritic": args.lrCritic,
                "batchSize": args.batchSize,
                "tau": args.tau,
                "hiddenDim": args.hiddenDim,
                "noiseTheta": args.theta,
                "noiseSigma": args.fixedSigma,
                "envName": args.env
            }, {"lrPolicy": args.lrPolicy})

    for e in range(args.numberOfEpisode):

        state = env.reset()
        episodeReward = 0
        episodeSteps = 0
        episodeTime = time.time()
        isDone = False

        for s in range(args.maxStepCount):
            action = ddpgAgent.GetAction(state)
            if isTraining:
                action = np.clip(action + noise.noise(), actionL, actionH)
            nextState, reward, done, _ = env.step(action)
            buffer.push_transition(
                Transition(state, action, reward, nextState, done))
            episodeReward += reward

            if stepCounter > 2 * args.batchSize and isTraining:
                ddpgAgent.Update(stepCounter)

            if not isTraining:
                env.render()
                time.sleep(0.01)

            if done:
                isDone = True
                break

            state = nextState
            episodeSteps += 1
            stepCounter += 1

        episodeTime = time.time() - episodeTime
        totalPassedTime += episodeTime
        totalReward.append(episodeReward)
        meanReward = float(np.mean(totalReward[-100:]))

        if saveResults and e % savePeriod == 0:

            ##Store datas w.r.t stepCounter
            writer.add_scalar("STEPS/episodeReward" + runFor, episodeReward,
                              stepCounter)
            writer.add_scalar("STEPS/meanReward" + runFor, meanReward,
                              stepCounter)
            writer.add_scalar("STEPS/episodes" + runFor, e, stepCounter)
            ##Store datas w.r.t episodes
            writer.add_scalar("EPISODES/episodeSteps" + runFor, episodeSteps,
                              e)
            writer.add_scalar("EPISODES/episodeTime" + runFor, episodeTime, e)
            writer.add_scalar("EPISODES/episodeReward" + runFor, episodeReward,
                              e)
            writer.add_scalar("EPISODES/meanReward" + runFor, meanReward, e)

        if debug:
            print(
                "Eps:{} Steps:{} Mean Reward: {} Episode Reward: {} Episode Time: {}  IsDone: {}"
                .format(e, stepCounter, meanReward, episodeReward, episodeTime,
                        isDone))

        if meanReward >= rewardToStop and isTraining:
            isSuccesfull = True
            print("SUCCESS!!!")
            print(
                "Total Eps:{} Total Steps:{} Mean Reward: {} Episode Reward: {} Total Time: {}"
                .format(e, stepCounter, meanReward, episodeReward,
                        totalPassedTime))
            if saveModel:
                SaveModel(ddpgAgent, e, stepCounter, episodeReward,
                          "__DDPG__" + args.env, modelSavingPath)
            break

    ##If reached max episode count without expected  mean reward
    if not isSuccesfull:
        print("FAILURE!!!")
        print("Total Eps:{} Total Steps:{} Mean Reward: {} Total Time: {}".
              format(len(totalReward), stepCounter, meanReward,
                     totalPassedTime))
Example #7
0
            logits[x.nonzero(as_tuple=True)] = .0

            # Fetching all predictions and ground_truth labels
            all_logits.append(logits.detach().cpu().numpy())
            all_y.append(y.detach().cpu().numpy())

        preds = np.concatenate(all_logits)
        true = np.concatenate(all_y)

        full_metrics = dict()
        full_raw_metrics = dict()
        for trait in DEMO_TRAITS:
            user_groups = user_groups_all_traits[trait]

            _, metrics, metrics_raw = eval_proced(preds=preds,
                                                  true=true,
                                                  tag='test',
                                                  user_groups=user_groups,
                                                  tids_path=tids_path,
                                                  entropy_norm=True)
            full_metrics.update(metrics)
            full_raw_metrics.update(metrics_raw)

        # Logging hyperparams and metrics
        summ.add_hparams({**best_config, 'fold_n': fold_n}, full_metrics)
        summ.flush()

        # Saving results and predictions
        pickle_dump(full_metrics, os.path.join(log_te_str, 'full_metrics.pkl'))
        pickle_dump(full_raw_metrics, os.path.join(log_te_str, 'full_raw_metrics.pkl'))
Example #8
0
        "layer_size": LAYER_SIZE,
        "nStep": nstep,
        "gamma": GAMMA,
        "tau": TAU,
        "learningRate": LR,
        "epsilon": EPS,
        "updateEvery": UPDATE_EVERY,
        "nUpdate": NUPDATES
    }

    np.random.seed(seed)
    env = ContinuousCartPoleEnv()

    now = datetime.now()
    writer = SummaryWriter('logdir/' + now.strftime("%Y%m%d-%H%M%S") + "/")
    writer.add_hparams(paramDict, {})

    env.seed(seed)
    action_size = env.action_space.shape[0]
    state_size = env.observation_space.shape[0]

    agent = DQN_Agent(state_size=state_size,
                      action_size=action_size,
                      layer_size=LAYER_SIZE,
                      BATCH_SIZE=BATCH_SIZE,
                      BUFFER_SIZE=BUFFER_SIZE,
                      PER=per,
                      LR=LR,
                      EPS=EPS,
                      Nstep=nstep,
                      TAU=TAU,
class SamplingMultitaskTrainer:
    def __init__(self,
                 dataset=None,
                 model_name=None,
                 model_params=None,
                 trainer_params=None,
                 restore=None,
                 device=None,
                 pretrained_embeddings_path=None,
                 tokenizer_path=None):

        self.graph_model = model_name(dataset.g, **model_params).to(device)
        self.model_params = model_params
        self.trainer_params = trainer_params
        self.device = device
        self.epoch = 0
        self.batch = 0
        self.dtype = torch.float32
        self.create_node_embedder(
            dataset,
            tokenizer_path,
            n_dims=model_params["h_dim"],
            pretrained_path=pretrained_embeddings_path,
            n_buckets=trainer_params["embedding_table_size"])

        self.summary_writer = SummaryWriter(self.model_base_path)

        self.ee_node_name = ElementEmbedderWithBpeSubwords(
            elements=dataset.load_node_names(),
            nodes=dataset.nodes,
            emb_size=self.elem_emb_size,
            tokenizer_path=tokenizer_path).to(self.device)

        self.ee_var_use = ElementEmbedderWithBpeSubwords(
            elements=dataset.load_var_use(),
            nodes=dataset.nodes,
            emb_size=self.elem_emb_size,
            tokenizer_path=tokenizer_path).to(self.device)

        self.ee_api_call = ElementEmbedderBase(
            elements=dataset.load_api_call(),
            nodes=dataset.nodes,
            compact_dst=False,
            dst_to_global=True)

        self.lp_node_name = LinkPredictor(self.ee_node_name.emb_size +
                                          self.graph_model.emb_size).to(
                                              self.device)
        self.lp_var_use = LinkPredictor(self.ee_var_use.emb_size +
                                        self.graph_model.emb_size).to(
                                            self.device)
        self.lp_api_call = LinkPredictor(self.graph_model.emb_size +
                                         self.graph_model.emb_size).to(
                                             self.device)

        if restore:
            self.restore_from_checkpoint(self.model_base_path)

        self.optimizer = self._create_optimizer()

        self.lr_scheduler = ExponentialLR(self.optimizer, gamma=1.0)
        self.best_score = BestScoreTracker()

        self._create_loaders(*self._get_training_targets())

    def create_node_embedder(self,
                             dataset,
                             tokenizer_path,
                             n_dims=None,
                             pretrained_path=None,
                             n_buckets=500000):
        from SourceCodeTools.nlp.embed.fasttext import load_w2v_map

        if pretrained_path is not None:
            pretrained = load_w2v_map(pretrained_path)
        else:
            pretrained = None

        if pretrained_path is None and n_dims is None:
            raise ValueError(
                f"Specify embedding dimensionality or provide pretrained embeddings"
            )
        elif pretrained_path is not None and n_dims is not None:
            assert n_dims == pretrained.n_dims, f"Requested embedding size and pretrained embedding " \
                                                f"size should match: {n_dims} != {pretrained.n_dims}"
        elif pretrained_path is not None and n_dims is None:
            n_dims = pretrained.n_dims

        if pretrained is not None:
            logging.info(f"Loading pretrained embeddings...")
        logging.info(f"Input embedding size is {n_dims}")

        self.node_embedder = NodeEmbedder(
            nodes=dataset.nodes,
            emb_size=n_dims,
            # tokenizer_path=tokenizer_path,
            dtype=self.dtype,
            pretrained=dataset.buckets_from_pretrained_embeddings(
                pretrained_path, n_buckets)
            if pretrained_path is not None else None,
            n_buckets=n_buckets)

        # self.node_embedder(node_type="node_", node_ids=torch.LongTensor([0]))
        # self.node_embedder(node_type="node_", node_ids=torch.LongTensor([13749]))
        # self.node_embedder(node_type="node_", node_ids=torch.LongTensor([13754]))

        # node_, 0 matplotlib
        # node_ 13749        Renderer
        # node_  13754 ▁renderer

        # print()

    @property
    def lr(self):
        return self.trainer_params['lr']

    @property
    def batch_size(self):
        return self.trainer_params['batch_size']

    @property
    def sampling_neighbourhood_size(self):
        return self.trainer_params['sampling_neighbourhood_size']

    @property
    def neg_sampling_factor(self):
        return self.trainer_params['neg_sampling_factor']

    @property
    def epochs(self):
        return self.trainer_params['epochs']

    @property
    def elem_emb_size(self):
        return self.trainer_params['elem_emb_size']

    @property
    def node_name_file(self):
        return self.trainer_params['node_name_file']

    @property
    def var_use_file(self):
        return self.trainer_params['var_use_file']

    @property
    def call_seq_file(self):
        return self.trainer_params['call_seq_file']

    @property
    def model_base_path(self):
        return self.trainer_params['model_base_path']

    @property
    def pretraining(self):
        return self.epoch >= self.trainer_params['pretraining_phase']

    @property
    def do_save(self):
        return self.trainer_params['save_checkpoints']

    # def _extract_embed(self, node_embed, input_nodes):
    #     emb = {}
    #     for node_type, nid in input_nodes.items():
    #         emb[node_type] = node_embed[node_type][nid]
    #     return emb

    def write_summary(self, scores, batch_step):
        # main_name = os.path.basename(self.model_base_path)
        for var, val in scores.items():
            # self.summary_writer.add_scalar(f"{main_name}/{var}", val, batch_step)
            self.summary_writer.add_scalar(var, val, batch_step)
        # self.summary_writer.add_scalars(main_name, scores, batch_step)

    def write_hyperparams(self, scores, epoch):
        params = copy(self.model_params)
        params["epoch"] = epoch
        main_name = os.path.basename(self.model_base_path)
        params = {
            k: v
            for k, v in params.items()
            if type(v) in {int, float, str, bool, torch.Tensor}
        }

        main_name = os.path.basename(self.model_base_path)
        scores = {f"h_metric/{k}": v for k, v in scores.items()}
        self.summary_writer.add_hparams(params,
                                        scores,
                                        run_name=f"h_metric/{epoch}")

    def _extract_embed(self, input_nodes):
        emb = {}
        for node_type, nid in input_nodes.items():
            emb[node_type] = self.node_embedder(
                node_type=node_type,
                node_ids=nid,
                train_embeddings=self.pretraining).to(self.device)
        return emb

    def _logits_batch(self, input_nodes, blocks):

        cumm_logits = []

        if self.use_types:
            # emb = self._extract_embed(self.graph_model.node_embed(), input_nodes)
            emb = self._extract_embed(input_nodes)
        else:
            if self.ntypes is not None:
                # single node type
                key = next(iter(self.ntypes))
                input_nodes = {key: input_nodes}
                # emb = self._extract_embed(self.graph_model.node_embed(), input_nodes)
                emb = self._extract_embed(input_nodes)
            else:
                emb = self.node_embedder(node_ids=input_nodes,
                                         train_embeddings=self.pretraining)
                # emb = self.graph_model.node_embed()[input_nodes]

        logits = self.graph_model(emb, blocks)

        if self.use_types:
            for ntype in self.graph_model.g.ntypes:

                logits_ = logits.get(ntype, None)
                if logits_ is None:
                    continue

                cumm_logits.append(logits_)
        else:
            if self.ntypes is not None:
                # single node type
                key = next(iter(self.ntypes))
                logits_ = logits[key]
            else:
                logits_ = logits

            cumm_logits.append(logits_)

        return torch.cat(cumm_logits)

    def seeds_to_global(self, seeds):
        if type(seeds) is dict:
            indices = [
                self.graph_model.g.nodes[ntype].data["global_graph_id"][
                    seeds[ntype]] for ntype in seeds
            ]
            return torch.cat(indices, dim=0)
        else:
            return seeds

    def _logits_embedder(self,
                         node_embeddings,
                         elem_embedder,
                         link_predictor,
                         seeds,
                         negative_factor=1):
        k = negative_factor
        indices = self.seeds_to_global(seeds)
        batch_size = len(indices)

        node_embeddings_batch = node_embeddings
        element_embeddings = elem_embedder(elem_embedder[indices.tolist()].to(
            self.device))

        positive_batch = torch.cat([node_embeddings_batch, element_embeddings],
                                   1)
        labels_pos = torch.ones(batch_size, dtype=torch.long)

        node_embeddings_neg_batch = node_embeddings_batch.repeat(k, 1)
        negative_random = elem_embedder(
            elem_embedder.sample_negative(batch_size * k).to(self.device))

        negative_batch = torch.cat(
            [node_embeddings_neg_batch, negative_random], 1)
        labels_neg = torch.zeros(batch_size * k, dtype=torch.long)

        batch = torch.cat([positive_batch, negative_batch], 0)
        labels = torch.cat([labels_pos, labels_neg], 0).to(self.device)

        logits = link_predictor(batch)

        return logits, labels

    def _handle_non_unique(self, non_unique_ids):
        id_list = non_unique_ids.tolist()
        unique_ids = list(set(id_list))
        new_position = dict(zip(unique_ids, range(len(unique_ids))))
        slice_map = torch.tensor(list(map(lambda x: new_position[x], id_list)),
                                 dtype=torch.long)
        return torch.tensor(unique_ids, dtype=torch.long), slice_map

    def _logits_nodes(self,
                      node_embeddings,
                      elem_embedder,
                      link_predictor,
                      create_dataloader,
                      src_seeds,
                      negative_factor=1):
        k = negative_factor
        indices = self.seeds_to_global(src_seeds)
        batch_size = len(indices)

        node_embeddings_batch = node_embeddings
        next_call_indices = elem_embedder[
            indices.tolist()]  # this assumes indices is torch tensor

        # dst targets are not unique
        unique_dst, slice_map = self._handle_non_unique(next_call_indices)
        assert unique_dst[slice_map].tolist() == next_call_indices.tolist()

        dataloader = create_dataloader(unique_dst)
        input_nodes, dst_seeds, blocks = next(iter(dataloader))
        blocks = [blk.to(self.device) for blk in blocks]
        assert dst_seeds.shape == unique_dst.shape
        assert dst_seeds.tolist() == unique_dst.tolist()
        unique_dst_embeddings = self._logits_batch(
            input_nodes, blocks)  # use_types, ntypes)
        next_call_embeddings = unique_dst_embeddings[slice_map.to(self.device)]
        positive_batch = torch.cat(
            [node_embeddings_batch, next_call_embeddings], 1)
        labels_pos = torch.ones(batch_size, dtype=torch.long)

        node_embeddings_neg_batch = node_embeddings_batch.repeat(k, 1)
        negative_indices = torch.tensor(
            elem_embedder.sample_negative(batch_size * k), dtype=torch.long
        )  # embeddings are sampled from 3/4 unigram distribution
        unique_negative, slice_map = self._handle_non_unique(negative_indices)
        assert unique_negative[slice_map].tolist() == negative_indices.tolist()

        dataloader = create_dataloader(unique_negative)
        input_nodes, dst_seeds, blocks = next(iter(dataloader))
        blocks = [blk.to(self.device) for blk in blocks]
        assert dst_seeds.shape == unique_negative.shape
        assert dst_seeds.tolist() == unique_negative.tolist()
        unique_negative_random = self._logits_batch(
            input_nodes, blocks)  # use_types, ntypes)
        negative_random = unique_negative_random[slice_map.to(self.device)]
        negative_batch = torch.cat(
            [node_embeddings_neg_batch, negative_random], 1)
        labels_neg = torch.zeros(batch_size * k, dtype=torch.long)

        batch = torch.cat([positive_batch, negative_batch], 0)
        labels = torch.cat([labels_pos, labels_neg], 0).to(self.device)

        logits = link_predictor(batch)

        return logits, labels

    def _logits_node_name(self, input_nodes, seeds, blocks):
        src_embs = self._logits_batch(input_nodes, blocks)
        logits, labels = self._logits_embedder(
            src_embs,
            self.ee_node_name,
            self.lp_node_name,
            seeds,
            negative_factor=self.neg_sampling_factor)
        return logits, labels

    def _logits_var_use(self, input_nodes, seeds, blocks):
        src_embs = self._logits_batch(input_nodes, blocks)
        logits, labels = self._logits_embedder(
            src_embs,
            self.ee_var_use,
            self.lp_var_use,
            seeds,
            negative_factor=self.neg_sampling_factor)
        return logits, labels

    def _logits_api_call(self, input_nodes, seeds, blocks):
        src_embs = self._logits_batch(input_nodes, blocks)
        logits, labels = self._logits_nodes(
            src_embs,
            self.ee_api_call,
            self.lp_api_call,
            self._create_api_call_loader,
            seeds,
            negative_factor=self.neg_sampling_factor)
        return logits, labels

    def _get_training_targets(self):
        if hasattr(self.graph_model.g, 'ntypes'):
            self.ntypes = self.graph_model.g.ntypes
            # labels = {ntype: self.graph_model.g.nodes[ntype].data['labels'] for ntype in self.ntypes}
            self.use_types = True

            if len(self.graph_model.g.ntypes) == 1:
                # key = next(iter(labels.keys()))
                # labels = labels[key]
                self.use_types = False

            train_idx = {
                ntype: torch.nonzero(
                    self.graph_model.g.nodes[ntype].data['train_mask'],
                    as_tuple=False).squeeze()
                for ntype in self.ntypes
            }
            val_idx = {
                ntype:
                torch.nonzero(self.graph_model.g.nodes[ntype].data['val_mask'],
                              as_tuple=False).squeeze()
                for ntype in self.ntypes
            }
            test_idx = {
                ntype: torch.nonzero(
                    self.graph_model.g.nodes[ntype].data['test_mask'],
                    as_tuple=False).squeeze()
                for ntype in self.ntypes
            }
        else:
            self.ntypes = None
            # labels = g.ndata['labels']
            train_idx = self.graph_model.g.ndata['train_mask']
            val_idx = self.graph_model.g.ndata['val_mask']
            test_idx = self.graph_model.g.ndata['test_mask']
            self.use_types = False

        return train_idx, val_idx, test_idx

    def _evaluate_embedder(self, ee, lp, loader, neg_sampling_factor=1):

        total_loss = 0
        total_acc = 0
        count = 0

        for input_nodes, seeds, blocks in loader:
            blocks = [blk.to(self.device) for blk in blocks]

            src_embs = self._logits_batch(input_nodes, blocks)
            logits, labels = self._logits_embedder(src_embs, ee, lp, seeds,
                                                   neg_sampling_factor)

            logp = nn.functional.log_softmax(logits, 1)
            loss = nn.functional.cross_entropy(logp, labels)
            acc = compute_accuracy(logp.argmax(dim=1), labels)

            total_loss += loss.item()
            total_acc += acc
            count += 1
        return total_loss / count, total_acc / count

    def _evaluate_nodes(self,
                        ee,
                        lp,
                        create_api_call_loader,
                        loader,
                        neg_sampling_factor=1):

        total_loss = 0
        total_acc = 0
        count = 0

        for input_nodes, seeds, blocks in loader:
            blocks = [blk.to(self.device) for blk in blocks]

            src_embs = self._logits_batch(input_nodes, blocks)
            logits, labels = self._logits_nodes(src_embs, ee, lp,
                                                create_api_call_loader, seeds,
                                                neg_sampling_factor)

            logp = nn.functional.log_softmax(logits, 1)
            loss = nn.functional.cross_entropy(logp, labels)
            acc = compute_accuracy(logp.argmax(dim=1), labels)

            total_loss += loss.item()
            total_acc += acc
            count += 1
        return total_loss / count, total_acc / count

    def _evaluate_objectives(self, loader_node_name, loader_var_use,
                             loader_api_call, neg_sampling_factor):

        node_name_loss, node_name_acc = self._evaluate_embedder(
            self.ee_node_name,
            self.lp_node_name,
            loader_node_name,
            neg_sampling_factor=neg_sampling_factor)

        var_use_loss, var_use_acc = self._evaluate_embedder(
            self.ee_var_use,
            self.lp_var_use,
            loader_var_use,
            neg_sampling_factor=neg_sampling_factor)

        api_call_loss, api_call_acc = self._evaluate_nodes(
            self.ee_api_call,
            self.lp_api_call,
            self._create_api_call_loader,
            loader_api_call,
            neg_sampling_factor=neg_sampling_factor)

        loss = node_name_loss + var_use_loss + api_call_loss

        return loss, node_name_acc, var_use_acc, api_call_acc

    def _idx_len(self, idx):
        if isinstance(idx, dict):
            length = 0
            for key in idx:
                length += len(idx[key])
        else:
            length = len(idx)
        return length

    def _get_loaders(self, train_idx, val_idx, test_idx, batch_size):
        layers = self.graph_model.num_layers
        # train sampler
        sampler = dgl.dataloading.MultiLayerNeighborSampler(
            [self.sampling_neighbourhood_size] * layers)
        loader = dgl.dataloading.NodeDataLoader(self.graph_model.g,
                                                train_idx,
                                                sampler,
                                                batch_size=batch_size,
                                                shuffle=False,
                                                num_workers=0)

        # validation sampler
        # we do not use full neighbor to save computation resources
        val_sampler = dgl.dataloading.MultiLayerNeighborSampler(
            [self.sampling_neighbourhood_size] * layers)
        val_loader = dgl.dataloading.NodeDataLoader(self.graph_model.g,
                                                    val_idx,
                                                    val_sampler,
                                                    batch_size=batch_size,
                                                    shuffle=False,
                                                    num_workers=0)

        # we do not use full neighbor to save computation resources
        test_sampler = dgl.dataloading.MultiLayerNeighborSampler(
            [self.sampling_neighbourhood_size] * layers)
        test_loader = dgl.dataloading.NodeDataLoader(self.graph_model.g,
                                                     test_idx,
                                                     test_sampler,
                                                     batch_size=batch_size,
                                                     shuffle=False,
                                                     num_workers=0)

        return loader, val_loader, test_loader

    def _create_loaders(self, train_idx, val_idx, test_idx):

        train_idx_node_name, val_idx_node_name, test_idx_node_name = self.ee_node_name.create_idx_pools(
            train_idx=train_idx, val_idx=val_idx, test_idx=test_idx)
        train_idx_var_use, val_idx_var_use, test_idx_var_use = self.ee_var_use.create_idx_pools(
            train_idx=train_idx, val_idx=val_idx, test_idx=test_idx)
        train_idx_api_call, val_idx_api_call, test_idx_api_call = self.ee_api_call.create_idx_pools(
            train_idx=train_idx, val_idx=val_idx, test_idx=test_idx)

        logging.info(
            f"Pool sizes : train {self._idx_len(train_idx_node_name)}, "
            f"val {self._idx_len(val_idx_node_name)}, "
            f"test {self._idx_len(test_idx_node_name)}.")
        logging.info(f"Pool sizes : train {self._idx_len(train_idx_var_use)}, "
                     f"val {self._idx_len(val_idx_var_use)}, "
                     f"test {self._idx_len(test_idx_var_use)}.")
        logging.info(
            f"Pool sizes : train {self._idx_len(train_idx_api_call)}, "
            f"val {self._idx_len(val_idx_api_call)}, "
            f"test {self._idx_len(test_idx_api_call)}.")

        self.loader_node_name, self.val_loader_node_name, self.test_loader_node_name = self._get_loaders(
            train_idx=train_idx_node_name,
            val_idx=val_idx_node_name,
            test_idx=test_idx_node_name,
            batch_size=self.batch_size  # batch_size_node_name
        )
        self.loader_var_use, self.val_loader_var_use, self.test_loader_var_use = self._get_loaders(
            train_idx=train_idx_var_use,
            val_idx=val_idx_var_use,
            test_idx=test_idx_var_use,
            batch_size=self.batch_size  # batch_size_var_use
        )
        self.loader_api_call, self.val_loader_api_call, self.test_loader_api_call = self._get_loaders(
            train_idx=train_idx_api_call,
            val_idx=val_idx_api_call,
            test_idx=test_idx_api_call,
            batch_size=self.batch_size  # batch_size_api_call
        )

    def _create_api_call_loader(self, indices):
        sampler = dgl.dataloading.MultiLayerNeighborSampler(
            [self.sampling_neighbourhood_size] * self.graph_model.num_layers)
        return dgl.dataloading.NodeDataLoader(self.graph_model.g,
                                              indices,
                                              sampler,
                                              batch_size=len(indices),
                                              num_workers=0)

    def _create_optimizer(self):
        # optimizer = torch.optim.Adam(model.parameters(), lr=lr)
        optimizer = torch.optim.Adam(
            [
                {
                    'params': self.graph_model.parameters()
                },
                {
                    'params': self.node_embedder.parameters()
                },
                {
                    'params': self.ee_node_name.parameters()
                },
                {
                    'params': self.ee_var_use.parameters()
                },
                # {'params': self.ee_api_call.parameters()},
                {
                    'params': self.lp_node_name.parameters()
                },
                {
                    'params': self.lp_var_use.parameters()
                },
                {
                    'params': self.lp_api_call.parameters()
                },
            ],
            lr=self.lr)
        return optimizer

    def train_all(self):
        """
        Training procedure for the model with node classifier
        :return:
        """

        for epoch in range(self.epoch, self.epochs):
            self.epoch = epoch

            start = time()

            for i, ((input_nodes_node_name, seeds_node_name, blocks_node_name),
                    (input_nodes_var_use, seeds_var_use, blocks_var_use),
                    (input_nodes_api_call, seeds_api_call, blocks_api_call)) in \
                    enumerate(zip(
                        self.loader_node_name,
                        self.loader_var_use,
                        self.loader_api_call)):

                blocks_node_name = [
                    blk.to(self.device) for blk in blocks_node_name
                ]
                blocks_var_use = [
                    blk.to(self.device) for blk in blocks_var_use
                ]
                blocks_api_call = [
                    blk.to(self.device) for blk in blocks_api_call
                ]

                logits_node_name, labels_node_name = self._logits_node_name(
                    input_nodes_node_name, seeds_node_name, blocks_node_name)

                logits_var_use, labels_var_use = self._logits_var_use(
                    input_nodes_var_use, seeds_var_use, blocks_var_use)

                logits_api_call, labels_api_call = self._logits_api_call(
                    input_nodes_api_call, seeds_api_call, blocks_api_call)

                train_acc_node_name = compute_accuracy(
                    logits_node_name.argmax(dim=1), labels_node_name)
                train_acc_var_use = compute_accuracy(
                    logits_var_use.argmax(dim=1), labels_var_use)
                train_acc_api_call = compute_accuracy(
                    logits_api_call.argmax(dim=1), labels_api_call)

                train_logits = torch.cat(
                    [logits_node_name, logits_var_use, logits_api_call], 0)
                train_labels = torch.cat(
                    [labels_node_name, labels_var_use, labels_api_call], 0)

                logp = nn.functional.log_softmax(train_logits, 1)
                loss = nn.functional.nll_loss(logp, train_labels)

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                self.write_summary(
                    {
                        "Loss": loss,
                        "Accuracy/train/node_name_vs_batch":
                        train_acc_node_name,
                        "Accuracy/train/var_use_vs_batch": train_acc_var_use,
                        "Accuracy/train/api_call_vs_batch": train_acc_api_call
                    }, self.batch)
                self.batch += 1

            self.eval()

            with torch.set_grad_enabled(False):

                _, val_acc_node_name, val_acc_var_use, val_acc_api_call = self._evaluate_objectives(
                    self.val_loader_node_name, self.val_loader_var_use,
                    self.val_loader_api_call, self.neg_sampling_factor)

                _, test_acc_node_name, test_acc_var_use, test_acc_api_call = self._evaluate_objectives(
                    self.test_loader_node_name, self.test_loader_var_use,
                    self.test_loader_api_call, self.neg_sampling_factor)

            self.train()

            end = time()

            self.best_score.track_best(epoch=epoch,
                                       loss=loss.item(),
                                       train_acc_node_name=train_acc_node_name,
                                       val_acc_node_name=val_acc_node_name,
                                       test_acc_node_name=test_acc_node_name,
                                       train_acc_var_use=train_acc_var_use,
                                       val_acc_var_use=val_acc_var_use,
                                       test_acc_var_use=test_acc_var_use,
                                       train_acc_api_call=train_acc_api_call,
                                       val_acc_api_call=val_acc_api_call,
                                       test_acc_api_call=test_acc_api_call,
                                       time=end - start)

            if self.do_save:
                self.save_checkpoint(self.model_base_path)

            self.write_summary(
                {
                    "Accuracy/test/node_name_vs_batch": test_acc_node_name,
                    "Accuracy/test/var_use_vs_batch": test_acc_var_use,
                    "Accuracy/test/api_call_vs_batch": test_acc_api_call,
                    "Accuracy/val/node_name_vs_batch": val_acc_node_name,
                    "Accuracy/val/var_use_vs_batch": val_acc_var_use,
                    "Accuracy/val/api_call_vs_batch": val_acc_api_call
                }, self.batch)

            self.write_hyperparams(
                {
                    "Loss/train_vs_epoch": loss,
                    "Accuracy/train/node_name_vs_epoch": train_acc_node_name,
                    "Accuracy/train/var_use_vs_epoch": train_acc_var_use,
                    "Accuracy/train/api_call_vs_epoch": train_acc_api_call,
                    "Accuracy/test/node_name_vs_epoch": test_acc_node_name,
                    "Accuracy/test/var_use_vs_epoch": test_acc_var_use,
                    "Accuracy/test/api_call_vs_epoch": test_acc_api_call,
                    "Accuracy/val/node_name_vs_epoch": val_acc_node_name,
                    "Accuracy/val/var_use_vs_epoch": val_acc_var_use,
                    "Accuracy/val/api_call_vs_epoch": val_acc_api_call
                }, self.epoch)

            self.lr_scheduler.step()

    def save_checkpoint(self,
                        checkpoint_path=None,
                        checkpoint_name=None,
                        **kwargs):

        checkpoint_path = join(checkpoint_path, "saved_state.pt")

        param_dict = {
            'graph_model': self.graph_model.state_dict(),
            'node_embedder': self.node_embedder.state_dict(),
            'ee_node_name': self.ee_node_name.state_dict(),
            'ee_var_use': self.ee_var_use.state_dict(),
            # 'ee_api_call': self.ee_api_call.state_dict(),
            "lp_node_name": self.lp_node_name.state_dict(),
            "lp_var_use": self.lp_var_use.state_dict(),
            "lp_api_call": self.lp_api_call.state_dict(),
            "epoch": self.epoch,
            "batch": self.batch
        }

        if len(kwargs) > 0:
            param_dict.update(kwargs)

        torch.save(param_dict, checkpoint_path)

    def restore_from_checkpoint(self, checkpoint_path):
        checkpoint = torch.load(join(checkpoint_path, "saved_state.pt"))
        self.graph_model.load_state_dict(checkpoint['graph_model'])
        self.ee_node_name.load_state_dict(checkpoint['ee_node_name'])
        self.ee_var_use.load_state_dict(checkpoint['ee_var_use'])
        # self.ee_api_call.load_state_dict(checkpoint['ee_api_call'])
        self.lp_node_name.load_state_dict(checkpoint['lp_node_name'])
        self.lp_var_use.load_state_dict(checkpoint['lp_var_use'])
        self.lp_api_call.load_state_dict(checkpoint['lp_api_call'])
        self.epoch = checkpoint['epoch']
        self.batch = checkpoint['batch']
        logging.info(f"Restored from epoch {checkpoint['epoch']}")

    def final_evaluation(self):

        with torch.set_grad_enabled(False):

            loss, train_acc_node_name, train_acc_var_use, train_acc_api_call = self._evaluate_objectives(
                self.loader_node_name, self.loader_var_use,
                self.loader_api_call, 1)

            _, val_acc_node_name, val_acc_var_use, val_acc_api_call = self._evaluate_objectives(
                self.val_loader_node_name, self.val_loader_var_use,
                self.val_loader_api_call, 1)

            _, test_acc_node_name, test_acc_var_use, test_acc_api_call = self._evaluate_objectives(
                self.test_loader_node_name, self.test_loader_var_use,
                self.test_loader_api_call, 1)

        scores = {
            # "loss": loss.item(),
            "train_acc_node_name": train_acc_node_name,
            "val_acc_node_name": val_acc_node_name,
            "test_acc_node_name": test_acc_node_name,
            "train_acc_var_use": train_acc_var_use,
            "val_acc_var_use": val_acc_var_use,
            "test_acc_var_use": test_acc_var_use,
            "train_acc_api_call": train_acc_api_call,
            "val_acc_api_call": val_acc_api_call,
            "test_acc_api_call": test_acc_api_call,
        }

        print(
            f'Final Eval : node name Train Acc {scores["train_acc_node_name"]:.4f}, '
            f'node name Val Acc {scores["val_acc_node_name"]:.4f}, '
            f'node name Test Acc {scores["test_acc_node_name"]:.4f}, '
            f'var use Train Acc {scores["train_acc_var_use"]:.4f}, '
            f'var use Val Acc {scores["val_acc_var_use"]:.4f}, '
            f'var use Test Acc {scores["test_acc_var_use"]:.4f}, '
            f'api call Train Acc {scores["train_acc_api_call"]:.4f}, '
            f'api call Val Acc {scores["val_acc_api_call"]:.4f}, '
            f'api call Test Acc {scores["test_acc_api_call"]:.4f}')

        return scores

    def eval(self):
        self.graph_model.eval()
        self.ee_node_name.eval()
        self.ee_var_use.eval()
        # self.ee_api_call.eval()
        self.lp_node_name.eval()
        self.lp_var_use.eval()
        self.lp_api_call.eval()

    def train(self):
        self.graph_model.train()
        self.ee_node_name.train()
        self.ee_var_use.train()
        # self.ee_api_call.eval()
        self.lp_node_name.train()
        self.lp_var_use.train()
        self.lp_api_call.train()

    def to(self, device):
        self.graph_model.to(device)
        self.ee_node_name.to(device)
        self.ee_var_use.to(device)
        # self.ee_api_call.to(device)
        self.lp_node_name.to(device)
        self.lp_var_use.to(device)
        self.lp_api_call.to(device)

    def get_embeddings(self):
        # self.graph_model.g.nodes["function"].data.keys()
        nodes = self.graph_model.g.nodes
        node_embs = {
            ntype: self.node_embedder(node_type=ntype,
                                      node_ids=nodes[ntype].data['typed_id'],
                                      train_embeddings=False)
            for ntype in self.graph_model.g.ntypes
        }

        h = self.graph_model.inference(batch_size=256,
                                       device='cpu',
                                       num_workers=0,
                                       x=node_embs)

        original_id = []
        global_id = []
        embeddings = []
        for ntype in self.graph_model.g.ntypes:
            embeddings.append(h[ntype])
            original_id.extend(nodes[ntype].data['original_id'].tolist())
            global_id.extend(nodes[ntype].data['global_graph_id'].tolist())

        embeddings = torch.cat(embeddings, dim=0).detach().numpy()

        return [Embedder(dict(zip(original_id, global_id)), embeddings)]
class TaskTrainerLogger():
    def __init__(self,
                 trainer_env: TrainerEnv,
                 model: PreTrainedModel,
                 tb_writer: Optional["SummaryWriter"] = None):
        self._env = trainer_env
        self.config: TrainerLoggerConfig = trainer_env.get_config(
            TrainerLoggerConfig)
        self._model = model
        self._tb_writer = tb_writer
        self.__init_logger()

    def __init_logger(self):
        if not self._tb_writer and is_tensorboard_available(
        ) and self._env.config.local_rank in [-1, 0]:
            self._tb_writer = SummaryWriter(log_dir=self.config.logging_dir)

        if not is_tensorboard_available():
            log.warning(
                "You are instantiating a Trainer but Tensorboard is not installed. You should consider installing it."
            )
        if is_wandb_available():
            self._setup_wandb()
        else:
            log.info(
                "You are instantiating a Trainer but W&B is not installed. To use wandb logging, "
                "run `pip install wandb; wandb login` see https://docs.wandb.com/huggingface."
            )

    def _setup_wandb(self):
        """
        Setup the optional Weights & Biases (`wandb`) integration.

        One can override this method to customize the setup if needed.  Find more information at https://docs.wandb.com/huggingface
        You can also override the following environment variables:

        Environment:
            WANDB_WATCH:
                (Optional, ["gradients", "all", "false"]) "gradients" by default, set to "false" to disable gradient logging
                or "all" to log gradients and parameters
            WANDB_PROJECT:
                (Optional): str - "huggingface" by default, set this to a custom string to store results in a different project
            WANDB_DISABLED:
                (Optional): boolean - defaults to false, set to "true" to disable wandb entirely
        """
        log.info(
            'Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"'
        )
        wandb.init(project=os.getenv("WANDB_PROJECT", "huggingface"),
                   config=vars(self._env.args))
        # keep track of model topology and gradients
        if os.getenv("WANDB_WATCH") != "false":
            wandb.watch(self._model,
                        log=os.getenv("WANDB_WATCH", "gradients"),
                        log_freq=max(100, self.config.logging_steps))

    def log_pre_train(self, train_dataloader,
                      train_scheduler: TaskTrainedScheduler):
        if self._tb_writer is not None:
            self._tb_writer.add_text("args", self._env.args.to_json_string())
            self._tb_writer.add_hparams(self._env.args.to_sanitized_dict(),
                                        metric_dict={})

        log.info("***** Running training *****")
        log.info("  Num examples = %d",
                 self._env.num_examples(train_dataloader))
        log.info("  Num Epochs = %d", train_scheduler.num_train_epochs)
        log.info("  Instantaneous batch size per device = %d",
                 self._env.config.per_gpu_train_batch_size)
        log.info(
            "  Total train batch size (w. parallel, distributed & accumulation) = %d",
            train_scheduler.total_train_batch_size)
        log.info("  Gradient Accumulation steps = %d",
                 train_scheduler.config.gradient_accumulation_steps)
        log.info("  Total optimization steps = %d", train_scheduler.t_total)

    def is_need_log_step(self, global_step):
        return (self.config.logging_steps > 0
                and global_step % self.config.logging_steps
                == 0) or (global_step == 1 and self.config.logging_first_step)

    def log_train_step(self,
                       epoch,
                       global_step,
                       logs: Dict[str, float],
                       iterator: Optional[tqdm] = None) -> None:
        if epoch is not None: logs["epoch"] = epoch
        if self._tb_writer:
            for k, v in logs.items():
                self._tb_writer.add_scalar(k, v, global_step)
        if is_wandb_available(): wandb.log(logs, step=global_step)
        output = json.dumps({**logs, **{"step": global_step}})

        (iterator.write(output) if iterator else print(output))

    def log_train_epoch(self):
        if self.config.tpu_metrics_debug:
            # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
            self._env.tpu_metrics()

    def log_train_end(self):
        if self._tb_writer: self._tb_writer.close()

        log.info(
            "\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n"
        )
Example #11
0
    print(model)
    print('\n\nNumber of parameters: {}\n'.format(
        sum(p.numel() for p in model.parameters())))

if args.cuda:
    device = get_freer_gpu()
    model = model.to(device)

if args.logdir:
    from torch.utils.tensorboard import SummaryWriter
    writer = SummaryWriter(log_dir=args.logdir,
                           comment=args.model,
                           purge_step=0 if args.checkpoint_epoch is None else
                           int(args.checkpoint_epoch * len(train_loader)))
    args_dict = parse_args_for_log(args)
    writer.add_hparams(hparam_dict=args_dict, metric_dict={'best_acc': 0.0})
else:
    writer = None
    args_dict = None

optimizer = optim.Adam(model.parameters(),
                       lr=args.lr,
                       betas=(args.beta1, args.beta2),
                       weight_decay=args.l2)

trainer = TrainLoop(model,
                    optimizer,
                    train_loader,
                    valid_loader,
                    max_gnorm=args.max_gnorm,
                    label_smoothing=args.smoothing,
Example #12
0
class Tensorboard(Callback):
    writer: SummaryWriter

    def __init__(self, root=None, trainer=None, writer=None):
        self.root = root
        self.trainer = trainer
        self.writer = writer

        self.logged_model_graph = False

    def _init_writer(self):
        self.writer = SummaryWriter(self.root)

    def log(self, val, tag=None, step=None, **kwargs):
        if step is None and self.trainer:
            step = self.trainer.num_steps

        if isinstance(val, nn.Module):
            self.writer.add_graph(val,
                                  kwargs.get('input'),
                                  verbose=kwargs.get('verbose', False))

    def show(self, root=None):
        from IPython import get_ipython
        ipython = get_ipython()
        ipython.magic('load_ext tensorboard')

        if isinstance(self, (str, Path)):
            # static method (Tensorboard.show())
            ipython.magic(f'tensorboard --logdir {self}')
        else:
            ipython.magic(f'tensorboard --logdir {root or self.root or "./"}')

    def close(self):
        if self.writer:
            self.writer.close()

    def flush(self):
        if self.writer:
            self.writer.flush()

    def on_train_start(self, trainer=None, **kwargs):
        if self.root is None:
            self.root = trainer.paths.tensorboard
        self.trainer = trainer

        if self.writer is None:
            self._init_writer()

    def on_train_end(self, trainer=None):
        if self.trainer and self.trainer.params and self.trainer.history.val_metrics:
            try:
                self.writer.add_hparams(
                    dict(self.trainer.params),
                    self.trainer.history.val_metrics.summary())
            except ValueError as e:
                log.error(e)

    def on_step_start(self, inputs=None, **kwargs):
        if not self.logged_model_graph:
            self.writer.add_graph(self.sanitize_model(self.trainer.model),
                                  inputs)
            self.logged_model_graph = True

    def on_step_end(self,
                    index=None,
                    inputs=None,
                    targets=None,
                    outputs=None,
                    loss=None,
                    trainer=None):
        self.writer.add_scalar('train/loss',
                               loss,
                               global_step=trainer.num_steps)
        for metric, values in trainer.history.metrics.items():
            self.writer.add_scalar(f'train/{metric}',
                                   values[-1],
                                   global_step=len(values) - 1)

    def on_validation_end(self,
                          targets=None,
                          outputs=None,
                          loss=None,
                          trainer=None):
        for metric, values in trainer.history.val_metrics.items():
            self.writer.add_scalar(f'validation/{metric}',
                                   values[-1],
                                   global_step=len(values) - 1)

    def sanitize_model(self, model):
        if isinstance(model, nn.DataParallel):
            return model.module
        return model
Example #13
0
class Dense_U_Net_lidar_Agent:
    def __init__(self, config=None, torchvision_init=True, lidar=False):
        '''
        Handles everything
        - training, validation testing
        - checkpoint loading and saving
        - logging | tensorboard summaries

        Accordingly everything is specified here
        - model
        - loss
        - optimizer
        - lr scheduling

        Arguments:
            torchvision_init: boolean
                - True:     load densenet state dict from torchvision
                - False:    load checkpoint; if no checkpoint just normal init
        '''

        self.logger = logging.getLogger('Agent')

        # model and config if lazy
        self.model = maskrcnn_resnet50_fpn(pretrained=True,
                                           progress=True,
                                           num_classes=91,  # have to if pretrained
                                           pretrained_backbone=True,
                                           trainable_backbone_layers=3)  # 0 being noe and 5 all

        '''
        # get number of input features for the classifier
        in_features = model.roi_heads.box_predictor.cls_score.in_features
        # replace the pre-trained head with a new one
        model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
        
        # now get the number of input features for the mask classifier
        in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
        hidden_layer = 256
        # and replace the mask predictor with a new one
        model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask,
                                                            hidden_layer,
                                                            num_classes)
        '''
        self.lidar = lidar
        if self.lidar:
            # add one channel to first layer
            self.model.backbone.body.conv1 = nn.Conv2d(4, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3),
                                                       bias=False)
        # replace final layer to 4 classes: background, vehicle, pedestrian, cyclist
        self.model.roi_heads.mask_predictor.mask_fcn_logits = nn.Conv2d(256, 4, kernel_size=(1, 1),
                                                                        stride=(1, 1))

        # in case config is empty it is created in model
        if config is None:
            self.config = utils.get_config()
        else:
            self.config = config

        # dataloader
        self.data_loader = WaymoDataset_Loader(self.config)

        # pixel-wise cross-entropy loss
        self.loss = torch.nn.BCEWithLogitsLoss(reduction='none').cuda()

        # optimizer
        self.optimizer = torch.optim.Adam(self.model.parameters(),
                                          lr=self.config.optimizer.learning_rate,
                                          betas=(self.config.optimizer.beta1, self.config.optimizer.beta2),
                                          eps=self.config.optimizer.eps,
                                          weight_decay=self.config.optimizer.weight_decay,
                                          amsgrad=self.config.optimizer.amsgrad)

        # learning rate decay scheduler
        if self.config.optimizer.lr_scheduler.want:
            self.lr_scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer,
                                                                step_size=self.config.optimizer.lr_scheduler.every_n_epochs,
                                                                gamma=self.config.optimizer.lr_scheduler.gamma)

        # initialize counters; updated in load_checkpoint
        self.current_epoch = 0
        self.current_train_iteration = 0
        self.current_val_iteration = 0
        self.best_val_iou = 0

        # if cuda is available export model to gpu
        self.cuda = torch.cuda.is_available()
        if self.cuda:
            self.device = torch.device('cuda')
            torch.cuda.manual_seed_all(self.config.agent.seed)
            self.logger.info('Operation will be on *****GPU-CUDA***** ')
        else:
            self.device = torch.device('cpu')
            torch.manual_seed(self.config.agent.seed)
            self.logger.info('Operation will be on *****CPU***** ')
        self.model = self.model.to(self.device)
        self.loss = self.loss.to(self.device)

        if not torchvision_init:
            self.load_checkpoint()

        # Tensorboard Writers
        Path(self.config.dir.current_run.summary).mkdir(exist_ok=True, parents=True)
        self.train_summary_writer = SummaryWriter(log_dir=self.config.dir.current_run.summary,
                                                  comment='FasterRCNNResNet50')
        self.val_summary_writer = SummaryWriter(log_dir=self.config.dir.current_run.summary,
                                                comment='FasterRCNNResNet50')

    def save_checkpoint(self, filename='checkpoint.pth.tar', is_best=False):
        '''
        Saving the latest checkpoint of the training

        Arguments:
            filename: filename which will contain the state
            is_best: flag is it is the best model
        '''

        # aggregate important data
        state = {
            self.config.agent.checkpoint.epoch: self.current_epoch,
            self.config.agent.checkpoint.train_iteration: self.current_train_iteration,
            self.config.agent.checkpoint.val_iteration: self.current_val_iteration,
            self.config.agent.checkpoint.best_val_iou: self.best_val_iou,
            self.config.agent.checkpoint.state_dict: self.model.state_dict(),
            self.config.agent.checkpoint.optimizer: self.optimizer.state_dict()
        }

        if is_best:
            filename = self.config.agent.best_checkpoint_name

        # create dir if not exists
        Path(self.config.dir.current_run.checkpoints).mkdir(exist_ok=True, parents=True)

        # Save the state
        torch.save(state, os.path.join(self.config.dir.current_run.checkpoints, filename))

    def load_checkpoint(self, filename=None):
        '''
        load checkpoint from file
        should contain following keys:
            'epoch', 'iteration', 'best_val_iou', 'state_dict', 'optimizer'
            where state_dict is model statedict
            and optimizer is optimizer statesict

        Arguments:
            filename: only name with file type extension | path in config.dir.current_run.checkpoints
        '''

        # use best if not specified
        if filename is None:
            filename = self.config.agent.best_checkpoint_name

        # load according to key
        filepath = os.path.join(self.config.dir.current_run.checkpoints, filename)
        try:
            self.logger.info('Loading checkpoint {}'.format(filename))
            checkpoint = torch.load(filepath)

            self.current_epoch = checkpoint[self.config.agent.checkpoint.epoch]
            self.current_train_iteration = checkpoint[
                self.config.agent.checkpoint.train_iteration]
            self.current_val_iteration = checkpoint[
                self.config.agent.checkpoint.val_iteration]
            self.best_val_iou = checkpoint[
                self.config.agent.checkpoint.best_val_iou]
            self.model.load_state_dict(checkpoint[
                                           self.config.agent.checkpoint.state_dict])
            self.optimizer.load_state_dict(checkpoint[
                                               self.config.agent.checkpoint.optimizer])

            self.logger.info('Checkpoint loaded successfully from {} at (epoch {}) at (iteration {})\n'
                             .format(self.config.dir.current_run.checkpoints, checkpoint['epoch'],
                                     checkpoint['train_iteration']))
        except OSError:
            warnings.warn('No checkpoint exists from {}. Skipping...'.format(filepath))
            self.logger.info('No checkpoint exists from {}. Skipping...'.format(filepath))
            self.logger.info('**First time to train**')

    def run(self):
        '''
        starts training are testing: specify under config.loader.mode
        can handle keyboard interupt
        '''

        print('starting ' + self.config.loader.mode + ' at ' + str(datetime.now()))
        try:
            if self.config.loader.mode == 'test':
                with torch.no_grad():
                    self.validate()
            else:
                self.train()

        except KeyboardInterrupt:
            self.logger.info('You have entered CTRL+C.. Wait to finalize')

    def train(self):
        '''
        training one epoch at a time
        validating after each epoch
        saving checkpoint after each epoch
        check if val acc is best and store separately
        '''

        # add selected loss and optimizer to config  | not added in init as may be changed before training
        self.config.loss.func = str(self.loss)
        self.config.optimizer.func = str(self.optimizer)

        # make sure to remember the hyper params
        self.add_hparams_summary_writer()
        self.save_hparams_json()

        # Iterate epochs | train one epoch | validate | save checkpoint
        for epoch in range(self.current_epoch, self.config.agent.max_epoch):
            self.current_epoch = epoch
            self.train_one_epoch()

            with torch.no_grad():
                avg_val_iou_per_class = self.validate()

            val_iou = sum(avg_val_iou_per_class) / len(avg_val_iou_per_class)
            is_best = val_iou > self.best_val_iou
            if is_best:
                self.best_val_iou = val_iou
            self.save_checkpoint(is_best=is_best)

        self.train_summary_writer.close()
        self.val_summary_writer.close()

    def train_one_epoch(self):
        '''
        One epoch training function
        '''

        # Initialize progress visualization and get batch
        tqdm_batch = tqdm(self.data_loader.train_loader, total=self.data_loader.train_iterations,
                          desc='Epoch-{}-'.format(self.current_epoch))

        # Set the model to be in training mode
        self.model.train()

        # metric counters
        current_batch = 0
        number_of_batches = self.data_loader.train_loader.dataset.__len__()
        epoch_loss = torch.zeros(number_of_batches).to(self.device)

        for image, lidar, _, targets in tqdm_batch:

            # push to gpu if possible
            if self.cuda:
                image = image.cuda(non_blocking=self.config.loader.async_loading)
                lidar = lidar.cuda(non_blocking=self.config.loader.async_loading)

            # forward pass
            '''
            During training, the model expects both the input tensors, as well as a targets (list of dictionary),
            containing:
            - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format,  with values of ``x``
              between ``0`` and ``W`` and values of ``y`` between ``0`` and ``H``
            - labels (``Int64Tensor[N]``): the class label for each ground-truth box
            - masks (``UInt8Tensor[N, H, W]``): the segmentation binary masks for each instance
    
            The model returns a ``Dict[Tensor]`` during training, containing the classification and regression
            losses for both the RPN and the R-CNN, and the mask loss.
            '''

            model_input = torch.cat((image, lidar), dim=1) if self.lidar else image
            loss_dict = self.model(model_input, targets)

            losses = sum(loss for loss in loss_dict.values())

            self.optimizer.zero_grad()
            losses.backward()
            self.optimizer.step()

            epoch_loss[current_batch] = losses.item()

            self.train_summary_writer.add_scalars('Training/Loss', loss_dict.detach(), self.current_train_iteration)

            # counters
            self.current_train_iteration += 1
            current_batch += 1

        tqdm_batch.close()

        # learning rate decay update; after validate; after each epoch
        if self.config.optimizer.lr_scheduler.want:
            self.lr_scheduler.step()

        # log
        avg_epoch_loss = torch.mean(epoch_loss, axis=0).tolist()
        self.logger.info('Training at Epoch-' + str(self.current_epoch) + ' | ' + 'Average Loss: ' + str(
            avg_epoch_loss))

    def validate(self):
        '''
        One epoch validation

        return:
            average IoU per class
        '''

        # Initialize progress visualization and get batch
        # !self.data_loader.valid_loader works for both valid and test
        tqdm_batch = tqdm(self.data_loader.valid_loader, total=self.data_loader.valid_iterations,
                          desc='Valiation at -{}-'.format(self.current_epoch))

        # set the model in training mode
        self.model.eval()

        # metric counters
        current_batch = 0
        number_of_batches = self.data_loader.valid_loader.dataset.__len__()
        epoch_loss = torch.zeros((number_of_batches, self.config.model.num_classes)).to(self.device)
        epoch_iou = torch.zeros((number_of_batches, self.config.model.num_classes))
        epoch_iou_nans = torch.zeros((number_of_batches, self.config.model.num_classes))
        epoch_acc = torch.zeros((number_of_batches, self.config.model.num_classes)).to(self.device)

        for image, lidar, ht_map, _ in tqdm_batch:

            # push to gpu if possible
            if self.cuda:
                image = image.cuda(non_blocking=self.config.loader.async_loading)
                lidar = lidar.cuda(non_blocking=self.config.loader.async_loading)
                ht_map = ht_map.cuda(non_blocking=self.config.loader.async_loading)

            # forward pass
            '''
            During inference, the model requires only the input tensors, and returns the post-processed
            predictions as a ``List[Dict[Tensor]]``, one for each input image. The fields of the ``Dict`` are as
            follows:
            - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format,  with values of ``x``
              between ``0`` and ``W`` and values of ``y`` between ``0`` and ``H``
            - labels (``Int64Tensor[N]``): the predicted labels for each image
            - scores (``Tensor[N]``): the scores or each prediction
            - masks (``UInt8Tensor[N, 1, H, W]``): the predicted masks for each instance, in ``0-1`` range. In order to
              obtain the final segmentation masks, the soft masks can be thresholded, generally
              with a value of 0.5 (``mask >= 0.5``)
            '''
            model_input = torch.cat((image, lidar), dim=1) if self.lidar else image
            prediction_list = self.model(model_input)

            # TODO alt version thresholding masks and then same
            # TODO expand masks into ht maps as before, rest should be the same
            # in 2nd dim change values into dimensions
            # -> join masks of same type -> torch.max(masks[indeces_predicted_class], dim=...)
            prediction = torch.zeros_like(ht_map)
            for sample_i, sample_prediction in enumerate(prediction_list):
                for obj_class in [0, 1, 2]:
                    class_idx = sample_prediction['labels'] == obj_class
                    prediction[sample_i, obj_class] = torch.max(sample_prediction['masks'][class_idx], dim=0)

            # pixel-wise loss
            current_loss = self.loss(prediction, ht_map)
            loss_per_class = torch.sum(current_loss.detach(), dim=(0, 2, 3))
            epoch_loss[current_batch, :] = loss_per_class

            # whole image IoU per class; not taking nans into acc for the mean value; counting the nans separately
            iou_per_instance_per_class = utils.compute_IoU_whole_img_batch(prediction.detach(), ht_map.detach(),
                                                                           self.config.agent.iou_threshold)
            iou_per_class = torch.tensor(np.nanmean(iou_per_instance_per_class, axis=0))
            iou_per_class[torch.isnan(iou_per_class)] = 0
            epoch_iou[current_batch, :] = iou_per_class
            epoch_iou_nans[current_batch, :] = torch.sum(torch.isnan(iou_per_instance_per_class), axis=0)

            # compute class-wise accuracy of current batch
            acc_per_class = utils.compute_accuracy(ht_map.detach(), prediction.detach(),
                                                   self.config.agent.iou_threshold)
            epoch_acc[current_batch, :] = acc_per_class

            # logging for visualization during training: separate plots for loss, acc, iou | each-classwise + overall
            loss_dict = {
                'Vehicle': loss_per_class[0],
                'Pedestrian': loss_per_class[1],
                'Cyclist': loss_per_class[2],
                'Overall': torch.mean(loss_per_class)
            }
            self.val_summary_writer.add_scalars('Validation/Loss', loss_dict, self.current_val_iteration)
            acc_dict = {
                'Vehicle': acc_per_class[0],
                'Pedestrian': acc_per_class[1],
                'Cyclist': acc_per_class[2],
                'Overall': torch.mean(acc_per_class)
            }
            self.val_summary_writer.add_scalars('Validation/Accuracy', acc_dict, self.current_val_iteration)
            iou_dict = {
                'Vehicle': iou_per_class[0],
                'Pedestrian': iou_per_class[1],
                'Cyclist': iou_per_class[2],
                'Overall': torch.mean(iou_per_class)
            }
            self.val_summary_writer.add_scalars('Validation/IoU', iou_dict, self.current_val_iteration)

            # counters
            self.current_val_iteration += 1
            current_batch += 1

        # log
        avg_epoch_loss = torch.mean(epoch_loss, axis=0).tolist()
        avg_epoch_iou = torch.mean(epoch_iou, axis=0).tolist()
        cum_epoch_nans = torch.sum(epoch_iou_nans, axis=0).tolist()
        avg_epoch_acc = torch.mean(epoch_acc, axis=0).tolist()
        self.logger.info('Validation at Epoch-' + str(self.current_epoch) + ' | ' + 'Average Loss: ' + str(
            avg_epoch_loss) + ' | ' + 'Average IoU: ' + str(avg_epoch_iou) + ' | ' + 'Number of NaNs: ' + str(
            cum_epoch_nans) + ' | ' + 'Average Accuracy: ' + str(avg_epoch_acc))

        tqdm_batch.close()

        return avg_epoch_iou

    def add_hparams_summary_writer(self):
        '''
        Add Hyperparamters to tensorboard summary writers using .add_hparams
        Can be accessed under the Hyperparameter tab in Tensorboard
        '''

        hyper_params = {
            'loss_func': self.config.loss.func,
            'loss_alpha': self.config.loss.alpha,
            'loss_gamma': self.config.loss.gamma,
            'loss_skip_v_every_n_its': self.config.loss.skip_v_every_n_its,
            'loss_skip_p_every_n_its': self.config.loss.skip_p_every_n_its,
            'loss_skip_b_every_n_its': self.config.loss.skip_b_every_n_its,
            'optimizer': self.config.optimizer.func,
            'learning_rate': self.config.optimizer.learning_rate,
            'beta1': self.config.optimizer.beta1,
            'beta2': self.config.optimizer.beta2,
            'eps': self.config.optimizer.eps,
            'amsgrad': self.config.optimizer.amsgrad,
            'weight_decay': self.config.optimizer.weight_decay,
            'lr_scheduler': self.config.optimizer.lr_scheduler.want,
            'lr_scheduler_every_n_epochs': self.config.optimizer.lr_scheduler.every_n_epochs,
            'lr_scheduler_gamma': self.config.optimizer.lr_scheduler.gamma,
        }

        self.train_summary_writer.add_hparams(hyper_params, {})
        self.val_summary_writer.add_hparams(hyper_params, {})

    def save_hparams_json(self):
        '''
        Uses config information to generate a hyperparameter dict and saves it as a json file
        into the current_run directory
        '''

        hparams = {
            'loss': self.config.loss.__dict__,
            'optimizer': self.config.optimizer.__dict__
        }

        utils.save_json_file(os.path.join(self.config.dir.current_run.summary, 'hyperparams.json'),
                             hparams, indent=4)

    def finalize(self):
        '''
        Close all Writers and print time
        '''

        self.logger.info('Please wait while finalizing the operation.. Thank you')
        self.train_summary_writer.close()
        self.val_summary_writer.close()
        print('ending ' + self.config.loader.mode + ' at ' + str(datetime.now()))
Example #14
0
    if Config.token_weight_file is not None:
        with open(Config.data_path + Config.token_weight_file, 'rb') as f:
            token_weights = pickle.load(f)
        token_weights = torch.Tensor(token_weights).to(Config.device)
    else:
        token_weights = None
    criterion = NLLLoss(ignore_index=Config.PAD_ID, weight=token_weights)

    progress_bar = ProgressBar(len(train_loader), ema=0)

    # Tensorboard
    summary_writer = SummaryWriter(comment=Config.model_name)
    summary_writer.add_hparams(
        {
            k: str(v)
            for k, v in Config.__dict__.items() if not k.startswith('__')
        }, {})
    summary_writer.add_hparams({"epoch_steps": epoch_steps}, {})

    if Config.load_state is not None:
        load_model(model.decoder, optimizer, Config.load_state)

    running_acc = None
    adaptive_summary_len = 3
    increase_sum_len = 0

    for epoch in range(0 if Config.load_state is None else Config.load_state,
                       Config.num_epochs):
        model.train()
        progress_bar.start()
Example #15
0
def test_clustertask(operateconfig:Dict,dataconfig:Dict, trainingconfig:Dict, modelconfig:Dict):
    #set registered hyper parameters
    logger.info("Register Hyper Parameter")
    hparams = generate_register_hparams(modelconfig,trainingconfig,dataconfig)

    dir_path =  dataconfig['data_dir_path']
    comment = '_' + dir_path.name +'_'+modelconfig['name']+'_'+modelconfig['version']
    metric_dict = {}
    w = SummaryWriter(comment = comment) if args.p else None
    if not dir_path:
        raise KeyError
    logger.info("Load Embedding Vector")
    datasetdir = DataSetDir(dir_path,word_emb_select=dataconfig['word_emb_select'])
    # combine model
    embedding_layer = Embedding_layer.from_pretrained(datasetdir.embedding_vec)
    
    embedding_layer.freeze_parameters()
    attenion_layer = Attention_layer(embedding_layer.dim,modelconfig['attention_hidden_size'])
    modelconfig['attention'] = attenion_layer
    modelconfig['embedding'] = embedding_layer
    model = BinarySynClassifierBaseOnAttention(
                config = modelconfig
            )
    optimizer = optim.Adam(filter(lambda x : x.requires_grad , model.parameters()),lr=trainingconfig['lr'], amsgrad=True)
    trainingconfig['optim'] = optimizer
    trainingconfig['loss_fn'] = torch.nn.BCELoss()
    wrapper = ModelWrapper(model,trainingconfig)
    
    if operateconfig['resume']:
        wrapper.load_check_point()
        # continue to trainning

    if operateconfig['train']:
        logger.info("Generate DataLoader")
        train_datasetitem = DataItemSet(
                    dataset=datasetdir.train_dataset,
                    sampler = select_sampler(dataconfig['sample_strategy']),
                    negative_sample_size = dataconfig['negative_sample_size']
                ) 
        dev_datasetitem = DataItemSet(
                    dataset=datasetdir.test_dataset,
                    sampler = select_sampler(dataconfig['test_sample_strategy']),
                    negative_sample_size = dataconfig['test_negative_sample_size']
                )
        train_dataloader = Dataloader(
                    dataitems=train_datasetitem, 
                    word2id=datasetdir.word2id,
                    batch_size=trainingconfig['batch_size']
                )
        dev_dataloader = Dataloader(
                    dataitems=dev_datasetitem,
                    word2id=datasetdir.word2id,
                    batch_size=trainingconfig['batch_size']
                )
        logger.info("Start to Train !! ")

        #Plot in Tensorboard
        for ix,item in enumerate(wrapper.train(train_dataloader=train_dataloader,dev_dataloader=dev_dataloader)):
            ep_loss, t_ac, t_p, t_r, t_f1, v_loss, v_ac, v_p, v_r, v_f1, cluster_unit, b_score = item
            if w:
                w.add_scalar("Training/Loss", ep_loss ,ix)
                w.add_scalar("Training/Accuracy", t_ac, ix )
                w.add_scalar("Training/Precision", t_p, ix)
                w.add_scalar("Training/Recall", t_r, ix)
                w.add_scalar("Training/F1_score", t_f1, ix)
                w.add_scalar("Validation/Loss",v_loss, ix)
                w.add_scalar("Validation/Accuracy", v_ac, ix)
                w.add_scalar("Validation/Precision", v_p, ix)
                w.add_scalar("Validation/Recall", v_r, ix)
                w.add_scalar("Validation/F1_score", v_f1, ix)
                w.add_scalar("Validation/FMI", cluster_unit['FMI'], ix)
                w.add_scalar("Validation/ARI",  cluster_unit['ARI'], ix)
                w.add_scalar("Validation/NMI",cluster_unit['NMI'], ix)
                w.add_scalar("Best Score Update", b_score, ix)
        
    if operateconfig['test']:
        test_datasetitem = DataItemSet(
                    dataset=datasetdir.test_dataset,
                    sampler = select_sampler(dataconfig['test_sample_strategy']),
                    negative_sample_size = dataconfig['test_negative_sample_size']
                )

        test_dataloader = Dataloader(
                    dataitems=test_datasetitem,
                    word2id=datasetdir.word2id,
                    batch_size=trainingconfig['batch_size']
                )
        d = wrapper.test_performance(test_dataloader=test_dataloader)
        metric_dict = { **metric_dict, **d}
    
    if operateconfig['predict']:
        pred_word_set = wrapper.cluster_predict(
                    dataset=datasetdir.test_dataset,
                    word2id=datasetdir.word2id,
                    outputfile=trainingconfig['result_out_dir'].joinpath(datasetdir.name+'_result.txt')
                )
        ans = wrapper.evaluate(datasetdir.test_dataset, pred_word_set)
        logger.info("{} DataSet Cluster Prediction".format(datasetdir.train_dataset.name))
        for name,f in ans:
            logger.info("{} : {:.5f}".format(name,f))
        
        if w:
            d = {i:j for i,j in ans}
            metric_dict = {**metric_dict, **d}
            w.add_hparams(hparams, metric_dict = metric_dict)
            w.close()
    wrapper.save(config.WRAPPER_DIR_PATH.joinpath(datasetdir.name))
Example #16
0
 train_loss, train_mse, train_kld, train_recon_loss, valid_loss, valid_mse, valid_kld, valid_recon_loss = \
     main(writer, first, loss_fn, ftype, optim_type, nflows, lr, pruning_ratio, wd,
          l1, z_dim=z_dim)
 if train_loss is None:
     break
 writer.add_hparams({
     'Loss_fn': loss_fn,
     'Ftype': ftype,
     'Optim_type': optim_type,
     'Nflows': nflows,
     'z_dim': z_dim,
     'LR': lr,
     'Pruning Ratio (L1 Unstructured)': pruning_ratio,
     'L1_reg': l1,
     'Weight Decay': wd
 },
     {
         'Train Loss': train_loss,
         'Train MSE Loss': train_mse,
         'Train KLD Loss': train_kld,
         'Train Recon Loss': train_recon_loss,
         'Valid Loss': valid_loss,
         'Valid MSE Loss': valid_mse,
         'Valid KLD Loss': valid_kld,
         'Valid Recon Loss': valid_recon_loss,
     })
 if valid_loss < best_valid_loss:
     best_loss_fn = loss_fn
     best_valid_loss = valid_loss
     best_ftype = ftype
     best_optim_type = optim_type
Example #17
0
def train(model, datasets, params):
    """Train a model on a given dataset."""
    # Prepare training
    patience = 0
    best_loss = 9999
    idx_dataloader = 0
    if params["use_std"]:
        criterion = masked_nllloss
    else:
        criterion = masked_l1
    optimizer = optim.Adam(model.parameters(), lr=params["learning_rate"])
    train_dls, val_dl = datasets
    if not params["debug"]:
        writer = SummaryWriter(params["log_dir"])

    for epoch in range(params["n_epochs"]):
        ########### Training ##########
        running_loss = 0
        pbar = tqdm(train_dls[idx_dataloader])
        for i, ((inputs, time, context), labels) in enumerate(pbar):
            optimizer.zero_grad()
            # Forward pass
            outputs = model(inputs.float(), time.float(), context.float())

            # Backward pass
            pad_labels, _ = pad_packed_sequence(labels,
                                                batch_first=True,
                                                padding_value=-999)
            pad_outputs, _ = pad_packed_sequence(outputs,
                                                 batch_first=True,
                                                 padding_value=-999)
            loss = criterion(pad_outputs, pad_labels.float())
            loss.backward()
            optimizer.step()

            # Update progress bar
            running_loss += loss.item()
            pbar.set_description(
                f"Epoch #{epoch+1} - Loss = {running_loss / (i+1):.5f}")

        ########### Validation ##########
        val_running_loss = 0
        pbar = tqdm(val_dl)
        with torch.no_grad():
            for j, ((inputs, time, context), labels) in enumerate(pbar):
                # Evaluate
                outputs = model(inputs.float(), time.float(), context.float())

                # Compute loss
                pad_labels, _ = pad_packed_sequence(labels,
                                                    batch_first=True,
                                                    padding_value=-999)
                pad_outputs, _ = pad_packed_sequence(outputs,
                                                     batch_first=True,
                                                     padding_value=-999)
                loss = criterion(pad_outputs, pad_labels.float())

                # Update progress bar
                val_running_loss += loss  # MSE per batch
                pbar.set_description(
                    f"Validation loss = {val_running_loss / (j+1):.5f}")

        ########### Callbacks at the end of each epoch ##########
        # Tensorboard
        if params["debug"]:
            continue
        writer.add_scalars("loss", {
            "train": running_loss / (i + 1),
            "val": val_running_loss / (j + 1),
        }, epoch + 1)
        # Keep best model
        if (val_running_loss / (j + 1)) < best_loss:
            logging.info(
                f"Validation loss improved from {best_loss} to {(val_running_loss / (j+1))}"
            )
            best_loss = (val_running_loss / (j + 1))
            torch.save(model.state_dict(), join(params["log_dir"], "model.pt"))
            patience = 0
        else:
            logging.info(
                f"Val loss did not improve. Patience: {patience} (max: {params['max_patience']})"
            )
            patience += 1
        # Early stopping
        if patience >= params["max_patience"]:
            logging.info("Triggered early stopping.")
            idx_dataloader += 1
            if idx_dataloader == len(train_dls):
                break
            patience = 0
            logging.info(
                f"Using next dataloader w/ bounds {params['label_bounds'][idx_dataloader]}"
            )

    if not params["debug"]:
        writer.add_hparams(
            {
                k: v.__str__() if isinstance(v, list) else v
                for k, v in params.items()
            }, {"val_loss": best_loss})
    logging.info("Training done.")
Example #18
0
                          test_set=test_set,
                          model_path=model_path,
                          writer=writer,
                          device=device)

    # Write hyperparamters to TensorBoard
    hparams = {
        'embedding_dim': args.embedding_dim,
        'use_glove': args.use_glove,
        'freeze_embeddings': args.freeze_embeddings,
        'num_filters': args.num_filters,
        'hidden_dim': args.hidden_dim,
        'dropout_p': args.dropout_p,
        'learning_rate': args.learning_rate
    }
    writer.add_hparams(hparam_dict=hparams,
                       metric_dict={'best_val_loss': best_val_loss})

    # Evaluation
    test_loss, test_acc, y_pred, y_target = test_step(model=model,
                                                      dataset=test_set,
                                                      device=device)
    config.logger.info(
        "→ Test performance:\n"
        f"  test_loss: {test_loss:.2f}, test_acc: {test_acc:.1f}")

    # Per-class performance analysis
    performance = get_performance(y_pred, y_target, classes)
    plot_confusion_matrix(y_pred=y_pred,
                          y_target=y_target,
                          classes=classes,
                          fp=os.path.join(experiment_dir,
Example #19
0
class ModelTrainer:
    def __init__(
        self,
        model,
        n_epochs=100,
        batch_size=16,
        lr=1e-3,
        beta_function=None,
        checkpoint_path=None,
        tb_label=None,
        tb_dir=None,
        trial=None,
    ):

        self.model = model

        ## Can only train models with defined loss
        assert hasattr(self.model, "get_loss") and callable(
            self.model.get_loss
        ), "Model needs to have implemented a .get_loss method"

        ## Training parameters
        self.n_epochs = n_epochs
        self.batch_size = batch_size
        self.lr = lr
        self.beta_function = beta_function

        ## Path to saved trained model to
        self.checkpoint_path = checkpoint_path

        self.best_params = None
        self.best_loss = None

        self.init_tb_writer(tb_dir=tb_dir, tb_label=tb_label)
        self.trial = trial

    def init_tb_writer(self, tb_dir=None, tb_label=None):

        ## Label for tensorboard integration
        if tb_dir is None and tb_label is None:
            self.tb_writer = None
            return
        elif tb_dir is None and tb_label is not None:
            now = datetime.now().strftime("%b%d_%H-%M-%S")
            log_dir = os.path.join("runs", tb_label, now)
        elif tb_dir is not None:
            log_dir = tb_dir

        self.tb_writer = SummaryWriter(flush_secs=5, log_dir=log_dir)

    def train_setup(
        self,
        train_dataset,
        validation_dataset,
        random_state=123,
    ):

        if random_state is not None:
            torch.manual_seed(random_state)

        self.train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=self.batch_size,
        )
        self.validation_loader = torch.utils.data.DataLoader(
            validation_dataset,
            batch_size=self.batch_size,
        )

        self.train_loss = [None] * self.n_epochs
        self.validation_loss = [None] * self.n_epochs
        self.kl_divergence = [None] * self.n_epochs

        self.epoch_train_loss = [None] * len(self.train_loader)
        self.epoch_validation_loss = [None] * len(self.validation_loader)
        self.epoch_kl_divergence = [None] * len(self.validation_loader)

        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)

    def train(
        self,
        train_dataset,
        validation_dataset,
        random_state=123,
        progress_bar=True,
    ):

        self.train_setup(
            train_dataset=train_dataset,
            validation_dataset=validation_dataset,
            random_state=random_state,
        )

        iter_ = trange(self.n_epochs) if progress_bar else range(self.n_epochs)

        for epoch in iter_:
            self.per_epoch(epoch)

        self.after_training()

    def train_loop(self, epoch):

        beta = 1. if self.beta_function is None else self.beta_function(epoch)

        self.model.train()
        for i, batch in enumerate(self.train_loader):

            # Forward pass
            loss = self.model.get_loss(batch, beta=beta)
            self.optimizer.zero_grad()
            loss.backward()

            self.optimizer.step()
            self.epoch_train_loss[i] = loss.item()

    def validation_loop(self, epoch):

        self.model.eval()
        with torch.no_grad():
            for i, batch in enumerate(self.validation_loader):
                # Forward pass
                loss, kl_term = self.model.get_loss(batch, return_kl=True)
                self.epoch_validation_loss[i] = loss.item()
                self.epoch_kl_divergence[i] = kl_term.item()

    def per_epoch(self, epoch):

        # Perfom loops
        self.train_loop(epoch)
        self.validation_loop(epoch)

        ## Calculate statistics
        self.train_loss[epoch] = sum(self.epoch_train_loss) / len(
            self.epoch_train_loss)
        self.validation_loss[epoch] = sum(self.epoch_validation_loss) / len(
            self.epoch_validation_loss)
        self.kl_divergence[epoch] = sum(self.epoch_kl_divergence) / len(
            self.epoch_kl_divergence)

        ## Record best model
        if self.best_loss is None or self.validation_loss[
                epoch] < self.best_loss:
            self.best_loss = self.validation_loss[epoch]
            self.best_params = self.model.state_dict()

        ## Log to tensorboard
        if self.tb_writer is not None:
            self.tb_writer.add_scalar("Loss/Train", self.train_loss[epoch],
                                      epoch)
            self.tb_writer.add_scalar("Loss/Validation",
                                      self.validation_loss[epoch], epoch)
            self.tb_writer.add_scalar("Average KL-term",
                                      self.kl_divergence[epoch], epoch)

        if self.trial is not None:
            self.trial.report(self.validation_loss[epoch], epoch)
            if self.trial.should_prune():
                raise optuna.TrialPruned()

    def after_training(self):

        ## Log hyper parameters to tensorboard
        if self.tb_writer is not None:
            hparams = {
                "lr": self.lr,
                "batch_size": self.batch_size,
                "beta_0": self.beta_function.beta_0
            }
            hparams.update(get_hparams(self.model))
            metrics = {"neg_ELBO": min(self.validation_loss)}
            self.tb_writer.add_hparams(hparams, metrics)
            self.tb_writer.close()

        if self.checkpoint_path is not None:
            Path(self.checkpoint_path).parent.mkdir(exist_ok=True,
                                                    parents=True)
            torch.save(self.model.state_dict(), self.checkpoint_path)

            best_model_path = self.checkpoint_path.with_name(
                f"{self.checkpoint_path.stem}_best.pt")
            torch.save(self.best_params, best_model_path)
Example #20
0
def train(cfg, trail):
    global optim, optim_critic, config, writer, model, critic

    # cfg.lr_critic_real = trail.suggest_loguniform("lr_critic_real", 0.0001, 1)
    # cfg.lr_critic_fake = trail.suggest_loguniform("lr_critic_fake", 0.0001, 1)
    # cfg.lr_entropy = trail.suggest_loguniform("lr_entropy", 0.0001, 1)
    # cfg.lr_gen = trail.suggest_loguniform("lr_gen", 0.0001, 1)

    config = cfg

    writer = SummaryWriter(
        f"{BASE_DIR}/app/vision/iris_detector/L2/runs/iris_real_{cfg.lr_critic_real}_fake_{cfg.lr_critic_fake}_entropy_{cfg.lr_entropy}_gen_{cfg.lr_gen}_{int(datetime.now().timestamp())}"
    )

    train_loader = DataLoader(dataset,
                              batch_size=config.batch_size,
                              shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=config.batch_size)
    global_step = 0

    model = IrisGenerator()
    critic = IrisCritic()

    optim = torch.optim.Adam(model.parameters())
    optim_critic = torch.optim.Adam(critic.parameters())

    loss_entropy, loss_gen, loss_real, loss_fake = (0, 0, 0, 0)

    for epoch in range(config.epochs):

        for batch_i, (images, masks) in enumerate(train_loader):
            global_step += 1

            loss_entropy, loss_gen = train_gen(images, masks)
            loss_real, loss_fake = train_critic(images, masks)

            writer.add_scalar("loss_real", loss_real, global_step=global_step)
            writer.add_scalar("loss_fake", loss_fake, global_step=global_step)
            writer.add_scalar("loss_entropy",
                              loss_entropy,
                              global_step=global_step)
            writer.add_scalar("loss_gen", loss_gen, global_step=global_step)

            print(".", end="")
        print()

        _, (test_images, _) = next(enumerate(test_loader))
        with torch.no_grad():
            test_generated = model(test_images)
        plot_one_hot_mask(test_generated[0])

        # y = model(images)
        #
        # loss = criterion(y, masks.type(torch.LongTensor))
        # loss = torch.mean(loss * get_weight_map(masks))
        #
        # optim.zero_grad()
        # loss.backward()
        # optim.step()

        # plot_one_hot_mask(y)

    writer.close()
    final_loss = float(np.mean([loss_entropy, loss_gen, loss_real, loss_fake]))
    writer.add_hparams(
        {
            key: config[key]
            for key in
            ["lr_critic_real", "lr_critic_fake", "lr_entropy", "lr_gen"]
        },
        {"final_loss": final_loss},
    )
    return final_loss
Example #21
0
class Train(object):
    def __init__(self, args):
        self.args = args

        # -----------
        # Data folder
        # -----------
        data_folder = dataset.Flickr8kFolder(args.data_root)

        # -------------------
        # Building vocabulary
        # -------------------
        logging.info('Building vocabulary...')
        self.vocab = vocabulary.build_flickr8k_vocabulary(
            data_folder.ann_file, min_freq=args.vocab_min_freq)
        logging.debug('Vocabulary size: {}'.format(len(self.vocab)))

        # ----------
        # Transforms
        # ----------
        logging.info('Building transforms...')
        train_transforms = torchvision.transforms.Compose([
            torchvision.transforms.RandomResizedCrop(224),
            torchvision.transforms.RandomHorizontalFlip(),
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406),
                                             std=(0.229, 0.224, 0.225))
        ])

        val_transforms = torchvision.transforms.Compose([
            torchvision.transforms.Resize(256),
            torchvision.transforms.CenterCrop(224),
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406),
                                             std=(0.229, 0.224, 0.225))
        ])

        # --------
        # Datasets
        # --------
        logging.info('Building datasets...')
        flickr_trainset = dataset.Flickr8kDataset(
            data_folder,
            split='train',
            transform=train_transforms,
            target_transform=utils.Word2Idx(self.vocab))

        flickr_valset = dataset.Flickr8kDataset(
            data_folder,
            split='eval',
            transform=val_transforms,
            target_transform=utils.Word2Idx(self.vocab))

        # -----------
        # Data loader
        # -----------
        logging.info('Building data loader...')
        self.train_loader = torch.utils.data.DataLoader(
            flickr_trainset,
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=args.num_workers,
            collate_fn=dataset.flickr_collate_fn)

        self.val_loader = torch.utils.data.DataLoader(
            flickr_valset,
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=args.num_workers,
            collate_fn=dataset.flickr_collate_fn)

        # -------------
        # Builing model
        # -------------
        logging.info('Building model...')
        encoder = EncoderFactory.get_encoder(args.encoder_type,
                                             args.encoder_size)
        decoder = DecoderFactory.get_decoder(args.attention_type,
                                             args.embedding_size,
                                             len(self.vocab),
                                             args.encoder_size,
                                             encoder.num_pixels,
                                             args.hidden_size,
                                             args.attention_size)

        self.model = model.ImageCaptioningNet(encoder, decoder)
        self.model.to(args.device)

        # ------------------
        # Loss and optimizer
        # ------------------
        self.criterion = nn.CrossEntropyLoss(
            ignore_index=vocabulary.SpecialToken.PAD.value.index)

        # Only the parameters of the final encoder layer are being optimized.
        params = self.model.trainable_parameters()
        self.optimizer = optim.Adam(params, lr=args.learning_rate)

        # ------
        # Others
        # ------
        self.idx2word_fn = utils.IdxToWord(self.vocab)

    @property
    def hparams(self):
        return {
            'encoder_type': self.args.encoder_type,
            'attention_type': self.args.attention_type,
            'num_epochs': self.args.num_epochs,
            'batch_size': self.args.batch_size,
            'learning_rate': self.args.learning_rate,
            'vocab_min_freq': self.args.vocab_min_freq,
            'embedding_size': self.args.embedding_size,
            'hidden_size': self.args.hidden_size,
            'attention_size': self.args.attention_size
        }

    def _dummy_input(self):
        """
        Returns a tuple with a dummy input (random) for the model. This method is used to ease the
        call of the add_graph method of the tensorboard summary writer.
        """
        dummy_imgs = torch.randn(self.args.batch_size,
                                 3,
                                 224,
                                 224,
                                 dtype=torch.float32)
        dummy_caps = torch.randint(low=0,
                                   high=len(self.vocab) - 1,
                                   size=(self.args.batch_size,
                                         self.args.max_seq_length),
                                   dtype=torch.int64)
        dummy_lens = torch.randint(low=1,
                                   high=self.args.max_seq_length,
                                   size=(self.args.batch_size, ),
                                   dtype=torch.int64)
        dummy_lens, _ = torch.sort(dummy_lens, descending=True)

        return (dummy_imgs.to(self.args.device),
                dummy_caps.to(self.args.device), dummy_lens)

    def _compute_accuracy(self, predicted, target):
        """
        Computes accuracy based on BLEU.

            :param predicted: Predicted captions. Shape: (batch_size, max_length).
            :param target: Target captions. Shape: (batch_size, max_length).
            :returns: average of the bleu score of each predicted caption.
        """
        total_bleu = 0
        for predicted_cap, target_cap in zip(predicted, target):
            predicted_cap = self.idx2word_fn(predicted_cap.tolist())
            target_cap = self.idx2word_fn(target_cap.tolist())

            bleu = torchtext.data.metrics.bleu_score([predicted_cap],
                                                     [[target_cap]])
            total_bleu += bleu
        return (total_bleu / self.args.batch_size) * 100.0

    def _train_epoch(self, epoch):
        """
        Training step for one epoch.

            :param epoch: current epoch (int)
            :return: average of loss and accurancy for the current epoch.
        """
        self.model.train()

        total_loss = 0
        total_accuracy = 0
        for i, (data, train_caps, loss_caps,
                lengths) in enumerate(self.train_loader):

            imgs = data.to(self.args.device)  # (batch_size, channels, h, w)
            train_caps = train_caps.to(
                self.args.device)  # (batch_size, max_length)
            loss_caps = loss_caps.to(
                self.args.device)  # (batch_size, max_length)

            # 0. Clear gradients.
            self.optimizer.zero_grad()

            # 1. Forward the data through the network.
            out, _ = self.model(imgs, train_caps, lengths)

            # 2. Compute loss.
            loss = self.criterion(out.view(-1, len(self.vocab)),
                                  loss_caps.view(-1))

            # 3. Backprop with repsect to the loss function.
            loss.backward()

            # 4) Apply the optimizer with a learning step.
            self.optimizer.step()

            # 5. Computing loss and accuracy.
            _, predicted_caps = out.max(
                2)  # predicted_caps = (batch_size, max_length)
            loss_value = loss.item()
            acc_value = self._compute_accuracy(predicted_caps, loss_caps)

            if self.args.log_interval > 0 and i % self.args.log_interval == 0:
                print('Epoch [{}/{}] - [{}/{}] [TRAIN] Loss: {} | Acc: {}'.
                      format(epoch + 1, self.args.num_epochs, i + 1,
                             len(self.train_loader), loss_value, acc_value))

                # # Writing scalars to tensorboard.
                # step = epoch * (len(self.train_loader)) + i
                # self.writer.add_scalar('Loss/train', loss_value, step)
                # self.writer.add_scalar('Accuracy/train', acc_value, step)

            # Adding loss and accuracy to totals.
            total_loss += loss_value
            total_accuracy += acc_value

        return total_loss / len(self.train_loader), total_accuracy / (len(
            self.train_loader))

    def _validate_epoch(self, epoch):
        """
        Validation step for one epoch.

            :param epoch: current epoch (int)
            :return: average of loss and accurancy for the current epoch.
        """
        self.model.eval()

        with torch.no_grad():
            total_loss = 0
            total_accuracy = 0
            for i, (data, train_caps, loss_caps,
                    lengths) in enumerate(self.val_loader):

                imgs = data.to(
                    self.args.device)  # (batch_size, channels, h, w)
                train_caps = train_caps.to(
                    self.args.device)  # (batch_size, max_length)
                loss_caps = loss_caps.to(
                    self.args.device)  # (batch_size, max_length)

                # 1. Forward the data through the network.
                out, _ = self.model(imgs, train_caps, lengths)

                # 2. Compute loss.
                loss = self.criterion(out.view(-1, len(self.vocab)),
                                      loss_caps.view(-1))

                # 3. Computing loss and accuracy.
                _, predicted_caps = out.max(
                    2)  # predicted_caps = (batch_size, max_length)
                loss_value = loss.item()
                acc_value = self._compute_accuracy(predicted_caps, loss_caps)

                if self.args.log_interval > 0 and i % self.args.log_interval == 0:
                    print('Epoch [{}/{}] - [{}/{}] [EVAL] Loss: {} | Acc: {}'.
                          format(epoch + 1, self.args.num_epochs, i + 1,
                                 len(self.val_loader), loss_value, acc_value))

                    # # Writing scalars to tensorboard.
                    # step = epoch * (len(self.val_loader)) + i
                    # self.writer.add_scalar('Loss/eval', loss_value, step)
                    # self.writer.add_scalar('Accuracy/eval', acc_value, step)

                # Adding loss and accuracy to totals.
                total_loss += loss_value
                total_accuracy += acc_value

        return total_loss / len(self.val_loader), total_accuracy / (len(
            self.val_loader))

    def train(self):
        """
        Training loop
        """
        # Starting tensorboard writer.
        if self.args.log_interval > 0:
            if args.session_name is not None:
                self.writer = SummaryWriter(
                    os.path.join('runs', args.session_name))
            else:
                self.writer = SummaryWriter()

            self.writer.add_graph(self.model, self._dummy_input())

        for epoch in range(self.args.num_epochs):

            train_loss, train_acc = self._train_epoch(epoch)
            eval_loss, eval_acc = self._validate_epoch(epoch)

            # Log loss and accuracy
            self.writer.add_scalar('Loss/eval', eval_loss, epoch + 1)
            self.writer.add_scalar('Loss/train', train_loss, epoch + 1)
            self.writer.add_scalar('Accuracy/eval', eval_acc, epoch + 1)
            self.writer.add_scalar('Accuracy/train', train_acc, epoch + 1)

            # Save checkpoint of the model.
            if self.args.save_checkpoints:
                torch.save(
                    {
                        'epoch': epoch + 1,
                        'model_state_dict': self.model.state_dict(),
                        'optimizer_state_dict': self.optimizer.state_dict(),
                        'loss': train_loss
                    }, 'checkpoints/checkpoints-{}.tar'.format(epoch + 1))

        if self.args.log_interval > 0:
            logging.info('Logging hparams...')
            self.writer.add_hparams(
                self.hparams, {
                    'hparam/train-loss': train_loss,
                    'hparam/train-accuracy': train_acc,
                    'hparam/eval-loss': eval_loss,
                    'hparam/eval-train': eval_acc
                })

            logging.info('Logging embeddings...')
            self.writer.add_embedding(self.model.decoder.embedding.weight,
                                      metadata=self.vocab.get_words(),
                                      global_step=0)

            self.writer.close()

        if not self.args.no_save_model:
            model_name = self.args.session_name if self.args.session_name is not None else 'model'

            logging.info('Saving model as {}...'.format(model_name))
            torch.save(self.model.state_dict(),
                       os.path.join('models', model_name + '.pt'))

            logging.info('Saving arguments of the model...')
            arguments = {
                'encoder_type': self.args.encoder_type,
                'attention_type': self.args.attention_type,
                'vocab_min_freq': self.args.vocab_min_freq,
                'encoder_size': self.args.encoder_size,
                'hidden_size': self.args.hidden_size,
                'embedding_size': self.args.embedding_size,
                'attention_size': self.args.attention_size,
                'overfitting': self.args.overfitting
            }
            with open(os.path.join('models', model_name + '.json'), 'w') as f:
                json.dump(arguments, f)
Example #22
0
args = parser.parse_args()

experiment_id = '/' + str(args.experiment_id)
experiment_name = '/' + args.experiment_name if args.experiment_name else ''
model_path = 'tmp/' + args.model + experiment_name + experiment_id

config = config_mappings[args.model]
config = config._replace(hparams={**config.hparams, **args.hparams})

print(config, model_path)
train_data, test_data, model, opt = DSprites.setup(config,
                                                   iid=config.dataset['iid'])
writer = SummaryWriter(log_dir=model_path)
for epoch in range(config.model['epochs']):
    DSprites.train(model,
                   train_data,
                   epoch,
                   opt,
                   writer=writer,
                   verbose=True,
                   metrics_labels=config.model['metrics_labels'])

_, metrics = DSprites.test(model,
                           test_data,
                           verbose=True,
                           metrics_labels=config.model['metrics_labels'],
                           writer=writer)
torch.save(model.state_dict(), model_path + ".pt")
metrics_labels = ['hparam/' + x for x in config.model['metrics_labels']]
writer.add_hparams(hparam_dict=config.hparams,
                   metric_dict=dict(zip(metrics_labels, metrics)))
Example #23
0
class Trainer:
    """
    Trainer is a simple but feature-complete training and eval loop for PyTorch,
    optimized for Transformers.
    """

    model: PreTrainedModel
    args: TrainingArguments
    data_collator: DataCollator
    train_dataset: Optional[Dataset]
    eval_dataset: Optional[Dataset]
    compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None
    prediction_loss_only: bool
    tb_writer: Optional["SummaryWriter"] = None
    optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = None
    global_step: Optional[int] = None
    epoch: Optional[float] = None

    def __init__(
        self,
        model: PreTrainedModel,
        args: TrainingArguments,
        data_collator: Optional[DataCollator] = None,
        train_dataset: Optional[Dataset] = None,
        eval_dataset: Optional[Dataset] = None,
        compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
        prediction_loss_only=False,
        tb_writer: Optional["SummaryWriter"] = None,
        optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = None,
    ):
        """
        Trainer is a simple but feature-complete training and eval loop for PyTorch,
        optimized for Transformers.

        Args:
            prediction_loss_only:
                (Optional) in evaluation and prediction, only return the loss
        """
        self.model = model.to(args.device)
        self.args = args
        if data_collator is not None:
            self.data_collator = data_collator
        else:
            self.data_collator = DefaultDataCollator()
        self.train_dataset = train_dataset
        self.eval_dataset = eval_dataset
        self.compute_metrics = compute_metrics
        self.prediction_loss_only = prediction_loss_only
        self.optimizers = optimizers
        if tb_writer is not None:
            self.tb_writer = tb_writer
        elif is_tensorboard_available() and self.is_world_master():
            self.tb_writer = SummaryWriter(log_dir=self.args.logging_dir)
        if not is_tensorboard_available():
            logger.warning(
                "You are instantiating a Trainer but Tensorboard is not installed. You should consider installing it."
            )
        if is_wandb_available():
            self._setup_wandb()
        else:
            logger.info(
                "You are instantiating a Trainer but W&B is not installed. To use wandb logging, "
                "run `pip install wandb; wandb login` see https://docs.wandb.com/huggingface."
            )
        set_seed(self.args.seed)
        # Create output directory if needed
        if self.is_world_master():
            os.makedirs(self.args.output_dir, exist_ok=True)
        if is_tpu_available():
            # Set an xla_device flag on the model's config.
            # We'll find a more elegant and not need to do this in the future.
            self.model.config.xla_device = True

    def get_train_dataloader(self) -> DataLoader:
        if self.train_dataset is None:
            raise ValueError("Trainer: training requires a train_dataset.")
        if is_tpu_available():
            train_sampler = get_tpu_sampler(self.train_dataset)
        else:
            train_sampler = (
                RandomSampler(self.train_dataset)
                if self.args.local_rank == -1
                else DistributedSampler(self.train_dataset)
            )

        data_loader = DataLoader(
            self.train_dataset,
            batch_size=self.args.train_batch_size,
            sampler=train_sampler,
            collate_fn=self.data_collator.collate_batch,
        )

        return data_loader

    def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
        if eval_dataset is None and self.eval_dataset is None:
            raise ValueError("Trainer: evaluation requires an eval_dataset.")

        eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset

        if is_tpu_available():
            sampler = SequentialDistributedSampler(
                eval_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()
            )
        elif self.args.local_rank != -1:
            sampler = SequentialDistributedSampler(eval_dataset)
        else:
            sampler = SequentialSampler(eval_dataset)

        data_loader = DataLoader(
            eval_dataset,
            sampler=sampler,
            batch_size=self.args.eval_batch_size,
            collate_fn=self.data_collator.collate_batch,
        )

        return data_loader

    def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
        # We use the same batch_size as for eval.
        if is_tpu_available():
            sampler = SequentialDistributedSampler(
                test_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()
            )
        elif self.args.local_rank != -1:
            sampler = SequentialDistributedSampler(test_dataset)
        else:
            sampler = SequentialSampler(test_dataset)

        data_loader = DataLoader(
            test_dataset,
            sampler=sampler,
            batch_size=self.args.eval_batch_size,
            collate_fn=self.data_collator.collate_batch,
        )

        return data_loader

    def get_optimizers(
        self, num_training_steps: int
    ) -> Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]:
        """
        Setup the optimizer and the learning rate scheduler.

        We provide a reasonable default that works well.
        If you want to use something else, you can pass a tuple in the Trainer's init,
        or override this method in a subclass.
        """
        if self.optimizers is not None:
            return self.optimizers
        # Prepare optimizer and schedule (linear warmup and decay)
        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
                "weight_decay": self.args.weight_decay,
            },
            {
                "params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)],
                "weight_decay": 0.0,
            },
        ]
        optimizer = AdamW(optimizer_grouped_parameters, lr=self.args.learning_rate, eps=self.args.adam_epsilon)
        scheduler = get_linear_schedule_with_warmup(
            optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=num_training_steps
        )
        self.optimizers = optimizer, scheduler
        return optimizer, scheduler

    def _setup_wandb(self):
        """
        Setup the optional Weights & Biases (`wandb`) integration.

        One can override this method to customize the setup if needed.  Find more information at https://docs.wandb.com/huggingface
        You can also override the following environment variables:

        Environment:
            WANDB_WATCH:
                (Optional, ["gradients", "all", "false"]) "gradients" by default, set to "false" to disable gradient logging
                or "all" to log gradients and parameters
            WANDB_PROJECT:
                (Optional): str - "huggingface" by default, set this to a custom string to store results in a different project
            WANDB_DISABLED:
                (Optional): boolean - defaults to false, set to "true" to disable wandb entirely
        """
        logger.info('Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"')
        wandb.init(project=os.getenv("WANDB_PROJECT", "huggingface"), config=vars(self.args))
        # keep track of model topology and gradients
        if os.getenv("WANDB_WATCH") != "false":
            wandb.watch(
                self.model, log=os.getenv("WANDB_WATCH", "gradients"), log_freq=max(100, self.args.logging_steps)
            )

    def num_examples(self, dataloader: DataLoader) -> int:
        """
        Helper to get num of examples from a DataLoader, by accessing its Dataset.
        """
        return len(dataloader.dataset)

    def train(self, model_path: Optional[str] = None):
        """
        Main training entry point.

        Args:
            model_path:
                (Optional) Local path to model if model to train has been instantiated from a local path
                If present, we will try reloading the optimizer/scheduler states from there.
        """
        train_dataloader = self.get_train_dataloader()
        if self.args.max_steps > 0:
            t_total = self.args.max_steps
            num_train_epochs = (
                self.args.max_steps // (len(train_dataloader) // self.args.gradient_accumulation_steps) + 1
            )
        else:
            t_total = int(len(train_dataloader) // self.args.gradient_accumulation_steps * self.args.num_train_epochs)
            num_train_epochs = self.args.num_train_epochs

        optimizer, scheduler = self.get_optimizers(num_training_steps=t_total)

        # Check if saved optimizer or scheduler states exist
        if (
            model_path is not None
            and os.path.isfile(os.path.join(model_path, "optimizer.pt"))
            and os.path.isfile(os.path.join(model_path, "scheduler.pt"))
        ):
            # Load in optimizer and scheduler states
            optimizer.load_state_dict(
                torch.load(os.path.join(model_path, "optimizer.pt"), map_location=self.args.device)
            )
            scheduler.load_state_dict(torch.load(os.path.join(model_path, "scheduler.pt")))

        model = self.model
        if self.args.fp16:
            if not is_apex_available():
                raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
            model, optimizer = amp.initialize(model, optimizer, opt_level=self.args.fp16_opt_level)

        # multi-gpu training (should be after apex fp16 initialization)
        if self.args.n_gpu > 1:
            model = torch.nn.DataParallel(model)

        # Distributed training (should be after apex fp16 initialization)
        if self.args.local_rank != -1:
            model = torch.nn.parallel.DistributedDataParallel(
                model,
                device_ids=[self.args.local_rank],
                output_device=self.args.local_rank,
                find_unused_parameters=True,
            )

        if self.tb_writer is not None:
            self.tb_writer.add_text("args", self.args.to_json_string())
            self.tb_writer.add_hparams(self.args.to_sanitized_dict(), metric_dict={})

        # Train!
        if is_tpu_available():
            total_train_batch_size = self.args.train_batch_size * xm.xrt_world_size()
        else:
            total_train_batch_size = (
                self.args.train_batch_size
                * self.args.gradient_accumulation_steps
                * (torch.distributed.get_world_size() if self.args.local_rank != -1 else 1)
            )
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", self.num_examples(train_dataloader))
        logger.info("  Num Epochs = %d", num_train_epochs)
        logger.info("  Instantaneous batch size per device = %d", self.args.per_gpu_train_batch_size)
        logger.info("  Total train batch size (w. parallel, distributed & accumulation) = %d", total_train_batch_size)
        logger.info("  Gradient Accumulation steps = %d", self.args.gradient_accumulation_steps)
        logger.info("  Total optimization steps = %d", t_total)

        self.global_step = 0
        self.epoch = 0
        epochs_trained = 0
        steps_trained_in_current_epoch = 0
        # Check if continuing training from a checkpoint
        if model_path is not None:
            # set global_step to global_step of last saved checkpoint from model path
            try:
                self.global_step = int(model_path.split("-")[-1].split("/")[0])
                epochs_trained = self.global_step // (len(train_dataloader) // self.args.gradient_accumulation_steps)
                steps_trained_in_current_epoch = self.global_step % (
                    len(train_dataloader) // self.args.gradient_accumulation_steps
                )

                logger.info("  Continuing training from checkpoint, will skip to saved global_step")
                logger.info("  Continuing training from epoch %d", epochs_trained)
                logger.info("  Continuing training from global step %d", self.global_step)
                logger.info("  Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
            except ValueError:
                self.global_step = 0
                logger.info("  Starting fine-tuning.")

        tr_loss = 0.0
        logging_loss = 0.0
        best_score = 0.0
        model.zero_grad()
        train_iterator = trange(
            epochs_trained, int(num_train_epochs), desc="Epoch", disable=not self.is_local_master()
        )
        for epoch in train_iterator:
            if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
                train_dataloader.sampler.set_epoch(epoch)

            if is_tpu_available():
                parallel_loader = pl.ParallelLoader(train_dataloader, [self.args.device]).per_device_loader(
                    self.args.device
                )
                epoch_iterator = tqdm(parallel_loader, desc="Iteration", disable=not self.is_local_master())
            else:
                epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=not self.is_local_master())

            for step, inputs in enumerate(epoch_iterator):

                # Skip past any already trained steps if resuming training
                if steps_trained_in_current_epoch > 0:
                    steps_trained_in_current_epoch -= 1
                    continue

                tr_loss += self._training_step(model, inputs, optimizer)

                if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
                    # last step in epoch but step is always smaller than gradient_accumulation_steps
                    len(epoch_iterator) <= self.args.gradient_accumulation_steps
                    and (step + 1) == len(epoch_iterator)
                ):
                    if self.args.fp16:
                        torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), self.args.max_grad_norm)
                    else:
                        torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm)

                    if not hasattr(self.optimizers[0], 'accumulate_grad'):
                        if is_tpu_available():
                            xm.optimizer_step(optimizer)
                        else:
                            optimizer.step()
                        scheduler.step()

                    model.zero_grad()
                    self.global_step += 1
                    self.epoch = epoch + (step + 1) / len(epoch_iterator)

                    if (self.args.logging_steps > 0 and self.global_step % self.args.logging_steps == 0) or (
                        self.global_step == 1 and self.args.logging_first_step
                    ):
                        logs: Dict[str, float] = {}
                        logs["loss"] = (tr_loss - logging_loss) / self.args.logging_steps
                        # backward compatibility for pytorch schedulers
                        logs["learning_rate"] = (
                            scheduler.get_last_lr()[0]
                            if version.parse(torch.__version__) >= version.parse("1.4")
                            else scheduler.get_lr()[0]
                        )
                        logging_loss = tr_loss

                        self._log(logs)

                        if self.args.evaluate_during_training:
                            results = self.evaluate()
                            for key, value in results.items():
                                eval_key = "eval_{}".format(key)
                                if 'acc' in eval_key or 'mcc' in eval_key or 'corr' in eval_key:
                                    if hasattr(self.model, 'classifiers') and int(key.split('_')[-1]) < len(self.model.classifiers)-1:
                                        continue
                                    if best_score<value:
                                        best_score = value
                                        self.save_model(self.args.output_dir)
                                    self._log({'best_score': best_score})

                    if self.args.save_steps > 0 and self.global_step % self.args.save_steps == 0:
                        # In all cases (even distributed/parallel), self.model is always a reference
                        # to the model we want to save.
                        if hasattr(model, "module"):
                            assert model.module is self.model
                        else:
                            assert model is self.model
                        # Save model checkpoint
                        output_dir = os.path.join(self.args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.global_step}")

                        self.save_model(output_dir)

                        if self.is_world_master():
                            self._rotate_checkpoints()

                        if is_tpu_available():
                            xm.rendezvous("saving_optimizer_states")
                            xm.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
                            xm.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
                        elif self.is_world_master():
                            torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
                            torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))

                if self.args.max_steps > 0 and self.global_step > self.args.max_steps:
                    epoch_iterator.close()
                    break
            if self.args.max_steps > 0 and self.global_step > self.args.max_steps:
                train_iterator.close()
                break
            if self.args.tpu_metrics_debug:
                # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
                xm.master_print(met.metrics_report())

        if self.tb_writer:
            self.tb_writer.close()

        logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
        return TrainOutput(self.global_step, tr_loss / self.global_step), best_score

    def _log(self, logs: Dict[str, float], iterator: Optional[tqdm] = None) -> None:
        if self.epoch is not None:
            logs["epoch"] = self.epoch
        if self.tb_writer:
            for k, v in logs.items():
                self.tb_writer.add_scalar(k, v, self.global_step)
        if is_wandb_available():
            wandb.log(logs, step=self.global_step)
        output = json.dumps({**logs, **{"step": self.global_step}})
        if iterator is not None:
            iterator.write(output)
        else:
            print(output)

    def _training_step(
        self, model: nn.Module, inputs: Dict[str, torch.Tensor], optimizer: torch.optim.Optimizer
    ) -> float:
        model.train()
        for k, v in inputs.items():
            inputs[k] = v.to(self.args.device)

        outputs = model(**inputs)
        loss = outputs[0]  # model outputs are always tuple in transformers (see doc)

        if self.args.n_gpu > 1:
            loss = loss.mean()  # mean() to average on multi-gpu parallel training
        if self.args.gradient_accumulation_steps > 1:
            loss = loss / self.args.gradient_accumulation_steps

        if self.args.fp16:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()
        if hasattr(self.optimizers[0], 'accumulate_grad'):
            self.optimizers[0].accumulate_grad()

        return loss.item()

    def is_local_master(self) -> bool:
        if is_tpu_available():
            return xm.is_master_ordinal(local=True)
        else:
            return self.args.local_rank in [-1, 0]

    def is_world_master(self) -> bool:
        """
        This will be True only in one process, even in distributed mode,
        even when training on multiple machines.
        """
        if is_tpu_available():
            return xm.is_master_ordinal(local=False)
        else:
            return self.args.local_rank == -1 or torch.distributed.get_rank() == 0

    def save_model(self, output_dir: Optional[str] = None):
        """
        Saving best-practices: if you use default names for the model,
        you can reload it using from_pretrained().

        Will only save from the world_master process (unless in TPUs).
        """

        if is_tpu_available():
            self._save_tpu(output_dir)
        elif self.is_world_master():
            self._save(output_dir)

    def _save_tpu(self, output_dir: Optional[str] = None):
        output_dir = output_dir if output_dir is not None else self.args.output_dir
        logger.info("Saving model checkpoint to %s", output_dir)

        if xm.is_master_ordinal():
            os.makedirs(output_dir, exist_ok=True)
            torch.save(self.args, os.path.join(output_dir, "training_args.bin"))

        # Save a trained model and configuration using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
        if not isinstance(self.model, PreTrainedModel):
            raise ValueError("Trainer.model appears to not be a PreTrainedModel")

        xm.rendezvous("saving_checkpoint")
        self.model.save_pretrained(output_dir)

    def _save(self, output_dir: Optional[str] = None):
        output_dir = output_dir if output_dir is not None else self.args.output_dir
        os.makedirs(output_dir, exist_ok=True)
        logger.info("Saving model checkpoint to %s", output_dir)
        # Save a trained model and configuration using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
        if not isinstance(self.model, PreTrainedModel):
            raise ValueError("Trainer.model appears to not be a PreTrainedModel")
        self.model.save_pretrained(output_dir)

        # Good practice: save your training arguments together with the trained model
        #torch.save(self.args, os.path.join(output_dir, "training_args.bin"))

    def _sorted_checkpoints(self, checkpoint_prefix=PREFIX_CHECKPOINT_DIR, use_mtime=False) -> List[str]:
        ordering_and_checkpoint_path = []

        glob_checkpoints = [str(x) for x in Path(self.args.output_dir).glob(f"{checkpoint_prefix}-*")]

        for path in glob_checkpoints:
            if use_mtime:
                ordering_and_checkpoint_path.append((os.path.getmtime(path), path))
            else:
                regex_match = re.match(f".*{checkpoint_prefix}-([0-9]+)", path)
                if regex_match and regex_match.groups():
                    ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))

        checkpoints_sorted = sorted(ordering_and_checkpoint_path)
        checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
        return checkpoints_sorted

    def _rotate_checkpoints(self, use_mtime=False) -> None:
        if self.args.save_total_limit is None or self.args.save_total_limit <= 0:
            return

        # Check if we should delete older checkpoint(s)
        checkpoints_sorted = self._sorted_checkpoints(use_mtime=use_mtime)
        if len(checkpoints_sorted) <= self.args.save_total_limit:
            return

        number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - self.args.save_total_limit)
        checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]
        for checkpoint in checkpoints_to_be_deleted:
            logger.info("Deleting older checkpoint [{}] due to args.save_total_limit".format(checkpoint))
            shutil.rmtree(checkpoint)

    def evaluate(
        self, eval_dataset: Optional[Dataset] = None, prediction_loss_only: Optional[bool] = None,
    ) -> Dict[str, float]:
        """
        Run evaluation and return metrics.

        The calling script will be responsible for providing a method to compute metrics, as they are
        task-dependent.

        Args:
            eval_dataset: (Optional) Pass a dataset if you wish to override
            the one on the instance.
        Returns:
            A dict containing:
                - the eval loss
                - the potential metrics computed from the predictions
        """
        eval_dataloader = self.get_eval_dataloader(eval_dataset)

        output = self._prediction_loop(eval_dataloader, description="Evaluation")

        self._log(output.metrics)

        if self.args.tpu_metrics_debug:
            # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
            xm.master_print(met.metrics_report())

        return output.metrics

    def predict(self, test_dataset: Dataset) -> PredictionOutput:
        """
        Run prediction and return predictions and potential metrics.

        Depending on the dataset and your use case, your test dataset may contain labels.
        In that case, this method will also return metrics, like in evaluate().
        """
        test_dataloader = self.get_test_dataloader(test_dataset)

        return self._prediction_loop(test_dataloader, description="Prediction")

    def _prediction_loop(
        self, dataloader: DataLoader, description: str, prediction_loss_only: Optional[bool] = None
    ) -> PredictionOutput:
        """
        Prediction/evaluation loop, shared by `evaluate()` and `predict()`.

        Works both with or without labels.
        """

        prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else self.prediction_loss_only

        model = self.model
        # multi-gpu eval
        if self.args.n_gpu > 1:
            model = torch.nn.DataParallel(model)
        else:
            model = self.model
        # Note: in torch.distributed mode, there's no point in wrapping the model
        # inside a DistributedDataParallel as we'll be under `no_grad` anyways.

        batch_size = dataloader.batch_size
        logger.info("***** Running %s *****", description)
        logger.info("  Num examples = %d", self.num_examples(dataloader))
        logger.info("  Batch size = %d", batch_size)
        eval_losses: List[float] = []
        preds: torch.Tensor = None
        label_ids: torch.Tensor = None
        model.eval()

        if is_tpu_available():
            dataloader = pl.ParallelLoader(dataloader, [self.args.device]).per_device_loader(self.args.device)

        for inputs in tqdm(dataloader, desc=description):
            has_labels = any(inputs.get(k) is not None for k in ["labels", "lm_labels", "masked_lm_labels"])

            for k, v in inputs.items():
                inputs[k] = v.to(self.args.device)

            with torch.no_grad():
                outputs = model(**inputs)
                if has_labels:
                    step_eval_loss, logits = outputs[:2]
                    eval_losses += [step_eval_loss.mean().item()]
                else:
                    logits = outputs[0]

            if not prediction_loss_only:
                if preds is None:
                    preds = logits.detach()
                else:
                    preds = torch.cat((preds, logits.detach()), dim=0)
                if inputs.get("labels") is not None:
                    if label_ids is None:
                        label_ids = inputs["labels"].detach()
                    else:
                        label_ids = torch.cat((label_ids, inputs["labels"].detach()), dim=0)

        if self.args.local_rank != -1:
            # In distributed mode, concatenate all results from all nodes:
            if preds is not None:
                preds = self.distributed_concat(preds, num_total_examples=self.num_examples(dataloader))
            if label_ids is not None:
                label_ids = self.distributed_concat(label_ids, num_total_examples=self.num_examples(dataloader))
        elif is_tpu_available():
            # tpu-comment: Get all predictions and labels from all worker shards of eval dataset
            if preds is not None:
                preds = xm.mesh_reduce("eval_preds", preds, torch.cat)
            if label_ids is not None:
                label_ids = xm.mesh_reduce("eval_label_ids", label_ids, torch.cat)

        # Finally, turn the aggregated tensors into numpy arrays.
        if preds is not None:
            preds = preds.cpu().numpy()
        if label_ids is not None:
            label_ids = label_ids.cpu().numpy()

        if self.compute_metrics is not None and preds is not None and label_ids is not None:
            metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids))
        else:
            metrics = {}
        if len(eval_losses) > 0:
            metrics["eval_loss"] = np.mean(eval_losses)

        # Prefix all keys with eval_
        for key in list(metrics.keys()):
            if not key.startswith("eval_"):
                metrics[f"eval_{key}"] = metrics.pop(key)

        return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics)

    def distributed_concat(self, tensor: torch.Tensor, num_total_examples: int) -> torch.Tensor:
        assert self.args.local_rank != -1

        output_tensors = [tensor.clone() for _ in range(torch.distributed.get_world_size())]
        torch.distributed.all_gather(output_tensors, tensor)

        concat = torch.cat(output_tensors, dim=0)

        # truncate the dummy elements added by SequentialDistributedSampler
        output = concat[:num_total_examples]
        return output
Example #24
0
def main(args):
    dataset_name = args.dataset
    model_name = args.model
    n_inner_iter = args.adaptation_steps
    meta_learning_rate = args.meta_learning_rate
    learning_rate = args.learning_rate
    batch_size = args.batch_size
    save_model_file = args.save_model_file
    load_model_file = args.load_model_file
    lower_trial = args.lower_trial
    upper_trial = args.upper_trial
    task_size = args.task_size
    noise_level = args.noise_level
    noise_type = args.noise_type
    epochs = args.epochs
    loss_fcn_str = args.loss
    modulate_task_net = args.modulate_task_net
    weight_vrae = args.weight_vrae
    stopping_patience = args.stopping_patience

    meta_info = {"POLLUTION": [5, 14], "HR": [32, 13], "BATTERY": [20, 3]}

    assert model_name in ("FCN", "LSTM"), "Model was not correctly specified"
    assert dataset_name in ("POLLUTION", "HR", "BATTERY")

    window_size, input_dim = meta_info[dataset_name]

    grid = [0., noise_level]

    train_data_ML = pickle.load(
        open(
            "../../Data/TRAIN-" + dataset_name + "-W" + str(window_size) +
            "-T" + str(task_size) + "-ML.pickle", "rb"))
    validation_data_ML = pickle.load(
        open(
            "../../Data/VAL-" + dataset_name + "-W" + str(window_size) + "-T" +
            str(task_size) + "-ML.pickle", "rb"))
    test_data_ML = pickle.load(
        open(
            "../../Data/TEST-" + dataset_name + "-W" + str(window_size) +
            "-T" + str(task_size) + "-ML.pickle", "rb"))

    total_tasks = len(train_data_ML)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    loss_fn = mae if loss_fcn_str == "MAE" else nn.SmoothL1Loss()

    ##multimodal learner parameters
    # paramters wto increase capactiy of the model
    n_layers_task_net = 2
    n_layers_task_encoder = 2
    n_layers_task_decoder = 2

    hidden_dim_task_net = 120
    hidden_dim_encoder = 120
    hidden_dim_decoder = 120

    # fixed values
    input_dim_task_net = input_dim
    input_dim_task_encoder = input_dim + 1
    output_dim_task_net = 1
    output_dim_task_decoder = input_dim + 1

    first_order = False
    inner_loop_grad_clip = 20

    for trial in range(lower_trial, upper_trial):

        output_directory = "../../Models/" + dataset_name + "_" + model_name + "_MMAML/" + str(
            trial) + "/"
        save_model_file_ = output_directory + save_model_file
        save_model_file_encoder = output_directory + "encoder_" + save_model_file
        load_model_file_ = output_directory + load_model_file
        checkpoint_file = output_directory + "checkpoint_" + save_model_file.split(
            ".")[0]

        writer = SummaryWriter()

        try:
            os.mkdir(output_directory)
        except OSError as error:
            print(error)

        task_net = LSTMModel(batch_size=batch_size,
                             seq_len=window_size,
                             input_dim=input_dim_task_net,
                             n_layers=n_layers_task_net,
                             hidden_dim=hidden_dim_task_net,
                             output_dim=output_dim_task_net)

        task_encoder = LSTMModel(batch_size=batch_size,
                                 seq_len=task_size,
                                 input_dim=input_dim_task_encoder,
                                 n_layers=n_layers_task_encoder,
                                 hidden_dim=hidden_dim_encoder,
                                 output_dim=1)

        task_decoder = LSTMDecoder(batch_size=1,
                                   n_layers=n_layers_task_decoder,
                                   seq_len=task_size,
                                   output_dim=output_dim_task_decoder,
                                   hidden_dim=hidden_dim_encoder,
                                   latent_dim=hidden_dim_decoder,
                                   device=device)

        lmbd = Lambda(hidden_dim_encoder, hidden_dim_task_net)

        multimodal_learner = MultimodalLearner(task_net, task_encoder,
                                               task_decoder, lmbd,
                                               modulate_task_net)
        multimodal_learner.to(device)

        output_layer = LinearModel(120, 1)
        opt = torch.optim.Adam(list(multimodal_learner.parameters()) +
                               list(output_layer.parameters()),
                               lr=meta_learning_rate)

        meta_learner = MetaLearner(output_layer, opt, learning_rate, loss_fn,
                                   first_order, n_inner_iter,
                                   inner_loop_grad_clip, device)

        early_stopping = EarlyStopping(patience=stopping_patience,
                                       model_file=save_model_file_,
                                       verbose=True)
        early_stopping_encoder = EarlyStopping(
            patience=stopping_patience,
            model_file=save_model_file_encoder,
            verbose=True)

        task_data_train = torch.FloatTensor(
            get_task_encoder_input(train_data_ML))
        task_data_validation = torch.FloatTensor(
            get_task_encoder_input(validation_data_ML))
        task_data_test = torch.FloatTensor(
            get_task_encoder_input(test_data_ML))

        val_loss_hist = []
        test_loss_hist = []

        for epoch in range(epochs):

            multimodal_learner.train()

            batch_idx = np.random.randint(0, total_tasks - 1, batch_size)
            task = task_data_train[batch_idx].cuda()

            x_spt, y_spt = train_data_ML[batch_idx]
            x_qry, y_qry = train_data_ML[batch_idx + 1]

            x_spt, y_spt = to_torch(x_spt), to_torch(y_spt)
            x_qry = to_torch(x_qry)
            y_qry = to_torch(y_qry)

            # data augmentation
            epsilon = grid[np.random.randint(0, len(grid))]

            if noise_type == "additive":
                y_spt = y_spt + epsilon
                y_qry = y_qry + epsilon
            else:
                y_spt = y_spt * (1 + epsilon)
                y_qry = y_qry * (1 + epsilon)

            x_spt_encodings = []
            x_qry_encodings = []
            vrae_loss_accum = 0.0
            for i in range(batch_size):
                x_spt_encoding, (vrae_loss, kl_loss,
                                 rec_loss) = multimodal_learner(
                                     x_spt[i],
                                     task[i:i + 1],
                                     output_encoding=True)
                x_spt_encodings.append(x_spt_encoding)
                vrae_loss_accum += vrae_loss

                x_qry_encoding, _ = multimodal_learner(x_qry[i],
                                                       task[i:i + 1],
                                                       output_encoding=True)
                x_qry_encodings.append(x_qry_encoding)

            train_tasks = [
                Task(x_spt_encodings[i], y_spt[i])
                for i in range(x_spt.shape[0])
            ]
            val_tasks = [
                Task(x_qry_encodings[i], y_qry[i])
                for i in range(x_qry.shape[0])
            ]

            # print(vrae_loss)

            adapted_params = meta_learner.adapt(train_tasks)
            mean_loss = meta_learner.step(adapted_params,
                                          val_tasks,
                                          is_training=True,
                                          additional_loss_term=weight_vrae *
                                          vrae_loss_accum / batch_size)

            ##plotting grad of output layer
            for tag, parm in output_layer.linear.named_parameters():
                writer.add_histogram("Grads_output_layer_" + tag,
                                     parm.grad.data.cpu().numpy(), epoch)

            multimodal_learner.eval()
            val_loss = test(validation_data_ML, multimodal_learner,
                            meta_learner, task_data_validation)
            test_loss = test(test_data_ML, multimodal_learner, meta_learner,
                             task_data_test)

            print("Epoch:", epoch)
            print("Train loss:", mean_loss)
            print("Val error:", val_loss)
            print("Test error:", test_loss)

            early_stopping(val_loss, meta_learner)
            early_stopping_encoder(val_loss, multimodal_learner)

            val_loss_hist.append(val_loss)
            test_loss_hist.append(test_loss)

            if early_stopping.early_stop:
                print("Early stopping")
                break

            writer.add_scalar("Loss/train",
                              mean_loss.cpu().detach().numpy(), epoch)
            writer.add_scalar("Loss/val", val_loss, epoch)
            writer.add_scalar("Loss/test", test_loss, epoch)

        multimodal_learner.load_state_dict(torch.load(save_model_file_encoder))
        output_layer.load_state_dict(
            torch.load(save_model_file_)["model_state_dict"])
        meta_learner = MetaLearner(output_layer, opt, learning_rate, loss_fn,
                                   first_order, n_inner_iter,
                                   inner_loop_grad_clip, device)

        val_loss = test(validation_data_ML, multimodal_learner, meta_learner,
                        task_data_validation)
        test_loss = test(test_data_ML, multimodal_learner, meta_learner,
                         task_data_test)

        with open(output_directory + "/results3.txt", "a+") as f:
            f.write("Dataset :%s \n" % dataset_name)
            f.write("Test error: %f \n" % test_loss)
            f.write("Val error: %f \n" % val_loss)
            f.write("\n")

        writer.add_hparams(
            {
                "fast_lr": learning_rate,
                "slow_lr": meta_learning_rate,
                "adaption_steps": n_inner_iter,
                "patience": stopping_patience,
                "weight_vrae": weight_vrae,
                "noise_level": noise_level,
                "dataset": dataset_name,
                "trial": trial
            }, {
                "val_loss": val_loss,
                "test_loss": test_loss
            })
Example #25
0
class Trainer:
    """
    Trainer is a simple but feature-complete training and eval loop for PyTorch,
    optimized for 🤗 Transformers.

    Args:
        model (:class:`~transformers.PreTrainedModel`):
            The model to train, evaluate or use for predictions.
        args (:class:`~transformers.TrainingArguments`):
            The arguments to tweak training.
        data_collator (:obj:`DataCollator`, `optional`, defaults to :func:`~transformers.default_data_collator`):
            The function to use to from a batch from a list of elements of :obj:`train_dataset` or
            :obj:`eval_dataset`.
        train_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
            The dataset to use for training.
        eval_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
            The dataset to use for evaluation.
        compute_metrics (:obj:`Callable[[EvalPrediction], Dict]`, `optional`):
            The function that will be used to compute metrics at evaluation. Must take a
            :class:`~transformers.EvalPrediction` and return a dictionary string to metric values.
        prediction_loss_only (:obj:`bool`, `optional`, defaults to `False`):
            When performing evaluation and predictions, only returns the loss.
        tb_writer (:obj:`SummaryWriter`, `optional`):
            Object to write to TensorBoard.
        optimizers (:obj:`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR`, `optional`):
            A tuple containing the optimizer and the scheduler to use. Will default to an instance of
            :class:`~transformers.AdamW` on your model and a scheduler given by
            :func:`~transformers.get_linear_schedule_with_warmup` controlled by :obj:`args`.
        kwargs:
            Deprecated keyword arguments.
    """

    def __init__(
        self,
        model: PreTrainedModel,
        args: TrainingArguments,
        data_collator: Optional[DataCollator] = None,
        train_dataset: Optional[Dataset] = None,
        eval_dataset: Optional[Dataset] = None,
        compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
        tb_writer: Optional["SummaryWriter"] = None,
        optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
        **kwargs,
    ):
        self.model = model.to(args.device)
        self.args = args
        self.data_collator = data_collator if data_collator is not None else default_data_collator
        self.train_dataset = train_dataset
        self.eval_dataset = eval_dataset
        self.compute_metrics = compute_metrics
        self.optimizer, self.lr_scheduler = optimizers
        self.tb_writer = tb_writer
        if "prediction_loss_only" in kwargs:
            warnings.warn(
                "Passing `prediction_loss_only` as a keyword argument is deprecated and won't be possible in a future version. Use `args.prediction_loss_only` instead.",
                FutureWarning,
            )
            self.args.prediction_loss_only = kwargs.pop("prediction_loss_only")
        assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."

        if tb_writer is None and is_tensorboard_available() and self.is_world_process_zero():
            self.tb_writer = SummaryWriter(log_dir=self.args.logging_dir)
        if not is_tensorboard_available():
            logger.warning(
                "You are instantiating a Trainer but Tensorboard is not installed. You should consider installing it."
            )
        if is_wandb_available():
            self.setup_wandb()
        elif os.environ.get("WANDB_DISABLED") != "true":
            logger.info(
                "You are instantiating a Trainer but W&B is not installed. To use wandb logging, "
                "run `pip install wandb; wandb login` see https://docs.wandb.com/huggingface."
            )
        if is_comet_available():
            self.setup_comet()
        elif os.environ.get("COMET_MODE") != "DISABLED":
            logger.info(
                "To use comet_ml logging, run `pip/conda install comet_ml` "
                "see https://www.comet.ml/docs/python-sdk/huggingface/"
            )
        set_seed(self.args.seed)
        # Create output directory if needed
        if self.is_world_process_zero():
            os.makedirs(self.args.output_dir, exist_ok=True)
        if is_torch_tpu_available():
            # Set an xla_device flag on the model's config.
            # We'll find a more elegant and not need to do this in the future.
            self.model.config.xla_device = True
        if not callable(self.data_collator) and callable(getattr(self.data_collator, "collate_batch", None)):
            self.data_collator = self.data_collator.collate_batch
            warnings.warn(
                (
                    "The `data_collator` should now be a simple callable (function, class with `__call__`), classes "
                    + "with a `collate_batch` are deprecated and won't be supported in a future version."
                ),
                FutureWarning,
            )
        self.global_step = None
        self.epoch = None
        if self.args.fp16 and _use_native_amp:
            self.scaler = torch.cuda.amp.GradScaler()

    def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
        if isinstance(self.train_dataset, torch.utils.data.IterableDataset):
            return None
        elif is_torch_tpu_available():
            return get_tpu_sampler(self.train_dataset)
        else:
            return (
                RandomSampler(self.train_dataset)
                if self.args.local_rank == -1
                else DistributedSampler(self.train_dataset)
            )

    def get_train_dataloader(self) -> DataLoader:
        """
        Returns the training :class:`~torch.utils.data.DataLoader`.

        Will use no sampler if :obj:`self.train_dataset` is a :obj:`torch.utils.data.IterableDataset`, a random sampler
        (adapted to distributed training if necessary) otherwise.

        Subclass and override this method if you want to inject some custom behavior.
        """
        if self.train_dataset is None:
            raise ValueError("Trainer: training requires a train_dataset.")
        train_sampler = self._get_train_sampler()

        return DataLoader(
            self.train_dataset,
            batch_size=self.args.train_batch_size,
            sampler=train_sampler,
            collate_fn=self.data_collator,
            drop_last=self.args.dataloader_drop_last,
        )

    def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.sampler.Sampler]:
        if isinstance(eval_dataset, torch.utils.data.IterableDataset):
            return None
        elif is_torch_tpu_available():
            return SequentialDistributedSampler(eval_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())
        elif self.args.local_rank != -1:
            return SequentialDistributedSampler(eval_dataset)
        else:
            return SequentialSampler(eval_dataset)

    def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
        """
        Returns the evaluation :class:`~torch.utils.data.DataLoader`.

        Will use no sampler if :obj:`self.eval_dataset` is a :obj:`torch.utils.data.IterableDataset`, a sequential
        sampler (adapted to distributed training if necessary) otherwise.

        Subclass and override this method if you want to inject some custom behavior.

        Args:
            eval_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
                If provided, will override :obj:`self.eval_dataset`.
        """
        if eval_dataset is None and self.eval_dataset is None:
            raise ValueError("Trainer: evaluation requires an eval_dataset.")

        eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
        eval_sampler = self._get_eval_sampler(eval_dataset)

        return DataLoader(
            eval_dataset,
            sampler=eval_sampler,
            batch_size=self.args.eval_batch_size,
            collate_fn=self.data_collator,
            drop_last=self.args.dataloader_drop_last,
        )

    def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
        """
        Returns the test :class:`~torch.utils.data.DataLoader`.

        Will use no sampler if :obj:`test_dataset` is a :obj:`torch.utils.data.IterableDataset`, a sequential
        sampler (adapted to distributed training if necessary) otherwise.

        Subclass and override this method if you want to inject some custom behavior.

        Args:
            eval_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
                The test dataset to use.
        """
        test_sampler = self._get_eval_sampler(test_dataset)

        # We use the same batch_size as for eval.
        return DataLoader(
            test_dataset,
            sampler=test_sampler,
            batch_size=self.args.eval_batch_size,
            collate_fn=self.data_collator,
            drop_last=self.args.dataloader_drop_last,
        )

    def create_optimizer_and_scheduler(self, num_training_steps: int):
        """
        Setup the optimizer and the learning rate scheduler.

        We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
        Trainer's init through :obj:`optimizers`, or subclass and override this method in a subclass.
        """
        if self.optimizer is None:
            no_decay = ["bias", "LayerNorm.weight"]
            optimizer_grouped_parameters = [
                {
                    "params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
                    "weight_decay": self.args.weight_decay,
                },
                {
                    "params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)],
                    "weight_decay": 0.0,
                },
            ]
            self.optimizer = AdamW(
                optimizer_grouped_parameters,
                lr=self.args.learning_rate,
                betas=(self.args.adam_beta1, self.args.adam_beta2),
                eps=self.args.adam_epsilon,
            )
        if self.lr_scheduler is None:
            self.lr_scheduler = get_linear_schedule_with_warmup(
                self.optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=num_training_steps
            )

    def setup_wandb(self):
        """
        Setup the optional Weights & Biases (`wandb`) integration.

        One can subclass and override this method to customize the setup if needed. Find more information
        `here <https://docs.wandb.com/huggingface>`__. You can also override the following environment variables:

        Environment:
            WANDB_WATCH:
                (Optional, ["gradients", "all", "false"]) "gradients" by default, set to "false" to disable gradient logging
                or "all" to log gradients and parameters
            WANDB_PROJECT:
                (Optional): str - "huggingface" by default, set this to a custom string to store results in a different project
            WANDB_DISABLED:
                (Optional): boolean - defaults to false, set to "true" to disable wandb entirely
        """
        if hasattr(self, "_setup_wandb"):
            warnings.warn(
                "The `_setup_wandb` method is deprecated and won't be called in a future version, define `setup_wandb` in your subclass.",
                FutureWarning,
            )
            return self._setup_wandb()

        if self.is_world_process_zero():
            logger.info(
                'Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"'
            )
            combined_dict = {**self.model.config.to_dict(), **self.args.to_sanitized_dict()}
            wandb.init(
                project=os.getenv("WANDB_PROJECT", "huggingface"), config=combined_dict, name=self.args.run_name
            )
            # keep track of model topology and gradients, unsupported on TPU
            if not is_torch_tpu_available() and os.getenv("WANDB_WATCH") != "false":
                wandb.watch(
                    self.model, log=os.getenv("WANDB_WATCH", "gradients"), log_freq=max(100, self.args.logging_steps)
                )

    def setup_comet(self):
        """
        Setup the optional Comet.ml integration.

        Environment:
            COMET_MODE:
                (Optional): str - "OFFLINE", "ONLINE", or "DISABLED"
            COMET_PROJECT_NAME:
                (Optional): str - Comet.ml project name for experiments
            COMET_OFFLINE_DIRECTORY:
                (Optional): str - folder to use for saving offline experiments when `COMET_MODE` is "OFFLINE"

        For a number of configurable items in the environment,
        see `here <https://www.comet.ml/docs/python-sdk/advanced/#comet-configuration-variables>`__
        """
        if self.is_world_master():
            comet_mode = os.getenv("COMET_MODE", "ONLINE").upper()
            args = {"project_name": os.getenv("COMET_PROJECT_NAME", "huggingface")}
            experiment = None
            if comet_mode == "ONLINE":
                experiment = comet_ml.Experiment(**args)
                logger.info("Automatic Comet.ml online logging enabled")
            elif comet_mode == "OFFLINE":
                args["offline_directory"] = os.getenv("COMET_OFFLINE_DIRECTORY", "./")
                experiment = comet_ml.OfflineExperiment(**args)
                logger.info("Automatic Comet.ml offline logging enabled; use `comet upload` when finished")
            if experiment is not None:
                experiment._set_model_graph(self.model, framework="transformers")
                experiment._log_parameters(self.args, prefix="args/", framework="transformers")
                experiment._log_parameters(self.model.config, prefix="config/", framework="transformers")

    def num_examples(self, dataloader: DataLoader) -> int:
        """
        Helper to get number of samples in a :class:`~torch.utils.data.DataLoader` by accessing its dataset.
        """
        return len(dataloader.dataset)

    def train(self, model_path: Optional[str] = None):
        """
        Main training entry point.

        Args:
            model_path (:obj:`str`, `optional`):
                Local path to the model if the model to train has been instantiated from a local path. If present,
                training will resume from the optimizer/scheduler states loaded here.
        """
        train_dataloader = self.get_train_dataloader()
        if self.args.max_steps > 0:
            t_total = self.args.max_steps
            num_train_epochs = (
                self.args.max_steps // (len(train_dataloader) // self.args.gradient_accumulation_steps) + 1
            )
        else:
            t_total = int(len(train_dataloader) // self.args.gradient_accumulation_steps * self.args.num_train_epochs)
            num_train_epochs = self.args.num_train_epochs

        self.create_optimizer_and_scheduler(num_training_steps=t_total)

        # Check if saved optimizer or scheduler states exist
        if (
            model_path is not None
            and os.path.isfile(os.path.join(model_path, "optimizer.pt"))
            and os.path.isfile(os.path.join(model_path, "scheduler.pt"))
        ):
            # Load in optimizer and scheduler states
            self.optimizer.load_state_dict(
                torch.load(os.path.join(model_path, "optimizer.pt"), map_location=self.args.device)
            )
            self.lr_scheduler.load_state_dict(torch.load(os.path.join(model_path, "scheduler.pt")))

        model = self.model
        if self.args.fp16 and _use_apex:
            if not is_apex_available():
                raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
            model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level)

        # multi-gpu training (should be after apex fp16 initialization)
        if self.args.n_gpu > 1:
            model = torch.nn.DataParallel(model)

        # Distributed training (should be after apex fp16 initialization)
        if self.args.local_rank != -1:
            model = torch.nn.parallel.DistributedDataParallel(
                model,
                device_ids=[self.args.local_rank],
                output_device=self.args.local_rank,
                find_unused_parameters=True,
            )

        if self.tb_writer is not None:
            self.tb_writer.add_text("args", self.args.to_json_string())
            self.tb_writer.add_hparams(self.args.to_sanitized_dict(), metric_dict={})

        # Train!
        if is_torch_tpu_available():
            total_train_batch_size = self.args.train_batch_size * xm.xrt_world_size()
        else:
            total_train_batch_size = (
                self.args.train_batch_size
                * self.args.gradient_accumulation_steps
                * (torch.distributed.get_world_size() if self.args.local_rank != -1 else 1)
            )
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", self.num_examples(train_dataloader))
        logger.info("  Num Epochs = %d", num_train_epochs)
        logger.info("  Instantaneous batch size per device = %d", self.args.per_device_train_batch_size)
        logger.info("  Total train batch size (w. parallel, distributed & accumulation) = %d", total_train_batch_size)
        logger.info("  Gradient Accumulation steps = %d", self.args.gradient_accumulation_steps)
        logger.info("  Total optimization steps = %d", t_total)

        self.global_step = 0
        self.epoch = 0
        epochs_trained = 0
        steps_trained_in_current_epoch = 0
        # Check if continuing training from a checkpoint
        if model_path is not None:
            # set global_step to global_step of last saved checkpoint from model path
            try:
                self.global_step = int(model_path.split("-")[-1].split("/")[0])
                epochs_trained = self.global_step // (len(train_dataloader) // self.args.gradient_accumulation_steps)
                steps_trained_in_current_epoch = self.global_step % (
                    len(train_dataloader) // self.args.gradient_accumulation_steps
                )

                logger.info("  Continuing training from checkpoint, will skip to saved global_step")
                logger.info("  Continuing training from epoch %d", epochs_trained)
                logger.info("  Continuing training from global step %d", self.global_step)
                logger.info("  Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
            except ValueError:
                self.global_step = 0
                logger.info("  Starting fine-tuning.")

        tr_loss = 0.0
        logging_loss = 0.0
        model.zero_grad()
        train_iterator = trange(
            epochs_trained, int(num_train_epochs), desc="Epoch", disable=not self.is_local_process_zero()
        )
        for epoch in train_iterator:
            if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
                train_dataloader.sampler.set_epoch(epoch)

            if is_torch_tpu_available():
                parallel_loader = pl.ParallelLoader(train_dataloader, [self.args.device]).per_device_loader(
                    self.args.device
                )
                epoch_iterator = tqdm(parallel_loader, desc="Iteration", disable=not self.is_local_process_zero())
            else:
                epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=not self.is_local_process_zero())

            # Reset the past mems state at the beginning of each epoch if necessary.
            if self.args.past_index >= 0:
                self._past = None

            for step, inputs in enumerate(epoch_iterator):

                # Skip past any already trained steps if resuming training
                if steps_trained_in_current_epoch > 0:
                    steps_trained_in_current_epoch -= 1
                    continue

                tr_loss += self.training_step(model, inputs)

                if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
                    # last step in epoch but step is always smaller than gradient_accumulation_steps
                    len(epoch_iterator) <= self.args.gradient_accumulation_steps
                    and (step + 1) == len(epoch_iterator)
                ):
                    if self.args.fp16 and _use_native_amp:
                        self.scaler.unscale_(self.optimizer)
                        torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm)
                    elif self.args.fp16 and _use_apex:
                        torch.nn.utils.clip_grad_norm_(amp.master_params(self.optimizer), self.args.max_grad_norm)
                    else:
                        torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm)

                    if is_torch_tpu_available():
                        xm.optimizer_step(self.optimizer)
                    elif self.args.fp16 and _use_native_amp:
                        self.scaler.step(self.optimizer)
                        self.scaler.update()
                    else:
                        self.optimizer.step()

                    self.lr_scheduler.step()
                    model.zero_grad()
                    self.global_step += 1
                    self.epoch = epoch + (step + 1) / len(epoch_iterator)

                    if (self.args.logging_steps > 0 and self.global_step % self.args.logging_steps == 0) or (
                        self.global_step == 1 and self.args.logging_first_step
                    ):
                        logs: Dict[str, float] = {}
                        logs["loss"] = (tr_loss - logging_loss) / self.args.logging_steps
                        # backward compatibility for pytorch schedulers
                        logs["learning_rate"] = (
                            self.lr_scheduler.get_last_lr()[0]
                            if version.parse(torch.__version__) >= version.parse("1.4")
                            else self.lr_scheduler.get_lr()[0]
                        )
                        logging_loss = tr_loss

                        self.log(logs)

                    if self.args.evaluate_during_training and self.global_step % self.args.eval_steps == 0:
                        self.evaluate()

                    if self.args.save_steps > 0 and self.global_step % self.args.save_steps == 0:
                        # In all cases (even distributed/parallel), self.model is always a reference
                        # to the model we want to save.
                        if hasattr(model, "module"):
                            assert (
                                model.module is self.model
                            ), f"Module {model.module} should be a reference to self.model"
                        else:
                            assert model is self.model, f"Model {model} should be a reference to self.model"
                        # Save model checkpoint
                        output_dir = os.path.join(self.args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.global_step}")

                        self.save_model(output_dir)

                        if self.is_world_process_zero():
                            self._rotate_checkpoints()

                        if is_torch_tpu_available():
                            xm.rendezvous("saving_optimizer_states")
                            xm.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
                            xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
                        elif self.is_world_process_zero():
                            torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
                            torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))

                if self.args.max_steps > 0 and self.global_step > self.args.max_steps:
                    epoch_iterator.close()
                    break
            if self.args.max_steps > 0 and self.global_step > self.args.max_steps:
                train_iterator.close()
                break
            if self.args.tpu_metrics_debug or self.args.debug:
                if is_torch_tpu_available():
                    # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
                    xm.master_print(met.metrics_report())
                else:
                    logger.warning(
                        "You enabled PyTorch/XLA debug metrics but you don't have a TPU "
                        "configured. Check your training configuration if this is unexpected."
                    )

        if self.tb_writer:
            self.tb_writer.close()
        if self.args.past_index and hasattr(self, "_past"):
            # Clean the state at the end of training
            delattr(self, "_past")

        logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
        return TrainOutput(self.global_step, tr_loss / self.global_step)

    def log(self, logs: Dict[str, float], iterator: Optional[tqdm] = None) -> None:
        """
        Log :obj:`logs` on the various objects watching training.

        Subclass and override this method to inject custom behavior.

        Args:
            logs (:obj:`Dict[str, float]`):
                The values to log.
            iterator (:obj:`tqdm`, `optional`):
                A potential tqdm progress bar to write the logs on.
        """
        if hasattr(self, "_log"):
            warnings.warn(
                "The `_log` method is deprecated and won't be called in a future version, define `log` in your subclass.",
                FutureWarning,
            )
            return self._log(logs, iterator=iterator)

        if self.epoch is not None:
            logs["epoch"] = self.epoch
        if self.global_step is None:
            # when logging evaluation metrics without training
            self.global_step = 0
        if self.tb_writer:
            for k, v in logs.items():
                if isinstance(v, (int, float)):
                    self.tb_writer.add_scalar(k, v, self.global_step)
                else:
                    logger.warning(
                        "Trainer is attempting to log a value of "
                        '"%s" of type %s for key "%s" as a scalar. '
                        "This invocation of Tensorboard's writer.add_scalar() "
                        "is incorrect so we dropped this attribute.",
                        v,
                        type(v),
                        k,
                    )
            self.tb_writer.flush()
        if is_wandb_available():
            if self.is_world_process_zero():
                wandb.log(logs, step=self.global_step)
        if is_comet_available():
            if self.is_world_process_zero():
                experiment = comet_ml.config.get_global_experiment()
                if experiment is not None:
                    experiment._log_metrics(logs, step=self.global_step, epoch=self.epoch, framework="transformers")
        output = {**logs, **{"step": self.global_step}}
        if iterator is not None:
            iterator.write(output)
        else:
            print(output)

    def _prepare_inputs(
        self, inputs: Dict[str, Union[torch.Tensor, Any]], model: nn.Module
    ) -> Dict[str, Union[torch.Tensor, Any]]:
        """
        Prepare :obj:`inputs` before feeding them to the model, converting them to tensors if they are not already and
        handling potential state.
        """
        for k, v in inputs.items():
            if isinstance(v, torch.Tensor):
                inputs[k] = v.to(self.args.device)

        if self.args.past_index >= 0 and self._past is not None:
            inputs["mems"] = self._past

        return inputs

    def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> float:
        """
        Perform a training step on a batch of inputs.

        Subclass and override to inject custom behavior.

        Args:
            model (:obj:`nn.Module`):
                The model to train.
            inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
                The inputs and targets of the model.

                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
                argument :obj:`labels`. Check your model's documentation for all accepted arguments.

        Return:
            :obj:`float`: The training loss on this batch.
        """
        if hasattr(self, "_training_step"):
            warnings.warn(
                "The `_training_step` method is deprecated and won't be called in a future version, define `training_step` in your subclass.",
                FutureWarning,
            )
            return self._training_step(model, inputs, self.optimizer)

        model.train()
        inputs = self._prepare_inputs(inputs, model)

        if self.args.fp16 and _use_native_amp:
            with autocast():
                outputs = model(**inputs)
                loss = outputs[0]
        else:
            outputs = model(**inputs)
            # We don't use .loss here since the model may return tuples instead of ModelOutput.
            loss = outputs[0]

        if self.args.past_index >= 0:
            self._past = outputs[self.args.past_index]

        if self.args.n_gpu > 1:
            loss = loss.mean()  # mean() to average on multi-gpu parallel training

        if self.args.gradient_accumulation_steps > 1:
            loss = loss / self.args.gradient_accumulation_steps

        if self.args.fp16 and _use_native_amp:
            self.scaler.scale(loss).backward()
        elif self.args.fp16 and _use_apex:
            with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()

        return loss.item()

    def is_local_master(self) -> bool:
        """
        Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on
        several machines) main process.

        .. warning::

            This method is deprecated, use :meth:`~transformers.Trainer.is_local_process_zero` instead.
        """
        warnings.warn("This method is deprecated, use `Trainer.is_local_process_zero()` instead.", FutureWarning)
        return self.is_local_process_zero()

    def is_local_process_zero(self) -> bool:
        """
        Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on
        several machines) main process.
        """
        if is_torch_tpu_available():
            return xm.is_master_ordinal(local=True)
        else:
            return self.args.local_rank in [-1, 0]

    def is_world_master(self) -> bool:
        """
        Whether or not this process is the global main process (when training in a distributed fashion on
        several machines, this is only going to be :obj:`True` for one process).

        .. warning::

            This method is deprecated, use :meth:`~transformers.Trainer.is_world_process_zero` instead.
        """
        warnings.warn("This method is deprecated, use `Trainer.is_world_process_zero()` instead.", FutureWarning)
        return self.is_world_process_zero()

    def is_world_process_zero(self) -> bool:
        """
        Whether or not this process is the global main process (when training in a distributed fashion on
        several machines, this is only going to be :obj:`True` for one process).
        """
        if is_torch_tpu_available():
            return xm.is_master_ordinal(local=False)
        else:
            return self.args.local_rank == -1 or torch.distributed.get_rank() == 0

    def save_model(self, output_dir: Optional[str] = None):
        """
        Will save the model, so you can reload it using :obj:`from_pretrained()`.

        Will only save from the world_master process (unless in TPUs).
        """

        if is_torch_tpu_available():
            self._save_tpu(output_dir)
        elif self.is_world_process_zero():
            self._save(output_dir)

    def _save_tpu(self, output_dir: Optional[str] = None):
        output_dir = output_dir if output_dir is not None else self.args.output_dir
        logger.info("Saving model checkpoint to %s", output_dir)

        if xm.is_master_ordinal():
            os.makedirs(output_dir, exist_ok=True)
            torch.save(self.args, os.path.join(output_dir, "training_args.bin"))

        # Save a trained model and configuration using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
        if not isinstance(self.model, PreTrainedModel):
            raise ValueError("Trainer.model appears to not be a PreTrainedModel")

        xm.rendezvous("saving_checkpoint")
        self.model.save_pretrained(output_dir)

    def _save(self, output_dir: Optional[str] = None):
        output_dir = output_dir if output_dir is not None else self.args.output_dir
        os.makedirs(output_dir, exist_ok=True)
        logger.info("Saving model checkpoint to %s", output_dir)
        # Save a trained model and configuration using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
        if not isinstance(self.model, PreTrainedModel):
            raise ValueError("Trainer.model appears to not be a PreTrainedModel")
        self.model.save_pretrained(output_dir)

        # Good practice: save your training arguments together with the trained model
        torch.save(self.args, os.path.join(output_dir, "training_args.bin"))

    def _sorted_checkpoints(self, checkpoint_prefix=PREFIX_CHECKPOINT_DIR, use_mtime=False) -> List[str]:
        ordering_and_checkpoint_path = []

        glob_checkpoints = [str(x) for x in Path(self.args.output_dir).glob(f"{checkpoint_prefix}-*")]

        for path in glob_checkpoints:
            if use_mtime:
                ordering_and_checkpoint_path.append((os.path.getmtime(path), path))
            else:
                regex_match = re.match(f".*{checkpoint_prefix}-([0-9]+)", path)
                if regex_match and regex_match.groups():
                    ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))

        checkpoints_sorted = sorted(ordering_and_checkpoint_path)
        checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
        return checkpoints_sorted

    def _rotate_checkpoints(self, use_mtime=False) -> None:
        if self.args.save_total_limit is None or self.args.save_total_limit <= 0:
            return

        # Check if we should delete older checkpoint(s)
        checkpoints_sorted = self._sorted_checkpoints(use_mtime=use_mtime)
        if len(checkpoints_sorted) <= self.args.save_total_limit:
            return

        number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - self.args.save_total_limit)
        checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]
        for checkpoint in checkpoints_to_be_deleted:
            logger.info("Deleting older checkpoint [{}] due to args.save_total_limit".format(checkpoint))
            shutil.rmtree(checkpoint)

    def evaluate(self, eval_dataset: Optional[Dataset] = None) -> Dict[str, float]:
        """
        Run evaluation and returns metrics.

        The calling script will be responsible for providing a method to compute metrics, as they are
        task-dependent (pass it to the init :obj:`compute_metrics` argument).

        You can also subclass and override this method to inject custom behavior.

        Args:
            eval_dataset (:obj:`Dataset`, `optional`):
                Pass a dataset if you wish to override :obj:`self.eval_dataset`.

        Returns:
            A dictionary containing the evaluation loss and the potential metrics computed from the predictions.
        """
        eval_dataloader = self.get_eval_dataloader(eval_dataset)

        output = self.prediction_loop(eval_dataloader, description="Evaluation")

        self.log(output.metrics)

        if self.args.tpu_metrics_debug or self.args.debug:
            # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
            xm.master_print(met.metrics_report())

        return output.metrics

    def predict(self, test_dataset: Dataset) -> PredictionOutput:
        """
        Run prediction and returns predictions and potential metrics.

        Depending on the dataset and your use case, your test dataset may contain labels.
        In that case, this method will also return metrics, like in :obj:`evaluate()`.

        Args:
            test_dataset (:obj:`Dataset`):
                Dataset to run the predictions on.

        Returns:
            `NamedTuple`:
            predictions (:obj:`np.ndarray`):
                The predictions on :obj:`test_dataset`.
            label_ids (:obj:`np.ndarray`, `optional`):
                The labels (if the dataset contained some).
            metrics (:obj:`Dict[str, float]`, `optional`):
                The potential dictionary of metrics (if the dataset contained labels).
        """
        test_dataloader = self.get_test_dataloader(test_dataset)

        return self.prediction_loop(test_dataloader, description="Prediction")

    def prediction_loop(
        self, dataloader: DataLoader, description: str, prediction_loss_only: Optional[bool] = None
    ) -> PredictionOutput:
        """
        Prediction/evaluation loop, shared by :obj:`Trainer.evaluate()` and :obj:`Trainer.predict()`.

        Works both with or without labels.
        """
        if hasattr(self, "_prediction_loop"):
            warnings.warn(
                "The `_prediction_loop` method is deprecated and won't be called in a future version, define `prediction_loop` in your subclass.",
                FutureWarning,
            )
            return self._prediction_loop(dataloader, description, prediction_loss_only=prediction_loss_only)

        prediction_loss_only = (
            prediction_loss_only if prediction_loss_only is not None else self.args.prediction_loss_only
        )

        model = self.model
        # multi-gpu eval
        if self.args.n_gpu > 1:
            model = torch.nn.DataParallel(model)
        else:
            model = self.model
        # Note: in torch.distributed mode, there's no point in wrapping the model
        # inside a DistributedDataParallel as we'll be under `no_grad` anyways.

        batch_size = dataloader.batch_size
        logger.info("***** Running %s *****", description)
        logger.info("  Num examples = %d", self.num_examples(dataloader))
        logger.info("  Batch size = %d", batch_size)
        eval_losses: List[float] = []
        preds: torch.Tensor = None
        label_ids: torch.Tensor = None
        model.eval()

        if is_torch_tpu_available():
            dataloader = pl.ParallelLoader(dataloader, [self.args.device]).per_device_loader(self.args.device)

        if self.args.past_index >= 0:
            self._past = None

        for inputs in tqdm(dataloader, desc=description):
            loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only)
            if loss is not None:
                eval_losses.append(loss)
            if logits is not None:
                preds = logits if preds is None else torch.cat((preds, logits), dim=0)
            if labels is not None:
                label_ids = labels if label_ids is None else torch.cat((label_ids, labels), dim=0)

        if self.args.past_index and hasattr(self, "_past"):
            # Clean the state at the end of the evaluation loop
            delattr(self, "_past")

        if self.args.local_rank != -1:
            # In distributed mode, concatenate all results from all nodes:
            if preds is not None:
                preds = self.distributed_concat(preds, num_total_examples=self.num_examples(dataloader))
            if label_ids is not None:
                label_ids = self.distributed_concat(label_ids, num_total_examples=self.num_examples(dataloader))
        elif is_torch_tpu_available():
            # tpu-comment: Get all predictions and labels from all worker shards of eval dataset
            if preds is not None:
                preds = xm.mesh_reduce("eval_preds", preds, torch.cat)
            if label_ids is not None:
                label_ids = xm.mesh_reduce("eval_label_ids", label_ids, torch.cat)

        # Finally, turn the aggregated tensors into numpy arrays.
        if preds is not None:
            preds = preds.cpu().numpy()
        if label_ids is not None:
            label_ids = label_ids.cpu().numpy()

        if self.compute_metrics is not None and preds is not None and label_ids is not None:
            metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids))
        else:
            metrics = {}
        if len(eval_losses) > 0:
            metrics["eval_loss"] = np.mean(eval_losses)

        # Prefix all keys with eval_
        for key in list(metrics.keys()):
            if not key.startswith("eval_"):
                metrics[f"eval_{key}"] = metrics.pop(key)

        return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics)

    def distributed_concat(self, tensor: torch.Tensor, num_total_examples: int) -> torch.Tensor:
        assert self.args.local_rank != -1

        output_tensors = [tensor.clone() for _ in range(torch.distributed.get_world_size())]
        torch.distributed.all_gather(output_tensors, tensor)

        concat = torch.cat(output_tensors, dim=0)

        # truncate the dummy elements added by SequentialDistributedSampler
        output = concat[:num_total_examples]
        return output

    def prediction_step(
        self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], prediction_loss_only: bool
    ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
        """
        Perform an evaluation step on :obj:`model` using obj:`inputs`.

        Subclass and override to inject custom behavior.

        Args:
            model (:obj:`nn.Module`):
                The model to evaluate.
            inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
                The inputs and targets of the model.

                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
                argument :obj:`labels`. Check your model's documentation for all accepted arguments.
            prediction_loss_only (:obj:`bool`):
                Whether or not to return the loss only.

        Return:
            Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
            A tuple with the loss, logits and labels (each being optional).
        """
        has_labels = any(inputs.get(k) is not None for k in ["labels", "lm_labels", "masked_lm_labels"])

        inputs = self._prepare_inputs(inputs, model)

        with torch.no_grad():
            outputs = model(**inputs)
            if has_labels:
                loss, logits = outputs[:2]
                loss = loss.mean().item()
            else:
                loss = None
                logits = outputs[0]
            if self.args.past_index >= 0:
                self._past = outputs[self.args.past_index if has_labels else self.args.past_index - 1]

        if prediction_loss_only:
            return (loss, None, None)

        labels = inputs.get("labels")
        if labels is not None:
            labels = labels.detach()
        return (loss, logits.detach(), labels)
Example #26
0
            total_loss += loss.item()
            total_correct += get_num_correct(preds, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        tb.add_scalar("Loss", total_loss, epoch)
        tb.add_scalar("Correct", total_correct, epoch)
        tb.add_scalar("Accuracy", total_correct / len(train_set), epoch)

        print("batch_size:", batch_size, "lr:", lr, "shuffle:", shuffle)
        print("epoch:", epoch, "total_correct:", total_correct, "loss:",
              total_loss)
    print(
        "___________________________________________________________________")

    tb.add_hparams(
        {
            "lr": lr,
            "bsize": batch_size,
            "shuffle": shuffle
        },
        {
            "accuracy": total_correct / len(train_set),
            "loss": total_loss,
        },
    )

tb.close()
Example #27
0
class Trainer:
    def __init__(self,
                 model=None,
                 optimiser=None,
                 tensorboard_path=None,
                 data_path=None,
                 model_path=None,
                 K=0):
        self.history = {
            "A": deque([]),
            "X": deque([]),
            "done": deque([]),
            "expert": deque([]),
            "context": deque([])
        }
        self.iter = 0
        self.model = model
        self.optimiser = optimiser
        self.tensorboard_path = tensorboard_path
        self.data_path = data_path
        self.model_path = model_path
        self.K = K
        self.valid_idxs = []
        self.start_idxs = []
        self.sample_weights = torch.tensor([])
        self.sample_idxs = []
        self.steps_since_start = -1
        self.loss = None
        if self.tensorboard_path is not None:
            self.tensorboard = SummaryWriter(self.tensorboard_path)
        if self.model_path is not None:
            load_model(self.model, self.model_path)
        if self.data_path is not None:
            self.load_trainer()

    def save_trainer_onexit(self, path=None):
        atexit.register(self.save_trainer, path=path)

    def save_trainer(self, path=None):
        # Create state dict
        if len(self.history["done"]) >= 1:
            self.history["done"][-1] = True
        data = self.__dict__.copy()
        if data["model"] is not None:
            save_model(self.model, self.model_path)
        if data["optimiser"] is not None:
            data["optimiser_state"] = self.optimiser.state_dict()
        for name in [
                "model", "optimiser", "K", "tensorboard_path", "data_path",
                "model_path", "valid_idxs", "start_idxs", "steps_since_start",
                "tensorboard", "loss", "loss_fn", "tensorboard_fn"
        ]:
            if name in data:
                del data[name]
        # Save file
        if path is None:
            path = self.data_path
        with open(path, 'wb') as fp:
            torch.save(data, fp)
        print("Saved Trainer " + path)
        return data

    def load_trainer(self, path=None):
        if path is None:
            path = self.data_path
        try:
            fp = open(path, 'rb')
        except:
            print("No Trainer to Load")
            return False
        data = torch.load(fp, map_location=device)
        self.load_trainer_dict(data)
        self.update_valid_idxs()
        print("Loaded Trainer " + path)
        return True

    def load_trainer_dict(self, data):
        if "optimiser_state" in data:
            self.optimiser.load_state_dict(data["optimiser_state"])
        for name in [
                "model", "optimiser", "K", "tensorboard_path", "data_path",
                "model_path", "valid_idxs", "start_idxs", "steps_since_start",
                "tensorboard", "loss", "loss_fn", "tensorboard_fn"
        ]:
            if name in data:
                del data[name]
        for name, val in data.items():
            setattr(self, name, val)

    def set_state(self, A, X, done=False, expert=None, context={}):
        # Update data
        self.history["done"].append(done)
        self.history["A"].append(A.to(device))
        self.history["X"].append(X.to(device))
        self.history["context"].append(context)
        if expert is not None:
            self.history["expert"].append(expert.to(device))
        else:
            self.history["expert"].append(expert)
        # Update valid idxs
        if done:
            self.steps_since_start = -1  # steps since first episode step
        else:
            self.steps_since_start += 1
        idx = len(self.history["X"]) - 1
        if self.steps_since_start >= self.K:
            self.valid_idxs.append(idx)
        if self.steps_since_start == 0:
            self.start_idxs.append(idx)

    def get_state(self, idx=-1, K=None):
        if K is None:
            K = self.K
        idx = self.valid_idxs[idx]
        # Compile data
        batch = {}
        X = torch.stack([self.history["X"][idx - k] for k in range(K + 1)],
                        dim=2)
        batch["X"] = X
        A = torch.stack([self.history["A"][idx - k] for k in range(K + 1)],
                        dim=2)
        batch["A"] = A
        expert = torch.stack(
            [self.history["expert"][idx - k] for k in range(K + 1)], dim=2)
        batch["expert"] = expert
        context = [self.history["context"][idx - k] for k in range(K + 1)]
        batch["context"] = context
        return batch

    def update_valid_idxs(self):
        self.valid_idxs = []
        self.start_idxs = []
        for idx in range(len(self.history["done"])):
            if self.history["done"][idx]:
                self.steps_since_start = -1
            else:
                self.steps_since_start += 1
            if self.steps_since_start >= self.K:
                self.valid_idxs.append(idx)
            if self.steps_since_start == 0:
                self.start_idxs.append(idx)

    def train(self, loss):
        loss.backward()
        self.optimiser.step()
        self.optimiser.zero_grad()
        # self.iter = list(self.optimiser.state_dict()["state"].values())[0]['step']
        self.iter += 1

    def get_batch(self, batch_size=16, data_split=1.0, weighting=False):
        # data_split = 0.9 for train set, data_split = -0.1 for test set
        # Init weighting
        if weighting and self.sample_weights.shape[0] != len(self.valid_idxs):
            self.sample_weights = torch.ones(len(self.valid_idxs))
        # Get indices
        size = min(batch_size, len(self.valid_idxs) - 1)
        idxs = self.get_random_idxs(num=size,
                                    data_split=data_split,
                                    weighting=weighting)
        self.sample_idxs = idxs
        # Compile data
        batch = {}
        X = torch.stack([
            torch.stack(
                [self.history["X"][idx - k] for k in range(self.K + 1)], dim=2)
            for idx in idxs
        ],
                        dim=0)
        batch["X"] = X
        A = torch.stack([
            torch.stack(
                [self.history["A"][idx - k] for k in range(self.K + 1)], dim=2)
            for idx in idxs
        ],
                        dim=0)
        batch["A"] = A
        expert = torch.stack([self.history["expert"][idx] for idx in idxs],
                             dim=0)
        batch["expert"] = expert
        context = [self.history["context"][idx] for idx in idxs]
        batch["context"] = context
        return batch

    def get_random_idxs(self, num=1, data_split=1.0, weighting=False):
        N = len(self.valid_idxs)
        N_set = int(N * abs(data_split))
        if weighting:
            weights = self.sample_weights[:N_set]
        else:
            weights = torch.ones(N_set)
        if data_split > 0:
            idx_idxs = torch.multinomial(weights, num, replacement=False)
        else:
            idx_idxs = (N - N_set) + torch.multinomial(
                weights, num, replacement=False)
        return [self.valid_idxs[i] for i in idx_idxs]

    def get_episodes(self):
        episodes = [self.get_episode(idx) for idx in self.start_idxs]
        Xs = [episode["X"] for episode in episodes]
        As = [episode["A"] for episode in episodes]
        experts = [episode["expert"] for episode in episodes]
        contexts = [episode["context"] for episode in episodes]
        max_length = max(map(lambda episode: episode["X"].shape[0], episodes))
        N = Xs[0].shape[1]
        D = Xs[0].shape[2]
        Xs = [
            torch.cat([
                X.float(),
                torch.full(
                    (max_length - X.shape[0], N, D), np.nan, device=device)
            ],
                      dim=0) for X in Xs
        ]
        As = [
            torch.cat([
                A.float(),
                torch.full(
                    (max_length - A.shape[0], N, N), np.nan, device=device)
            ],
                      dim=0) for A in As
        ]
        if len(experts[0].shape) == 3:
            OUT_DIM = experts[0].shape[2]
            experts = [
                torch.cat([
                    expert.float(),
                    torch.full((max_length - expert.shape[0], N, OUT_DIM),
                               np.nan,
                               device=device)
                ],
                          dim=0) for expert in experts
            ]
        elif len(experts[0].shape) == 2:
            experts = [
                torch.cat([
                    expert.float(),
                    torch.full((max_length - expert.shape[0], N),
                               np.nan,
                               device=device)
                ],
                          dim=0) for expert in experts
            ]
        contexts = [
            np.concatenate([
                context,
                np.array([{} for _ in range(max_length - context.shape[0])])
            ],
                           axis=0) for context in contexts
        ]
        Xs = torch.stack(Xs, dim=0)
        As = torch.stack(As, dim=0)
        experts = torch.stack(experts, dim=0)
        contexts = np.array(contexts)
        data = {"X": Xs, "A": As, "expert": experts, "context": contexts}
        return data  # num_episodes x episode_length x N x D

    def get_episode(self, idx=None):
        # if idx is given, returns the rest of that episode
        # if idx is None, returns an entire random episode
        if idx is None:
            N = len(self.start_idxs)
            idx = self.start_idxs[torch.randint(N, size=())]
        A = []
        X = []
        expert = []
        context = []
        k = 0
        while idx + k != len(self.history["done"]) and (
                not self.history["done"][idx + k] or k == 0):
            A.append(self.history["A"][idx + k])
            X.append(self.history["X"][idx + k])
            expert.append(self.history["expert"][idx + k])
            context.append(self.history["context"][idx + k])
            k += 1
        A = torch.stack(A, dim=0)
        X = torch.stack(X, dim=0)
        expert = torch.stack(expert, dim=0)
        context = np.array(context)
        return {"A": A, "X": X, "expert": expert, "context": context}

    def get_X(self):
        X = torch.tensor(self.history["X"])
        X = X[self.valid_idxs]
        return X

    def update_tensorboard(self, value, t=None, datatype=None):
        if not hasattr(self, "tensorboard"):
            print("No tensorboard object")
            return
        if t is None:
            t = self.iter
        if datatype is None:
            for name, val in value.items():
                self.update_tensorboard(val, t=t, datatype=name)
        else:
            if datatype is 'scalar':  # given as a dict {name: value}
                for name, val in value.items():
                    if isinstance(val, dict):
                        self.tensorboard.add_scalars(
                            name, val, global_step=t)  # adds to the same plot
                    else:
                        self.tensorboard.add_scalar(name, val, global_step=t)
            elif datatype is 'graph':  # given as a tuple (net, inputs)
                if t == 0:
                    net = value[0]
                    inputs = value[1]
                    self.tensorboard.add_graph(net, inputs)
            elif datatype is 'embedding':  # given as a tuple (features, labels)
                features = value[
                    0]  # NxD, each row is the feature vector of a data point
                labels = value[1]  # N, vector of int labels
                self.tensorboard.add_embedding(features,
                                               metadata=labels,
                                               global_step=t)
            elif datatype is 'hyperparameter':  # given as a tuple (hparam_dict, metric_dict)
                hparam_dict = value[
                    0]  # {hyperparameter name: hyperparameter value}
                metric_dict = value[1]  # {metric name: metric value}
                self.tensorboard.add_hparams(hparam_dict, metric_dict)
            elif datatype is 'histogram':
                for name, val in value.items():
                    self.tensorboard.add_histogram(name, val, global_step=t)

    def show_tensorboard(self, openbrowser=True):
        from tensorboard import program
        tb = program.TensorBoard()
        tb.configure(argv=[None, '--logdir', self.tensorboard.log_dir])
        url = tb.launch()
        if openbrowser:
            import webbrowser
            webbrowser.open(url)
def do_pretrain(args):
    if is_main_process(args) and args.tensorboard_dir:
        tb_writer = SummaryWriter(log_dir=args.tensorboard_dir)
        tb_writer.add_text("args", args.to_json_string())
        tb_writer.add_hparams(args.to_sanitized_dict(), metric_dict={})
    else:
        tb_writer = None

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    ort.set_seed(args.seed)

    device, args = setup_training(args)

    model = prepare_model(args, device)

    logger.info("Running training: Batch size = %d, initial LR = %f",
                args.train_batch_size, args.learning_rate)

    most_recent_ckpts_paths = []
    average_loss = 0.0
    epoch = 0
    training_steps = 0

    pool = ProcessPoolExecutor(1)
    while True:
        files = [
            os.path.join(args.input_dir, f) for f in os.listdir(args.input_dir)
            if os.path.isfile(os.path.join(args.input_dir, f))
            and 'training' in f
        ]
        files.sort()
        random.shuffle(files)

        f_id = 0
        train_dataloader, data_file = create_pretraining_dataset(
            get_data_file(f_id, args.world_rank, args.world_size, files),
            args.max_predictions_per_seq, args)

        for f_id in range(1, len(files)):
            logger.info("data file %s" % (data_file))

            dataset_future = pool.submit(
                create_pretraining_dataset,
                get_data_file(f_id, args.world_rank, args.world_size, files),
                args.max_predictions_per_seq, args)

            train_iter = tqdm(train_dataloader, desc="Iteration"
                              ) if is_main_process(args) else train_dataloader
            for step, batch in enumerate(train_iter):
                training_steps += 1
                batch = [t.to(device) for t in batch]
                input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels = batch

                loss, _, _ = model.train_step(input_ids, input_mask,
                                              segment_ids, masked_lm_labels,
                                              next_sentence_labels)
                average_loss += loss.item()

                global_step = model._train_step_info.optimization_step
                if training_steps % (args.log_freq *
                                     args.gradient_accumulation_steps) == 0:
                    if is_main_process(args):
                        divisor = args.log_freq * args.gradient_accumulation_steps
                        if tb_writer:
                            lr = model.options.lr_scheduler.get_last_lr()[0]
                            tb_writer.add_scalar(
                                'train/summary/scalar/Learning_Rate', lr,
                                global_step)
                            if args.fp16:
                                tb_writer.add_scalar(
                                    'train/summary/scalar/loss_scale_25', loss,
                                    global_step)
                                # TODO: ORTTrainer to expose all_finite
                                # tb_writer.add_scalar('train/summary/scalar/all_fp16_gradients_finite_859', all_finite, global_step)
                            tb_writer.add_scalar('train/summary/total_loss',
                                                 average_loss / divisor,
                                                 global_step)

                        print("Step:{} Average Loss = {}".format(
                            global_step, average_loss / divisor))

                    if global_step >= args.max_steps or global_step >= force_to_stop_max_steps:
                        if tb_writer:
                            tb_writer.close()

                    if global_step >= args.max_steps:
                        if args.save_checkpoint:
                            experimental_save_checkpoint(
                                model, args.output_dir)
                        final_loss = average_loss / (
                            args.log_freq * args.gradient_accumulation_steps)
                        return final_loss

                    average_loss = 0

            del train_dataloader

            train_dataloader, data_file = dataset_future.result(timeout=None)

        epoch += 1
Example #29
0
def train_agent(train_params, train_env_params, eval_env_params, obs_params):
    # Environment parameters
    n_agents = train_env_params.n_agents
    x_dim = train_env_params.x_dim
    y_dim = train_env_params.y_dim
    n_cities = train_env_params.n_cities
    max_rails_between_cities = train_env_params.max_rails_between_cities
    max_rails_in_city = train_env_params.max_rails_in_city
    seed = train_env_params.seed

    # Unique ID for this training
    now = datetime.now()
    training_id = now.strftime('%y%m%d%H%M%S')

    # Observation parameters
    observation_tree_depth = obs_params.observation_tree_depth
    observation_radius = obs_params.observation_radius
    observation_max_path_depth = obs_params.observation_max_path_depth

    # Training parameters
    eps_start = train_params.eps_start
    eps_end = train_params.eps_end
    eps_decay = train_params.eps_decay
    n_episodes = train_params.n_episodes
    checkpoint_interval = train_params.checkpoint_interval
    n_eval_episodes = train_params.n_evaluation_episodes
    restore_replay_buffer = train_params.restore_replay_buffer
    save_replay_buffer = train_params.save_replay_buffer

    # Set the seeds
    random.seed(seed)
    np.random.seed(seed)

    # Observation builder
    predictor = ShortestPathPredictorForRailEnv(observation_max_path_depth)
    tree_observation = TreeObsForRailEnv(max_depth=observation_tree_depth,
                                         predictor=predictor)

    # Setup the environments
    train_env = create_rail_env(train_env_params, tree_observation)
    train_env.reset(regenerate_schedule=True, regenerate_rail=True)
    eval_env = create_rail_env(eval_env_params, tree_observation)
    eval_env.reset(regenerate_schedule=True, regenerate_rail=True)

    # Setup renderer
    if train_params.render:
        env_renderer = RenderTool(train_env, gl="PGL")

    # Calculate the state size given the depth of the tree observation and the number of features
    n_features_per_node = train_env.obs_builder.observation_dim
    n_nodes = sum([np.power(4, i) for i in range(observation_tree_depth + 1)])
    state_size = n_features_per_node * n_nodes

    # The action space of flatland is 5 discrete actions
    action_size = 5

    # Max number of steps per episode
    # This is the official formula used during evaluations
    # See details in flatland.envs.schedule_generators.sparse_schedule_generator
    # max_steps = int(4 * 2 * (env.height + env.width + (n_agents / n_cities)))
    max_steps = train_env._max_episode_steps

    action_count = [0] * action_size
    action_dict = dict()
    agent_obs = [None] * n_agents
    agent_prev_obs = [None] * n_agents
    agent_prev_action = [2] * n_agents
    update_values = [False] * n_agents

    # Smoothed values used as target for hyperparameter tuning
    smoothed_normalized_score = -1.0
    smoothed_eval_normalized_score = -1.0
    smoothed_completion = 0.0
    smoothed_eval_completion = 0.0

    # Double Dueling DQN policy
    policy = DDDQNPolicy(state_size, action_size, train_params)

    # Loads existing replay buffer
    if restore_replay_buffer:
        try:
            policy.load_replay_buffer(restore_replay_buffer)
            policy.test()
        except RuntimeError as e:
            print(
                "\n🛑 Could't load replay buffer, were the experiences generated using the same tree depth?"
            )
            print(e)
            exit(1)

    print("\n💾 Replay buffer status: {}/{} experiences".format(
        len(policy.memory.memory), train_params.buffer_size))

    hdd = psutil.disk_usage('/')
    if save_replay_buffer and (hdd.free / (2**30)) < 500.0:
        print(
            "⚠️  Careful! Saving replay buffers will quickly consume a lot of disk space. You have {:.2f}gb left."
            .format(hdd.free / (2**30)))

    # TensorBoard writer
    writer = SummaryWriter()
    writer.add_hparams(vars(train_params), {})
    writer.add_hparams(vars(train_env_params), {})
    writer.add_hparams(vars(obs_params), {})

    training_timer = Timer()
    training_timer.start()

    print(
        "\n🚉 Training {} trains on {}x{} grid for {} episodes, evaluating on {} episodes every {} episodes. Training id '{}'.\n"
        .format(train_env.get_num_agents(), x_dim, y_dim, n_episodes,
                n_eval_episodes, checkpoint_interval, training_id))

    for episode_idx in range(n_episodes + 1):
        step_timer = Timer()
        reset_timer = Timer()
        learn_timer = Timer()
        preproc_timer = Timer()
        inference_timer = Timer()

        # Reset environment
        reset_timer.start()
        obs, info = train_env.reset(regenerate_rail=True,
                                    regenerate_schedule=True)
        reset_timer.end()

        if train_params.render:
            env_renderer.set_new_rail()

        score = 0
        nb_steps = 0
        actions_taken = []

        # Build initial agent-specific observations
        for agent in train_env.get_agent_handles():
            if obs[agent]:
                agent_obs[agent] = normalize_observation(
                    obs[agent],
                    observation_tree_depth,
                    observation_radius=observation_radius)
                agent_prev_obs[agent] = agent_obs[agent].copy()

        # Run episode
        for step in range(max_steps - 1):
            inference_timer.start()
            for agent in train_env.get_agent_handles():
                if info['action_required'][agent]:
                    update_values[agent] = True
                    action = policy.act(agent_obs[agent], eps=eps_start)

                    action_count[action] += 1
                    actions_taken.append(action)
                else:
                    # An action is not required if the train hasn't joined the railway network,
                    # if it already reached its target, or if is currently malfunctioning.
                    update_values[agent] = False
                    action = 0
                action_dict.update({agent: action})
            inference_timer.end()

            # Environment step
            step_timer.start()
            next_obs, all_rewards, done, info = train_env.step(action_dict)
            step_timer.end()

            # Render an episode at some interval
            if train_params.render and episode_idx % checkpoint_interval == 0:
                env_renderer.render_env(show=True,
                                        frames=False,
                                        show_observations=False,
                                        show_predictions=False)

            # Update replay buffer and train agent
            for agent in train_env.get_agent_handles():
                if update_values[agent] or done['__all__']:
                    # Only learn from timesteps where somethings happened
                    learn_timer.start()
                    policy.step(agent_prev_obs[agent],
                                agent_prev_action[agent], all_rewards[agent],
                                agent_obs[agent], done[agent])
                    learn_timer.end()

                    agent_prev_obs[agent] = agent_obs[agent].copy()
                    agent_prev_action[agent] = action_dict[agent]

                # Preprocess the new observations
                if next_obs[agent]:
                    preproc_timer.start()
                    agent_obs[agent] = normalize_observation(
                        next_obs[agent],
                        observation_tree_depth,
                        observation_radius=observation_radius)
                    preproc_timer.end()

                score += all_rewards[agent]

            nb_steps = step

            if done['__all__']:
                break

        # Epsilon decay
        eps_start = max(eps_end, eps_decay * eps_start)

        # Collect information about training
        tasks_finished = sum(done[idx]
                             for idx in train_env.get_agent_handles())
        completion = tasks_finished / max(1, train_env.get_num_agents())
        normalized_score = score / (max_steps * train_env.get_num_agents())
        action_probs = action_count / np.sum(action_count)
        action_count = [1] * action_size

        smoothing = 0.99
        smoothed_normalized_score = smoothed_normalized_score * smoothing + normalized_score * (
            1.0 - smoothing)
        smoothed_completion = smoothed_completion * smoothing + completion * (
            1.0 - smoothing)

        # Print logs
        if episode_idx % checkpoint_interval == 0:
            torch.save(
                policy.qnetwork_local, './checkpoints/' + training_id + '-' +
                str(episode_idx) + '.pth')

            if save_replay_buffer:
                policy.save_replay_buffer('./replay_buffers/' + training_id +
                                          '-' + str(episode_idx) + '.pkl')

            if train_params.render:
                env_renderer.close_window()

        print('\r🚂 Episode {}'
              '\t 🏆 Score: {:.3f}'
              ' Avg: {:.3f}'
              '\t 💯 Done: {:.2f}%'
              ' Avg: {:.2f}%'
              '\t 🎲 Epsilon: {:.3f} '
              '\t 🔀 Action Probs: {}'.format(episode_idx, normalized_score,
                                             smoothed_normalized_score,
                                             100 * completion,
                                             100 * smoothed_completion,
                                             eps_start,
                                             format_action_prob(action_probs)),
              end=" ")

        # Evaluate policy and log results at some interval
        if episode_idx % checkpoint_interval == 0 and n_eval_episodes > 0:
            scores, completions, nb_steps_eval = eval_policy(
                eval_env, policy, train_params, obs_params)

            writer.add_scalar("evaluation/scores_min", np.min(scores),
                              episode_idx)
            writer.add_scalar("evaluation/scores_max", np.max(scores),
                              episode_idx)
            writer.add_scalar("evaluation/scores_mean", np.mean(scores),
                              episode_idx)
            writer.add_scalar("evaluation/scores_std", np.std(scores),
                              episode_idx)
            writer.add_histogram("evaluation/scores", np.array(scores),
                                 episode_idx)
            writer.add_scalar("evaluation/completions_min",
                              np.min(completions), episode_idx)
            writer.add_scalar("evaluation/completions_max",
                              np.max(completions), episode_idx)
            writer.add_scalar("evaluation/completions_mean",
                              np.mean(completions), episode_idx)
            writer.add_scalar("evaluation/completions_std",
                              np.std(completions), episode_idx)
            writer.add_histogram("evaluation/completions",
                                 np.array(completions), episode_idx)
            writer.add_scalar("evaluation/nb_steps_min", np.min(nb_steps_eval),
                              episode_idx)
            writer.add_scalar("evaluation/nb_steps_max", np.max(nb_steps_eval),
                              episode_idx)
            writer.add_scalar("evaluation/nb_steps_mean",
                              np.mean(nb_steps_eval), episode_idx)
            writer.add_scalar("evaluation/nb_steps_std", np.std(nb_steps_eval),
                              episode_idx)
            writer.add_histogram("evaluation/nb_steps",
                                 np.array(nb_steps_eval), episode_idx)

            smoothing = 0.9
            smoothed_eval_normalized_score = smoothed_eval_normalized_score * smoothing + np.mean(
                scores) * (1.0 - smoothing)
            smoothed_eval_completion = smoothed_eval_completion * smoothing + np.mean(
                completions) * (1.0 - smoothing)
            writer.add_scalar("evaluation/smoothed_score",
                              smoothed_eval_normalized_score, episode_idx)
            writer.add_scalar("evaluation/smoothed_completion",
                              smoothed_eval_completion, episode_idx)

        # Save logs to tensorboard
        writer.add_scalar("training/score", normalized_score, episode_idx)
        writer.add_scalar("training/smoothed_score", smoothed_normalized_score,
                          episode_idx)
        writer.add_scalar("training/completion", np.mean(completion),
                          episode_idx)
        writer.add_scalar("training/smoothed_completion",
                          np.mean(smoothed_completion), episode_idx)
        writer.add_scalar("training/nb_steps", nb_steps, episode_idx)
        writer.add_histogram("actions/distribution", np.array(actions_taken),
                             episode_idx)
        writer.add_scalar("actions/nothing",
                          action_probs[RailEnvActions.DO_NOTHING], episode_idx)
        writer.add_scalar("actions/left",
                          action_probs[RailEnvActions.MOVE_LEFT], episode_idx)
        writer.add_scalar("actions/forward",
                          action_probs[RailEnvActions.MOVE_FORWARD],
                          episode_idx)
        writer.add_scalar("actions/right",
                          action_probs[RailEnvActions.MOVE_RIGHT], episode_idx)
        writer.add_scalar("actions/stop",
                          action_probs[RailEnvActions.STOP_MOVING],
                          episode_idx)
        writer.add_scalar("training/epsilon", eps_start, episode_idx)
        writer.add_scalar("training/buffer_size", len(policy.memory),
                          episode_idx)
        writer.add_scalar("training/loss", policy.loss, episode_idx)
        writer.add_scalar("timer/reset", reset_timer.get(), episode_idx)
        writer.add_scalar("timer/step", step_timer.get(), episode_idx)
        writer.add_scalar("timer/learn", learn_timer.get(), episode_idx)
        writer.add_scalar("timer/preproc", preproc_timer.get(), episode_idx)
        writer.add_scalar("timer/total", training_timer.get_current(),
                          episode_idx)
Example #30
0
class Trainer:
    """
    Trainer is a simple but feature-complete training and eval loop for PyTorch,
    optimized for 🤗 Transformers.

    Args:
        model (:class:`~transformers.PreTrainedModel`):
            The model to train, evaluate or use for predictions.
        args (:class:`~transformers.TrainingArguments`):
            The arguments to tweak training.
        data_collator (:obj:`DataCollator`, `optional`, defaults to :func:`~transformers.default_data_collator`):
            The function to use to from a batch from a list of elements of :obj:`train_dataset` or
            :obj:`eval_dataset`.
        train_dataset (:obj:`Dataset`, `optional`):
            The dataset to use for training.
        eval_dataset (:obj:`Dataset`, `optional`):
            The dataset to use for evaluation.
        compute_metrics (:obj:`Callable[[EvalPrediction], Dict]`, `optional`):
            The function that will be used to compute metrics at evaluation. Must take a
            :class:`~transformers.EvalPrediction` and return a dictionary string to metric values.
        prediction_loss_only (:obj:`bool`, `optional`, defaults to `False`):
            When performing evaluation and predictions, only returns the loss.
        tb_writer (:obj:`SummaryWriter`, `optional`):
            Object to write to TensorBoard.
        optimizers (:obj:`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR`, `optional`):
            A tuple containing the optimizer and the scheduler to use. Will default to an instance of
            :class:`~transformers.AdamW` on your model and a scheduler given by
            :func:`~transformers.get_linear_schedule_with_warmup` controlled by :obj:`args`.
    """

    model: PreTrainedModel
    args: TrainingArguments
    data_collator: DataCollator
    train_dataset: Optional[Dataset]
    eval_dataset: Optional[Dataset]
    compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None
    prediction_loss_only: bool
    tb_writer: Optional["SummaryWriter"] = None
    optimizers: Tuple[torch.optim.Optimizer,
                      torch.optim.lr_scheduler.LambdaLR] = None
    global_step: Optional[int] = None
    epoch: Optional[float] = None

    def __init__(
        self,
        model: PreTrainedModel,
        args: TrainingArguments,
        data_collator: Optional[DataCollator] = None,
        train_dataset: Optional[Dataset] = None,
        eval_dataset: Optional[Dataset] = None,
        compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
        prediction_loss_only=False,
        tb_writer: Optional["SummaryWriter"] = None,
        optimizers: Tuple[torch.optim.Optimizer,
                          torch.optim.lr_scheduler.LambdaLR] = None,
    ):
        self.model = model.to(args.device)
        self.args = args
        self.data_collator = data_collator if data_collator is not None else default_data_collator
        self.train_dataset = train_dataset
        self.eval_dataset = eval_dataset
        self.compute_metrics = compute_metrics
        self.prediction_loss_only = prediction_loss_only
        self.optimizers = optimizers
        if tb_writer is not None:
            self.tb_writer = tb_writer
        elif is_tensorboard_available() and self.is_world_master():
            self.tb_writer = SummaryWriter(log_dir=self.args.logging_dir)
        if not is_tensorboard_available():
            logger.warning(
                "You are instantiating a Trainer but Tensorboard is not installed. You should consider installing it."
            )
        if is_wandb_available():
            self._setup_wandb()
        else:
            logger.info(
                "You are instantiating a Trainer but W&B is not installed. To use wandb logging, "
                "run `pip install wandb; wandb login` see https://docs.wandb.com/huggingface."
            )
        set_seed(self.args.seed)
        # Create output directory if needed
        if self.is_world_master():
            os.makedirs(self.args.output_dir, exist_ok=True)
        if is_torch_tpu_available():
            # Set an xla_device flag on the model's config.
            # We'll find a more elegant and not need to do this in the future.
            self.model.config.xla_device = True
        if not callable(self.data_collator) and callable(
                getattr(self.data_collator, "collate_batch", None)):
            self.data_collator = self.data_collator.collate_batch
            warnings.warn(
                ("The `data_collator` should now be a simple callable (function, class with `__call__`), classes "
                 +
                 "with a `collate_batch` are deprecated and won't be supported in a future version."
                 ),
                FutureWarning,
            )

    def get_train_dataloader(self) -> DataLoader:
        """
        Returns the training :class:`~torch.utils.data.DataLoader`.
        """
        if self.train_dataset is None:
            raise ValueError("Trainer: training requires a train_dataset.")
        if is_torch_tpu_available():
            train_sampler = get_tpu_sampler(self.train_dataset)
        else:
            train_sampler = (RandomSampler(self.train_dataset)
                             if self.args.local_rank == -1 else
                             DistributedSampler(self.train_dataset))

        data_loader = DataLoader(
            self.train_dataset,
            batch_size=self.args.train_batch_size,
            sampler=train_sampler,
            collate_fn=self.data_collator,
            drop_last=self.args.dataloader_drop_last,
        )

        return data_loader

    def get_eval_dataloader(self,
                            eval_dataset: Optional[Dataset] = None
                            ) -> DataLoader:
        """
        Returns the evaluation :class:`~torch.utils.data.DataLoader`.

        Args:
            eval_dataset (:obj:`Dataset`, `optional`):
                If provided, will override `self.eval_dataset`.
        """
        if eval_dataset is None and self.eval_dataset is None:
            raise ValueError("Trainer: evaluation requires an eval_dataset.")

        eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset

        if is_torch_tpu_available():
            sampler = SequentialDistributedSampler(
                eval_dataset,
                num_replicas=xm.xrt_world_size(),
                rank=xm.get_ordinal())
        elif self.args.local_rank != -1:
            sampler = SequentialDistributedSampler(eval_dataset)
        else:
            sampler = SequentialSampler(eval_dataset)

        data_loader = DataLoader(
            eval_dataset,
            sampler=sampler,
            batch_size=self.args.eval_batch_size,
            collate_fn=self.data_collator,
            drop_last=self.args.dataloader_drop_last,
        )

        return data_loader

    def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
        """
        Returns the test :class:`~torch.utils.data.DataLoader`.

        Args:
            test_dataset (obj:`Dataset`): The test dataset to use.
        """
        # We use the same batch_size as for eval.
        if is_torch_tpu_available():
            sampler = SequentialDistributedSampler(
                test_dataset,
                num_replicas=xm.xrt_world_size(),
                rank=xm.get_ordinal())
        elif self.args.local_rank != -1:
            sampler = SequentialDistributedSampler(test_dataset)
        else:
            sampler = SequentialSampler(test_dataset)

        data_loader = DataLoader(
            test_dataset,
            sampler=sampler,
            batch_size=self.args.eval_batch_size,
            collate_fn=self.data_collator,
            drop_last=self.args.dataloader_drop_last,
        )

        return data_loader

    def get_optimizers(
        self, num_training_steps: int
    ) -> Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]:
        """
        Setup the optimizer and the learning rate scheduler.

        We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
        Trainer's init through :obj:`optimizers`, or override this method in a subclass.
        """
        if self.optimizers is not None:
            return self.optimizers
        # Prepare optimizer and schedule (linear warmup and decay)
        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [
                    p for n, p in self.model.named_parameters()
                    if not any(nd in n for nd in no_decay)
                ],
                "weight_decay":
                self.args.weight_decay,
            },
            {
                "params": [
                    p for n, p in self.model.named_parameters()
                    if any(nd in n for nd in no_decay)
                ],
                "weight_decay":
                0.0,
            },
        ]
        if self.args.use_habana and self.args.use_fused_adam:
            try:
                from hb_custom import FusedAdamW
            except ImportError:
                raise ImportError("Please install hb_custom.")
            optimizer = FusedAdamW(optimizer_grouped_parameters,
                                   lr=self.args.learning_rate,
                                   eps=self.args.adam_epsilon)
        else:
            optimizer = AdamW(optimizer_grouped_parameters,
                              lr=self.args.learning_rate,
                              eps=self.args.adam_epsilon)
        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=self.args.warmup_steps,
            num_training_steps=num_training_steps)
        return optimizer, scheduler

    def _setup_wandb(self):
        """
        Setup the optional Weights & Biases (`wandb`) integration.

        One can override this method to customize the setup if needed.  Find more information at https://docs.wandb.com/huggingface
        You can also override the following environment variables:

        Environment:
            WANDB_WATCH:
                (Optional, ["gradients", "all", "false"]) "gradients" by default, set to "false" to disable gradient logging
                or "all" to log gradients and parameters
            WANDB_PROJECT:
                (Optional): str - "huggingface" by default, set this to a custom string to store results in a different project
            WANDB_DISABLED:
                (Optional): boolean - defaults to false, set to "true" to disable wandb entirely
        """
        if self.is_world_master():
            logger.info(
                'Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"'
            )
            wandb.init(project=os.getenv("WANDB_PROJECT", "huggingface"),
                       config=vars(self.args))
            # keep track of model topology and gradients, unsupported on TPU
            if not is_torch_tpu_available() and os.getenv(
                    "WANDB_WATCH") != "false":
                wandb.watch(self.model,
                            log=os.getenv("WANDB_WATCH", "gradients"),
                            log_freq=max(100, self.args.logging_steps))

    def num_examples(self, dataloader: DataLoader) -> int:
        """
        Helper to get number of samples in a :class:`~torch.utils.data.DataLoader` by accessing its Dataset.
        """
        return len(dataloader.dataset)

    def enable_tracing(self):
        torch._C._debug_set_autodiff_subgraph_inlining(False)
        torch._C._jit_set_profiling_executor(False)
        torch._C._jit_set_profiling_mode(False)
        try:
            import hb_torch
        except ImportError:
            assert False, "Could Not import hb_torch"
        hb_torch.enable()
        hb_torch.remove_inplace_ops()

    def compute_position_ids(self, input_ids):
        input_shape = input_ids.size()
        seq_length = input_shape[1]
        position_ids_seq = torch.arange(seq_length, dtype=torch.int32)
        position_ids_ = position_ids_seq.unsqueeze(0).expand(input_shape)
        position_ids = position_ids_.contiguous()
        return position_ids

    def train(self, model_path: Optional[str] = None):
        """
        Main training entry point.

        Args:
            model_path (:obj:`str`, `optional`):
                Local path to the model if the model to train has been instantiated from a local path. If present,
                training will resume from the optimizer/scheduler states loaded here.
        """
        train_dataloader = self.get_train_dataloader()
        if self.args.max_steps > 0:
            t_total = self.args.max_steps
            num_train_epochs = (self.args.max_steps //
                                (len(train_dataloader) //
                                 self.args.gradient_accumulation_steps) + 1)
        else:
            t_total = int(
                len(train_dataloader) //
                self.args.gradient_accumulation_steps *
                self.args.num_train_epochs)
            num_train_epochs = self.args.num_train_epochs

        optimizer, scheduler = self.get_optimizers(num_training_steps=t_total)

        # Check if saved optimizer or scheduler states exist
        if (model_path is not None
                and os.path.isfile(os.path.join(model_path, "optimizer.pt"))
                and os.path.isfile(os.path.join(model_path, "scheduler.pt"))):
            # Load in optimizer and scheduler states
            optimizer.load_state_dict(
                torch.load(os.path.join(model_path, "optimizer.pt")))
            scheduler.load_state_dict(
                torch.load(os.path.join(model_path, "scheduler.pt")))

        model = self.model
        if self.args.fp16:
            if not is_apex_available():
                raise ImportError(
                    "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
                )
            model, optimizer = amp.initialize(
                model, optimizer, opt_level=self.args.fp16_opt_level)

        if self.args.hmp:
            print(self.args.hmp_bf16)
            from hmp import hmp
            hmp.convert(opt_level=self.args.hmp_opt_level,
                        bf16_file_path=self.args.hmp_bf16,
                        fp32_file_path=self.args.hmp_fp32,
                        isVerbose=self.args.hmp_verbose)

        if self.args.use_jit_trace:
            model.train()
            self.enable_tracing()

        # multi-gpu training (should be after apex fp16 initialization)
        if self.args.n_gpu > 1:
            model = torch.nn.DataParallel(model)

        # Distributed training (should be after apex fp16 initialization)
        if self.args.local_rank != -1:
            if self.args.use_habana:
                if not self.args.use_jit_trace:
                    model = torch.nn.parallel.DistributedDataParallel(
                        model, find_unused_parameters=True)
            else:
                model = torch.nn.parallel.DistributedDataParallel(
                    model,
                    device_ids=[self.args.local_rank],
                    output_device=self.args.local_rank,
                    find_unused_parameters=True,
                )

        if self.tb_writer is not None:
            self.tb_writer.add_text("args", self.args.to_json_string())
            self.tb_writer.add_hparams(self.args.to_sanitized_dict(),
                                       metric_dict={})

        # Train!
        if is_torch_tpu_available():
            total_train_batch_size = self.args.train_batch_size * xm.xrt_world_size(
            )
        else:
            total_train_batch_size = (self.args.train_batch_size *
                                      self.args.gradient_accumulation_steps *
                                      (torch.distributed.get_world_size()
                                       if self.args.local_rank != -1 else 1))
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", self.num_examples(train_dataloader))
        logger.info("  Num Epochs = %d", num_train_epochs)
        logger.info("  Instantaneous batch size per device = %d",
                    self.args.per_device_train_batch_size)
        logger.info(
            "  Total train batch size (w. parallel, distributed & accumulation) = %d",
            total_train_batch_size)
        logger.info("  Gradient Accumulation steps = %d",
                    self.args.gradient_accumulation_steps)
        logger.info("  Total optimization steps = %d", t_total)

        self.global_step = 0
        self.epoch = 0
        epochs_trained = 0
        steps_trained_in_current_epoch = 0
        is_model_traced = False
        tensor_dummy = torch.zeros(1).to(self.args.device)
        # Check if continuing training from a checkpoint
        if model_path is not None:
            # set global_step to global_step of last saved checkpoint from model path
            try:
                self.global_step = int(model_path.split("-")[-1].split("/")[0])
                epochs_trained = self.global_step // (
                    len(train_dataloader) //
                    self.args.gradient_accumulation_steps)
                steps_trained_in_current_epoch = self.global_step % (
                    len(train_dataloader) //
                    self.args.gradient_accumulation_steps)

                logger.info(
                    "  Continuing training from checkpoint, will skip to saved global_step"
                )
                logger.info("  Continuing training from epoch %d",
                            epochs_trained)
                logger.info("  Continuing training from global step %d",
                            self.global_step)
                logger.info(
                    "  Will skip the first %d steps in the first epoch",
                    steps_trained_in_current_epoch)
            except ValueError:
                self.global_step = 0
                logger.info("  Starting fine-tuning.")

        tr_loss = 0.0
        logging_loss = 0.0
        model.zero_grad()
        train_iterator = trange(epochs_trained,
                                int(num_train_epochs),
                                desc="Epoch",
                                disable=not self.is_local_master())
        for epoch in train_iterator:
            if isinstance(train_dataloader, DataLoader) and isinstance(
                    train_dataloader.sampler, DistributedSampler):
                train_dataloader.sampler.set_epoch(epoch)

            if is_torch_tpu_available():
                parallel_loader = pl.ParallelLoader(
                    train_dataloader,
                    [self.args.device]).per_device_loader(self.args.device)
                epoch_iterator = tqdm(parallel_loader,
                                      desc="Iteration",
                                      disable=not self.is_local_master())
            else:
                epoch_iterator = tqdm(train_dataloader,
                                      desc="Iteration",
                                      disable=not self.is_local_master())

            # Reset the past mems state at the beginning of each epoch if necessary.
            if self.args.past_index >= 0:
                self._past = None

            for step, inputs in enumerate(epoch_iterator):

                # Skip past any already trained steps if resuming training
                if steps_trained_in_current_epoch > 0:
                    steps_trained_in_current_epoch -= 1
                    continue

                device = self.args.device
                position_ids_cpu = self.compute_position_ids(
                    inputs['input_ids'])
                if self.args.use_habana:
                    inputs['input_ids'] = inputs['input_ids'].to(
                        dtype=torch.int32)
                    inputs['attention_mask'] = inputs['attention_mask'].to(
                        dtype=torch.int32)
                    inputs['token_type_ids'] = inputs['token_type_ids'].to(
                        dtype=torch.int32)
                    inputs['labels'] = inputs['labels'].to(dtype=torch.int32)
                    inputs['position_ids'] = position_ids_cpu

                if self.args.use_jit_trace and is_model_traced == False:
                    input_ids = inputs['input_ids'].to(device)
                    attention_mask = inputs['attention_mask'].to(device)
                    token_type_ids = inputs['token_type_ids'].to(device)
                    labels = inputs['labels'].to(device)
                    position_ids = position_ids_cpu.to(device)
                    model_trace = torch.jit.trace(
                        model, (input_ids, attention_mask, token_type_ids,
                                position_ids, tensor_dummy, tensor_dummy,
                                labels, tensor_dummy, tensor_dummy),
                        check_trace=False)
                    is_model_traced = True
                    model = model_trace
                    if self.args.local_rank != -1:
                        if self.args.use_habana:
                            model = torch.nn.parallel.DistributedDataParallel(
                                model, find_unused_parameters=True)

                tr_loss += self._training_step(model, inputs, optimizer,
                                               tensor_dummy)

                if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
                        # last step in epoch but step is always smaller than gradient_accumulation_steps
                        len(epoch_iterator) <=
                        self.args.gradient_accumulation_steps and
                    (step + 1) == len(epoch_iterator)):
                    if self.args.fp16:
                        torch.nn.utils.clip_grad_norm_(
                            amp.master_params(optimizer),
                            self.args.max_grad_norm)
                    else:
                        if self.args.use_habana:
                            if self.args.use_fused_clip_norm:
                                try:
                                    from hb_custom import FusedClipNorm
                                except ImportError:
                                    raise ImportError(
                                        "Please install hb_custom.")

                                FusedClipNorm(model.parameters(),
                                              self.args.max_grad_norm)
                            else:
                                if self.args.hmp:
                                    from hmp import hmp
                                    with hmp.disable_casts():
                                        torch.nn.utils.clip_grad_norm_(
                                            model.parameters(),
                                            self.args.max_grad_norm)
                                else:
                                    torch.nn.utils.clip_grad_norm_(
                                        model.parameters(),
                                        self.args.max_grad_norm)
                        else:
                            torch.nn.utils.clip_grad_norm_(
                                model.parameters(), self.args.max_grad_norm)

                    if is_torch_tpu_available():
                        xm.optimizer_step(optimizer)
                    else:
                        if self.args.use_habana and self.args.hmp and not (
                                self.args.use_fused_adam):
                            from hmp import hmp
                            with hmp.disable_casts():
                                optimizer.step()
                        else:
                            optimizer.step()

                    scheduler.step()
                    model.zero_grad()
                    self.global_step += 1
                    self.epoch = epoch + (step + 1) / len(epoch_iterator)

                    if (self.args.logging_steps > 0
                            and self.global_step % self.args.logging_steps
                            == 0) or (self.global_step == 1
                                      and self.args.logging_first_step):
                        logs: Dict[str, float] = {}
                        logs["loss"] = (tr_loss -
                                        logging_loss) / self.args.logging_steps
                        # backward compatibility for pytorch schedulers
                        logs["learning_rate"] = (
                            scheduler.get_last_lr()[0]
                            if version.parse(torch.__version__) >=
                            version.parse("1.4") else scheduler.get_lr()[0])
                        logging_loss = tr_loss

                        self._log(logs)

                    if self.args.evaluate_during_training and self.global_step % self.args.eval_steps == 0:
                        self.evaluate()

                    if self.args.save_steps > 0 and self.global_step % self.args.save_steps == 0:
                        # In all cases (even distributed/parallel), self.model is always a reference
                        # to the model we want to save.
                        if hasattr(model, "module"):
                            assert model.module is self.model
                        else:
                            if not self.args.use_jit_trace:
                                assert model is self.model
                        # Save model checkpoint
                        output_dir = os.path.join(
                            self.args.output_dir,
                            f"{PREFIX_CHECKPOINT_DIR}-{self.global_step}")

                        self.save_model(output_dir)

                        if self.is_world_master():
                            self._rotate_checkpoints()

                        if is_torch_tpu_available():
                            xm.rendezvous("saving_optimizer_states")
                            xm.save(optimizer.state_dict(),
                                    os.path.join(output_dir, "optimizer.pt"))
                            xm.save(scheduler.state_dict(),
                                    os.path.join(output_dir, "scheduler.pt"))
                        elif self.is_world_master():
                            if self.args.use_habana:
                                # shallow copy of dict followed by dma
                                import copy
                                optim_dict = dict()
                                for state in optimizer.state.values():
                                    for k, v in state.items():
                                        if isinstance(v, torch.Tensor):
                                            optim_dict[k] = v.to('cpu')
                                        else:
                                            optim_dict[k] = v
                                torch.save(
                                    optim_dict,
                                    os.path.join(output_dir, "optimizer.pt"))
                            else:
                                torch.save(
                                    optimizer.state_dict(),
                                    os.path.join(output_dir, "optimizer.pt"))

                            torch.save(
                                scheduler.state_dict(),
                                os.path.join(output_dir, "scheduler.pt"))

                if self.args.max_steps > 0 and self.global_step > self.args.max_steps:
                    epoch_iterator.close()
                    break
            if self.args.max_steps > 0 and self.global_step > self.args.max_steps:
                train_iterator.close()
                break
            if self.args.tpu_metrics_debug or self.args.debug:
                # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
                xm.master_print(met.metrics_report())

        if self.tb_writer:
            self.tb_writer.close()
        if self.args.past_index and hasattr(self, "_past"):
            # Clean the state at the end of training
            delattr(self, "_past")

        logger.info(
            "\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n"
        )
        return TrainOutput(self.global_step, tr_loss / self.global_step)

    def _log(self,
             logs: Dict[str, float],
             iterator: Optional[tqdm] = None) -> None:
        if self.epoch is not None:
            logs["epoch"] = self.epoch
        if self.global_step is None:
            # when logging evaluation metrics without training
            self.global_step = 0
        if self.tb_writer:
            for k, v in logs.items():
                if isinstance(v, (int, float)):
                    self.tb_writer.add_scalar(k, v, self.global_step)
                else:
                    logger.warning(
                        "Trainer is attempting to log a value of "
                        '"%s" of type %s for key "%s" as a scalar. '
                        "This invocation of Tensorboard's writer.add_scalar() "
                        "is incorrect so we dropped this attribute.",
                        v,
                        type(v),
                        k,
                    )
            self.tb_writer.flush()
        if is_wandb_available():
            if self.is_world_master():
                wandb.log(logs, step=self.global_step)
        output = {**logs, **{"step": self.global_step}}
        if iterator is not None:
            iterator.write(output)
        else:
            logger.info(output)

    def _training_step(self,
                       model: nn.Module,
                       inputs: Dict[str, Union[torch.Tensor, Any]],
                       optimizer: torch.optim.Optimizer,
                       tensor_dummy=None) -> float:
        model.train()
        for k, v in inputs.items():
            if isinstance(v, torch.Tensor):
                inputs[k] = v.to(self.args.device)

        if self.args.past_index >= 0 and self._past is not None:
            inputs["mems"] = self._past

        if self.args.use_jit_trace:
            outputs = model(inputs['input_ids'], inputs['attention_mask'],
                            inputs['token_type_ids'], inputs['position_ids'],
                            tensor_dummy, tensor_dummy, inputs['labels'],
                            tensor_dummy, tensor_dummy)
        else:
            outputs = model(**inputs)
        loss = outputs[
            0]  # model outputs are always tuple in transformers (see doc)

        if self.args.past_index >= 0:
            self._past = outputs[self.args.past_index]

        if self.args.n_gpu > 1:
            loss = loss.mean(
            )  # mean() to average on multi-gpu parallel training
        if self.args.gradient_accumulation_steps > 1:
            loss = loss / self.args.gradient_accumulation_steps

        if self.args.fp16:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()

        return loss.item()

    def is_local_master(self) -> bool:
        if is_torch_tpu_available():
            return xm.is_master_ordinal(local=True)
        else:
            return self.args.local_rank in [-1, 0]

    def is_world_master(self) -> bool:
        """
        This will be True only in one process, even in distributed mode,
        even when training on multiple machines.
        """
        if is_torch_tpu_available():
            return xm.is_master_ordinal(local=False)
        else:
            return self.args.local_rank == -1 or torch.distributed.get_rank(
            ) == 0

    def save_model(self, output_dir: Optional[str] = None):
        """
        Will save the model, so you can reload it using :obj:`from_pretrained()`.

        Will only save from the world_master process (unless in TPUs).
        """

        if is_torch_tpu_available():
            self._save_tpu(output_dir)
        elif self.is_world_master():
            self._save(output_dir)

    def _save_tpu(self, output_dir: Optional[str] = None):
        output_dir = output_dir if output_dir is not None else self.args.output_dir
        logger.info("Saving model checkpoint to %s", output_dir)

        if xm.is_master_ordinal():
            os.makedirs(output_dir, exist_ok=True)
            torch.save(self.args, os.path.join(output_dir,
                                               "training_args.bin"))

        # Save a trained model and configuration using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
        if not isinstance(self.model, PreTrainedModel):
            raise ValueError(
                "Trainer.model appears to not be a PreTrainedModel")

        xm.rendezvous("saving_checkpoint")
        self.model.save_pretrained(output_dir)

    def _save(self, output_dir: Optional[str] = None):
        output_dir = output_dir if output_dir is not None else self.args.output_dir
        os.makedirs(output_dir, exist_ok=True)
        logger.info("Saving model checkpoint to %s", output_dir)
        # Save a trained model and configuration using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
        if not isinstance(self.model, PreTrainedModel):
            raise ValueError(
                "Trainer.model appears to not be a PreTrainedModel")
        self.model.save_pretrained(output_dir)

        # Good practice: save your training arguments together with the trained model
        torch.save(self.args, os.path.join(output_dir, "training_args.bin"))

    def _sorted_checkpoints(self,
                            checkpoint_prefix=PREFIX_CHECKPOINT_DIR,
                            use_mtime=False) -> List[str]:
        ordering_and_checkpoint_path = []

        glob_checkpoints = [
            str(x)
            for x in Path(self.args.output_dir).glob(f"{checkpoint_prefix}-*")
        ]

        for path in glob_checkpoints:
            if use_mtime:
                ordering_and_checkpoint_path.append(
                    (os.path.getmtime(path), path))
            else:
                regex_match = re.match(f".*{checkpoint_prefix}-([0-9]+)", path)
                if regex_match and regex_match.groups():
                    ordering_and_checkpoint_path.append(
                        (int(regex_match.groups()[0]), path))

        checkpoints_sorted = sorted(ordering_and_checkpoint_path)
        checkpoints_sorted = [
            checkpoint[1] for checkpoint in checkpoints_sorted
        ]
        return checkpoints_sorted

    def _rotate_checkpoints(self, use_mtime=False) -> None:
        if self.args.save_total_limit is None or self.args.save_total_limit <= 0:
            return

        # Check if we should delete older checkpoint(s)
        checkpoints_sorted = self._sorted_checkpoints(use_mtime=use_mtime)
        if len(checkpoints_sorted) <= self.args.save_total_limit:
            return

        number_of_checkpoints_to_delete = max(
            0,
            len(checkpoints_sorted) - self.args.save_total_limit)
        checkpoints_to_be_deleted = checkpoints_sorted[:
                                                       number_of_checkpoints_to_delete]
        for checkpoint in checkpoints_to_be_deleted:
            logger.info(
                "Deleting older checkpoint [{}] due to args.save_total_limit".
                format(checkpoint))
            shutil.rmtree(checkpoint)

    def evaluate(self,
                 eval_dataset: Optional[Dataset] = None) -> Dict[str, float]:
        """
        Run evaluation and returns metrics.

        The calling script will be responsible for providing a method to compute metrics, as they are
        task-dependent (pass it to the init :obj:`compute_metrics` argument).

        Args:
            eval_dataset (:obj:`Dataset`, `optional`):
                Pass a dataset if you wish to override :obj:`self.eval_dataset`.
        Returns:
            A dictionary containing the evaluation loss and the potential metrics computed from the predictions.
        """
        eval_dataloader = self.get_eval_dataloader(eval_dataset)

        output = self._prediction_loop(eval_dataloader,
                                       description="Evaluation")

        self._log(output.metrics)

        if self.args.tpu_metrics_debug or self.args.debug:
            # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
            xm.master_print(met.metrics_report())

        return output.metrics

    def predict(self, test_dataset: Dataset) -> PredictionOutput:
        """
        Run prediction and returns predictions and potential metrics.

        Depending on the dataset and your use case, your test dataset may contain labels.
        In that case, this method will also return metrics, like in :obj:`evaluate()`.

        Args:
            test_dataset (:obj:`Dataset`):
                Dataset to run the predictions on.
        Returns:
            `NamedTuple`:
            predictions (:obj:`np.ndarray`):
                The predictions on :obj:`test_dataset`.
            label_ids (:obj:`np.ndarray`, `optional`):
                The labels (if the dataset contained some).
            metrics (:obj:`Dict[str, float]`, `optional`):
                The potential dictionary of metrics (if the dataset contained labels).
        """
        test_dataloader = self.get_test_dataloader(test_dataset)

        return self._prediction_loop(test_dataloader, description="Prediction")

    def _prediction_loop(
            self,
            dataloader: DataLoader,
            description: str,
            prediction_loss_only: Optional[bool] = None) -> PredictionOutput:
        """
        Prediction/evaluation loop, shared by `evaluate()` and `predict()`.

        Works both with or without labels.
        """

        prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else self.prediction_loss_only

        if self.args.use_jit_trace:
            self.enable_tracing()

        model = self.model
        # multi-gpu eval
        if self.args.n_gpu > 1:
            model = torch.nn.DataParallel(model)
        else:
            model = self.model
        # Note: in torch.distributed mode, there's no point in wrapping the model
        # inside a DistributedDataParallel as we'll be under `no_grad` anyways.

        batch_size = dataloader.batch_size
        logger.info("***** Running %s *****", description)
        logger.info("  Num examples = %d", self.num_examples(dataloader))
        logger.info("  Batch size = %d", batch_size)
        eval_losses: List[float] = []
        preds: torch.Tensor = None
        label_ids: torch.Tensor = None
        model.eval()

        if is_torch_tpu_available():
            dataloader = pl.ParallelLoader(
                dataloader,
                [self.args.device]).per_device_loader(self.args.device)

        if self.args.past_index >= 0:
            past = None

        is_eval_traced = False
        tensor_dummy = torch.zeros(1).to(self.args.device)

        for inputs in tqdm(dataloader, desc=description):
            has_labels = any(
                inputs.get(k) is not None
                for k in ["labels", "lm_labels", "masked_lm_labels"])

            ## Habana doesn't support Long tensors
            ## Hence we need to convert start and end positions to int
            if self.args.use_habana:
                inputs["input_ids"] = inputs["input_ids"].to(dtype=torch.int32)
                inputs["attention_mask"] = inputs["attention_mask"].to(
                    dtype=torch.int32)
                inputs["token_type_ids"] = inputs["token_type_ids"].to(
                    dtype=torch.int32)

            position_ids_cpu = self.compute_position_ids(inputs["input_ids"])

            for k, v in inputs.items():
                if isinstance(v, torch.Tensor):
                    inputs[k] = v.to(self.args.device)
            if self.args.past_index >= 0:
                inputs["mems"] = past

            inputs["position_ids"] = position_ids_cpu.to(self.args.device)

            with torch.no_grad():
                if self.args.use_jit_trace and is_eval_traced == False:
                    model_trace = torch.jit.trace(
                        model,
                        (inputs['input_ids'], inputs['attention_mask'],
                         inputs['token_type_ids'], inputs["position_ids"],
                         tensor_dummy, tensor_dummy, inputs['labels'],
                         tensor_dummy, tensor_dummy),
                        check_trace=False)
                    model_trace.eval()
                    is_eval_traced = True
                    model = model_trace
                if self.args.use_jit_trace:
                    outputs = model(inputs['input_ids'],
                                    inputs['attention_mask'],
                                    inputs['token_type_ids'],
                                    inputs["position_ids"], tensor_dummy,
                                    tensor_dummy, inputs['labels'],
                                    tensor_dummy, tensor_dummy)
                else:
                    outputs = model(**inputs)
                if has_labels:
                    step_eval_loss, logits = outputs[:2]
                    eval_losses += [step_eval_loss.mean().item()]
                else:
                    logits = outputs[0]
                if self.args.past_index >= 0:
                    past = outputs[self.args.past_index if has_labels else self
                                   .args.past_index - 1]

            if not prediction_loss_only:
                if preds is None:
                    preds = logits.detach()
                else:
                    preds = torch.cat((preds, logits.detach()), dim=0)
                if inputs.get("labels") is not None:
                    if label_ids is None:
                        label_ids = inputs["labels"].detach()
                    else:
                        label_ids = torch.cat(
                            (label_ids, inputs["labels"].detach()), dim=0)

        if self.args.local_rank != -1:
            # In distributed mode, concatenate all results from all nodes:
            if preds is not None:
                preds = self.distributed_concat(
                    preds, num_total_examples=self.num_examples(dataloader))
            if label_ids is not None:
                #workaround to convert to int type
                label_ids = label_ids.to("cpu").type(
                    torch.IntTensor).to("habana")
                label_ids = self.distributed_concat(
                    label_ids,
                    num_total_examples=self.num_examples(dataloader))
        elif is_torch_tpu_available():
            # tpu-comment: Get all predictions and labels from all worker shards of eval dataset
            if preds is not None:
                preds = xm.mesh_reduce("eval_preds", preds, torch.cat)
            if label_ids is not None:
                label_ids = xm.mesh_reduce("eval_label_ids", label_ids,
                                           torch.cat)

        # Finally, turn the aggregated tensors into numpy arrays.
        if preds is not None:
            preds = preds.float().cpu().numpy()
        if label_ids is not None:
            label_ids = label_ids.float().cpu().numpy()

        if self.compute_metrics is not None and preds is not None and label_ids is not None:
            metrics = self.compute_metrics(
                EvalPrediction(predictions=preds, label_ids=label_ids))
        else:
            metrics = {}
        if len(eval_losses) > 0:
            metrics["eval_loss"] = np.mean(eval_losses)

        # Prefix all keys with eval_
        for key in list(metrics.keys()):
            if not key.startswith("eval_"):
                metrics[f"eval_{key}"] = metrics.pop(key)

        return PredictionOutput(predictions=preds,
                                label_ids=label_ids,
                                metrics=metrics)

    def distributed_concat(self, tensor: torch.Tensor,
                           num_total_examples: int) -> torch.Tensor:
        assert self.args.local_rank != -1

        output_tensors = [
            tensor.clone() for _ in range(torch.distributed.get_world_size())
        ]
        torch.distributed.all_gather(output_tensors, tensor)

        concat = torch.cat(output_tensors, dim=0)

        # truncate the dummy elements added by SequentialDistributedSampler
        output = concat[:num_total_examples]
        return output