Example #1
0
    def init_training(self):
        model = self.elements["model"]
        start_epoch = self.params["start_epoch"]
        exist_model = self.params["exist_model"]
        model_dir = self.params["model_dir"]
        model_blueprint = self.params["model_blueprint"]
        suffix = self.params["suffix"]

        if start_epoch <= 0 and utils.is_main_training():
            model_creation = model.get_model_creation()
            utils.write_nnet_config(model_blueprint, model_creation, "{0}/config/nnet.config".format(model_dir))

        ## Recover checkpoint | Tansform learning | Initialize parametes 
        if start_epoch > 0:
            # This train_stage is equal to number of completed epoch
            if utils.is_main_training(): logger.info("Recover training from {0} epoch.".format(start_epoch))
            model.load_state_dict(torch.load('{0}/{1}.{2}'.format(model_dir, start_epoch, suffix), 
                                             map_location="cpu"))
        elif os.path.exists(exist_model):
            if utils.is_main_training(): logger.info("Use {0} as the initial model to start transform-training.".format(exist_model))
            model.load_transform_state_dict(torch.load(exist_model, map_location="cpu"))
        else:
            # Just use the raw initial model or initialize it again by some initial functions here
            pass # Now, it means use the raw initial model

        if utils.use_horovod():
            import horovod.torch as hvd

            # Broadcast parameters from rank 0 to all other processes.
            hvd.broadcast_parameters(self.elements["model"].state_dict(), root_rank=0)

             # For optimizer wrapper such as lookahead.
            if getattr(self.elements["optimizer"], "optimizer", None) is not None:
                raise TypeError("Do not support using lookahead with horovod now.")
            else:
                # Broadcast optimizer state.
                hvd.broadcast_optimizer_state(self.elements["optimizer"], root_rank=0)
                self.elements["optimizer"] = hvd.DistributedOptimizer(self.elements["optimizer"], 
                                             named_parameters=self.elements["model"].named_parameters())

        ## Select device
        model = self.select_device()

        # Original model is built in libs.nnet.framework.TopVirtualNnet, and it is not available after
        # wrapped by DistributedDataParallel. So, to call functions of TopVirtualNnet conveniently, the 
        # self.elements["model_forward"] is set here to name DistributedDataParallel.
        if isinstance(model, torch.nn.parallel.DistributedDataParallel):
            self.elements["model"] = model.module
            self.elements["model_forward"] = model
Example #2
0
 def run_lr_finder(self,
                   save_file: str,
                   comment=None,
                   init_lr=1e-8,
                   final_lr=10.,
                   num_iters=None,
                   beta=0.98):
     self.init_training()
     log_dir = self.params["model_dir"] + "/log/"  # For tensorboardX
     if comment is not None:
         save_file = comment + "-" + save_file
     save_file = log_dir + save_file
     log_lrs, values_matrix = self.lr_finder_compute(
         self.elements["data"].train_loader,
         self.elements["optimizer"],
         init_lr=init_lr,
         final_lr=final_lr,
         num_iters=num_iters,
         beta=beta,
         log_dir=log_dir,
         comment=comment)
     if utils.is_main_training():
         df = pd.DataFrame(np.vstack([log_lrs, values_matrix]).T,
                           columns=[
                               "log_lr", "train_loss", "train_acc",
                               "valid_loss", "valid_acc"
                           ])
         logger.info("Save lr finder values to {}.".format(save_file))
         df.to_csv(save_file)
Example #3
0
    def step(self, training_point=None, valid_metric=None):
        if self.name == "warmR":
            if self.lr_decay_step > 0 and training_point[1]%self.lr_decay_step == 0:
                self.lr_scheduler.step(training_point[0]+training_point[1]/training_point[2])
            elif self.lr_decay_step == 0:
                self.lr_scheduler.step(training_point[0])
        elif self.name == "1cycle":
            self.lr_scheduler.step()
        elif self.name == "reduceP":
            # Sample a point in which the metrics of valid are computed and adjust learning rate at this point.
            if self.is_reduce_point(training_point):
                # Do not support horovod now.
                if utils.use_ddp():
                    # Multi-gpu case.
                    # In this case, we do not compute valid set for all processes but just computing it in main process
                    # and broadcast the metrics to other processes.
                    if not self.init:
                        device = utils.get_device_from_optimizer(self.lr_scheduler.optimizer)
                        # Create a must tentor to prepare to broadcast with torch.distributed.broadcast fuction.
                        self.metric = torch.randn(2, device=device) 
                        # New a group to broadcast the special metric tensor. It is important.
                        self.group = torch.distributed.new_group(ranks=list(range(torch.distributed.get_world_size())), 
                                                                 backend="nccl")
                        self.init = True
                    if utils.is_main_training():
                        # Gather the new value of metric.
                        self.metric = torch.tensor([valid_metric[0], valid_metric[1]], device=self.metric.device)
                    # Broadcast
                    torch.distributed.broadcast(self.metric, 0, group=self.group)
                    metric = self.metric[0] if self.metric == "valid_loss" else self.metric[1]
                else:
                    # Single-GPU case.
                    metric = valid_metric[0] if self.metric == "valid_loss" else valid_metric[1]

                self.lr_scheduler.step(metric)
Example #4
0
    def get_bunch_from_csv(self,
                           trainset_csv: str,
                           valid_csv: str = None,
                           egs_params: dict = {},
                           data_loader_params_dict: dict = {}):
        Egs = ChunkEgs
        if "egs_type" in egs_params.keys():
            egs_type = egs_params.pop("egs_type")
            if egs_type == "chunk":
                pass
            elif egs_type == "vector":
                Egs = VectorEgs
            else:
                raise TypeError(
                    "Do not support {} egs now. Select one from [chunk, vector]."
                    .format(egs_type))

        trainset = Egs(trainset_csv, **egs_params)
        # For multi-GPU training.
        if not utils.is_main_training():
            valid = None
        if valid_csv != "" and valid_csv is not None:
            valid = Egs(valid_csv)
        else:
            valid = None
        return self(trainset, valid, **data_loader_params_dict)
Example #5
0
 def lr_finder_compute(self, train_batch):
     model = self.elements["model"]
     if model.use_step:
         model.step(*self.training_point)
     loss, acc = self.train_one_batch(train_batch)
     model.backward_step(*self.training_point)
     if utils.is_main_training():
         valid_loss, valid_acc = self.compute_validation(self.elements["data"].valid_loader)
     return ["train_loss", "train_acc", "valid_loss", "valid_acc"], [loss, acc, valid_loss, valid_acc]
Example #6
0
 def lr_finder_compute(self, train_batch):
     model = self.elements["model"]
     if model.use_step:
         model.step(*self.training_point)
     loss, acc = self.train_one_batch(train_batch)
     model.backward_step(*self.training_point)
     if utils.is_main_training():
         valid_loss, valid_acc = self.compute_validation(
             self.elements["data"].valid_loader)
     weight = model.loss.weight.squeeze(dim=2)
     weight = F.normalize(weight, dim=1)
     orth = 0.
     for i in range(weight.shape[0]):
         for j in range(i + 1, weight.shape[0]):
             orth += torch.dot(weight[i], weight[j]).item()
     orth /= weight.shape[0] * (weight.shape[0] - 1) / 2
     return ["train_loss", "train_acc", "valid_loss", "valid_acc",
             "orth"], [loss, acc, valid_loss, valid_acc, orth]
##--------------------------------------------------##
##
######################################################### START #########################################################
##
#### Set seed
utils.set_all_seed(1024)
##
#### Init environment
# It is used for multi-gpu training if used (number of gpu-id > 1).
# And it will do nothing for single-GPU training.
utils.init_multi_gpu_training(args.gpu_id, args.multi_gpu_solution, args.port)
##
#### Set sleep time for a rest
# Use it to run a launcher with a countdown function when there are no extra GPU memory
# but you really want to go to bed and know when the GPU memory will be free.
if args.sleep > 0 and utils.is_main_training():
    logger.info("This launcher will sleep {}s before starting...".format(
        args.sleep))
    time.sleep(args.sleep)
##
#### Auto-config params
# If multi-GPU used, it will auto-scale learning rate by multiplying number of processes.
optimizer_params["learn_rate"] = utils.auto_scale_lr(
    optimizer_params["learn_rate"])
# It is used for model.step() defined in model blueprint.
if lr_scheduler_params["name"] == "warmR" and model_params["use_step"]:
    model_params["step_params"]["T"] = (lr_scheduler_params["warmR.T_max"],
                                        lr_scheduler_params["warmR.T_mult"])
##
#### Preprocess
if stage <= 2 and endstage >= 0 and utils.is_main_training():
Example #8
0
if args.sleep > 0: time.sleep(args.sleep)
##
#### Init environment
# It is used for multi-gpu training if used (number of gpu-id > 1).
# And it will do nothing for single-GPU training.
utils.init_multi_gpu_training(args.gpu_id, args.multi_gpu_solution, args.port)
##
#### Auto-config params
# If multi-GPU used, it will auto-scale learning rate by multiplying number of processes.
optimizer_params["learn_rate"] = utils.auto_scale_lr(optimizer_params["learn_rate"])
# It is used for model.step() defined in model blueprint.
if lr_scheduler_params["name"] == "warmR" and model_params["use_step"]:
    model_params["step_params"]["T"]=(lr_scheduler_params["warmR.T_max"], lr_scheduler_params["warmR.T_mult"])
##
#### Preprocess
if stage <= 2 and endstage >= 0 and utils.is_main_training():
    # Here only give limited options because it is not convenient.
    # Suggest to pre-execute this shell script to make it freedom and then continue to run this launcher.
    kaldi_common.execute_command("sh subtools/pytorch/pipeline/preprocess_to_egs.sh "
                                 "--stage {stage} --endstage {endstage} --valid-split-type {valid_split_type} "
                                 "--nj {nj} --cmn {cmn} --limit-utts {limit_utts} --min-chunk {chunk_size} --overlap {overlap} "
                                 "--sample-type {sample_type} --chunk-num {chunk_num} --scale {scale} --force-clear {force_clear} "
                                 "--valid-num-utts {valid_utts} --valid-chunk-num {valid_chunk_num_every_utt} "
                                 "{traindata} {egs_dir}".format(stage=stage, endstage=endstage, valid_split_type=valid_split_type, 
                                 nj=preprocess_nj, cmn=str(cmn).lower(), limit_utts=limit_utts, chunk_size=chunk_size, overlap=overlap, 
                                 sample_type=sample_type, chunk_num=chunk_num, scale=scale, force_clear=str(force_clear).lower(), 
                                 valid_utts=valid_utts, valid_chunk_num_every_utt=valid_chunk_num_every_utt, traindata=traindata, 
                                 egs_dir=egs_dir))

#### Train model
if stage <= 3 <= endstage:
Example #9
0
    def run(self):
        """Main function to start a training process.
        """
        try:
            self.init_training()

            if utils.is_main_training():
                self.reporter = Reporter(self)

            start_epoch = self.params["start_epoch"]
            epochs = self.params["epochs"]
            data = self.elements["data"]
            model = self.elements["model"]
            model_forward = self.elements[
                "model_forward"]  # See init_training.
            lr_scheduler = self.elements["lr_scheduler"]
            base_optimizer = self.elements["optimizer"]

            # For lookahead.
            if getattr(base_optimizer, "optimizer", None) is not None:
                base_optimizer = base_optimizer.optimizer
            last_lr = base_optimizer.state_dict()['param_groups'][0]['lr']

            if utils.is_main_training():
                logger.info("Training will run for {0} epochs.".format(epochs))

            for this_epoch in range(start_epoch, epochs):
                # Set random seed w.r.t epoch for distributed training.
                if isinstance(data.train_loader.sampler, torch.utils.data.distributed.DistributedSampler) and \
                    self.params["ddp_random_epoch"]:
                    data.train_loader.sampler.set_epoch(this_epoch)
                for this_iter, batch in enumerate(data.train_loader, 0):
                    self.training_point = (this_epoch, this_iter,
                                           data.num_batch_train
                                           )  # It is important for reporter.

                    if model.use_step:
                        model.step(*self.training_point)

                    loss, acc = self.train_one_batch(batch)

                    model.backward_step(*self.training_point)

                    # For multi-GPU training. Remember that it is not convenient to wrap lr_scheduler
                    # for there are many strategies with different details. Here, only warmR, ReduceLROnPlateau
                    # and some simple schedulers whose step() parameter is 'epoch' only are supported.
                    lr_scheduler_params = {
                        "training_point": self.training_point
                    }

                    if utils.is_main_training(
                    ) or lr_scheduler.name == "reduceP":
                        if data.valid_loader and (self.reporter.is_report(self.training_point) or \
                           lr_scheduler.is_reduce_point(self.training_point)):

                            valid_loss, valid_acc = self.compute_validation(
                                data.valid_loader)
                            # real_snapshot is set for tensorboard to avoid workspace problem
                            real_snapshot = {
                                "train_loss": loss,
                                "valid_loss": valid_loss,
                                "train_acc": acc * 100,
                                "valid_acc": valid_acc * 100
                            }
                            snapshot = {
                                "train_loss": "{0:.6f}".format(loss),
                                "valid_loss": "{0:.6f}".format(valid_loss),
                                "train_acc": "{0:.2f}".format(acc * 100),
                                "valid_acc": "{0:.2f}".format(valid_acc * 100),
                                "real": real_snapshot
                            }
                            try:
                                weight = model.loss.weight.squeeze(dim=2)
                                weight = F.normalize(weight, dim=1)
                                orth_snapshot = {"orth_snp": 0.}
                                for i in range(weight.shape[0]):
                                    for j in range(i + 1, weight.shape[0]):
                                        orth_snapshot["orth_snp"] += torch.dot(
                                            weight[i], weight[j]).item()
                                orth_snapshot["orth_snp"] /= weight.shape[
                                    0] * (weight.shape[0] - 1) / 2
                                real_snapshot.update(orth_snapshot)
                                snapshot.update(orth_snapshot)
                                snapshot["real"] = real_snapshot
                            except Exception as e:
                                pass
                            # For ReduceLROnPlateau.
                            lr_scheduler_params["valid_metric"] = (valid_loss,
                                                                   valid_acc)
                        else:
                            real_snapshot = {
                                "train_loss": loss,
                                "train_acc": acc * 100
                            }
                            snapshot = {
                                "train_loss": "{0:.6f}".format(loss),
                                "valid_loss": "",
                                "train_acc": "{0:.2f}".format(acc * 100),
                                "valid_acc": "",
                                "real": real_snapshot
                            }

                    if lr_scheduler is not None:
                        # It is not convenient to wrap lr_scheduler (doing).
                        if isinstance(lr_scheduler, LRSchedulerWrapper):
                            lr_scheduler.step(**lr_scheduler_params)
                            if lr_scheduler.name == "reduceP" and utils.is_main_training(
                            ):
                                current_lr = base_optimizer.state_dict(
                                )['param_groups'][0]['lr']
                                if current_lr < last_lr:
                                    last_lr = current_lr
                                    self.save_model(from_epoch=False)
                        else:
                            # For some pytorch lr_schedulers, but it is not available for all.
                            lr_scheduler.step(this_epoch)
                    if utils.is_main_training(): self.reporter.update(snapshot)
                if utils.is_main_training(): self.save_model()
            if utils.is_main_training(): self.reporter.finish()
        except BaseException as e:
            if utils.use_ddp(): utils.cleanup_ddp()
            if not isinstance(e, KeyboardInterrupt):
                traceback.print_exc()
            sys.exit(1)
Example #10
0
    def run(self):
        """Main function to start a training process.
        """
        try:
            self.init_training()

            if utils.is_main_training():
                self.reporter = Reporter(self)

            start_epoch = self.params["start_epoch"]
            epochs = self.params["epochs"]
            data = self.elements["data"]
            model = self.elements["model"]
            model_forward = self.elements[
                "model_forward"]  # See init_training.
            lr_scheduler = self.elements["lr_scheduler"]

            if utils.is_main_training():
                logger.info("Training will run for {0} epochs.".format(epochs))

            for this_epoch in range(start_epoch, epochs):
                for this_iter, batch in enumerate(data.train_loader, 0):
                    self.training_point = (this_epoch, this_iter,
                                           data.num_batch_train
                                           )  # It is important for reporter.

                    if model.use_step:
                        model.step(*self.training_point)

                    loss, acc = self.train_one_batch(batch)

                    # For multi-GPU training. Remember that it is not convenient to wrap lr_scheduler
                    # for there are many strategies with different details. Here, only warmR, ReduceLROnPlateau
                    # and some simple schedulers whose step() parameter is 'epoch' only are supported.
                    lr_scheduler_params = {
                        "training_point": self.training_point
                    }

                    if utils.is_main_training(
                    ) or lr_scheduler.name == "reduceP":
                        if data.valid_loader and (self.reporter.is_report(self.training_point) or \
                           lr_scheduler.is_reduce_point(self.training_point)):

                            valid_loss, valid_acc = self.compute_validation(
                                data.valid_loader)
                            snapshot = {
                                "train_loss": "{0:.6f}".format(loss),
                                "valid_loss": "{0:.6f}".format(valid_loss),
                                "train_acc": "{0:.2f}".format(acc * 100),
                                "valid_acc": "{0:.2f}".format(valid_acc * 100)
                            }
                            # For ReduceLROnPlateau.
                            lr_scheduler_params["valid_metric"] = (valid_loss,
                                                                   valid_acc)
                        else:
                            snapshot = {
                                "train_loss": "{0:.6f}".format(loss),
                                "valid_loss": "",
                                "train_acc": "{0:.2f}".format(acc * 100),
                                "valid_acc": ""
                            }

                    if lr_scheduler is not None:
                        # It is not convenient to wrap lr_scheduler (doing).
                        if isinstance(lr_scheduler, LRSchedulerWrapper):
                            lr_scheduler.step(**lr_scheduler_params)
                        else:
                            # For some pytorch lr_schedulers, but it is not available for all.
                            lr_scheduler.step(this_epoch)
                    if utils.is_main_training(): self.reporter.update(snapshot)
                if utils.is_main_training(): self.save_model()
            if utils.is_main_training(): self.reporter.finish()
        except BaseException as e:
            if utils.use_ddp(): utils.cleanup_ddp()
            if not isinstance(e, KeyboardInterrupt):
                traceback.print_exc()
            sys.exit(1)
Example #11
0
    def wrapper(self,
                trn_loader,
                optimizer,
                init_lr=1e-6,
                final_lr=10.,
                num_iters=None,
                beta=0.98,
                split=[5, -10],
                log_dir=None,
                comment=None):
        if init_lr < 0:
            raise ValueError(
                "Expected init_lr > 0, but got init_lr = {}.".format(init_lr))
        if final_lr < init_lr:
            raise ValueError(
                "Expected final_lr > init_lr, but got final_lr {} <= and init_lr {}."
                .format(final_lr, init_lr))
        if num_iters is not None and num_iters <= 1:
            raise ValueError(
                "Expected num_iters > 1, but got {}.".format(num_iters))
        if not isinstance(trn_loader, DataLoader):
            raise TypeError("Expected Dataloader, but got {}.".format(
                type(trn_loader).__name__))
        if not isinstance(optimizer, Optimizer):
            raise TypeError("Expected Optimizer, but got {}.".format(
                type(Optimizer).__name__))

        # If num_iters is None, then just run one epoch.
        if num_iters is not None:
            num_iters = num_iters
            epochs = (num_iters - 1) // len(trn_loader) + 1
        else:
            num_iters = len(trn_loader)
            epochs = 1

        logger.info(
            "Run lr finder from init_lr = {} to final_lr = {} with {} iters.".
            format(init_lr, final_lr, num_iters))

        # Init.
        mult = (final_lr / init_lr)**(1 / (num_iters - 1))

        num_batch = 0
        avg_values = 0.
        log_lrs = []

        if utils.is_main_training():
            reporter = LRFinderReporter(num_iters,
                                        log_dir=log_dir,
                                        comment=comment)

        # Start.
        lr = init_lr
        optimizer.param_groups[0]['lr'] = lr

        for this_epoch in range(epochs):
            for batch in trn_loader:
                num_batch += 1

                # The values is a vector of numpy and function return a list of float values.
                keys, values = function(self, batch)

                values = np.array(values)

                if not utils.is_main_training():
                    continue

                # Compute the smoothed values. The avg_values will be also a vector of numpy rather than 0.
                avg_values = beta * avg_values + (1 - beta) * values
                smoothed_values = avg_values / (1 - beta**num_batch)

                snapshot = {"lr": lr}
                for i in range(len(keys)):
                    snapshot[keys[i]] = smoothed_values[i]

                reporter.update(num_batch, snapshot)

                # # Stop if the main value is exploding.
                # if num_batch > 1 and smoothed_values[0] > 4 * best_value:
                #     reporter.finish()
                #     logger.info("Stop lr finder early by default rule.")
                #     return log_lrs[split[0]:split[1]], value_matrix.T[:,split[0]:split[1]]

                # Record the best main value. The main value which has the index-0 is usually the training loss.
                if num_batch == 1 or smoothed_values[0] < best_value:
                    best_value = smoothed_values[0]

                # Store the values.
                if num_batch == 1:
                    value_matrix = smoothed_values
                else:
                    value_matrix = np.vstack([value_matrix, smoothed_values])

                log_lrs.append(math.log10(lr))

                if num_batch >= num_iters:
                    reporter.finish()
                    return log_lrs[
                        split[0]:split[1]], value_matrix.T[:,
                                                           split[0]:split[1]]

                # Update the lr for the next step.
                lr *= mult
                optimizer.param_groups[0]['lr'] = lr

        if not utils.is_main_training():
            return None, None
        reporter.finish()
        return log_lrs[split[0]:split[1]], value_matrix.T[:, split[0]:split[1]]
Example #12
0
    def __init__(self,
                 trainset,
                 valid=None,
                 use_fast_loader=False,
                 max_prefetch=10,
                 batch_size=512,
                 shuffle=True,
                 num_workers=0,
                 pin_memory=False,
                 drop_last=True):

        num_samples = len(trainset)
        num_gpu = 1
        multi_gpu = False
        if utils.use_horovod():
            # Multi-GPU training.
            import horovod.torch as hvd
            # Partition dataset among workers using DistributedSampler
            train_sampler = torch.utils.data.distributed.DistributedSampler(
                trainset,
                num_replicas=hvd.size(),
                rank=hvd.rank(),
                shuffle=shuffle)
            multi_gpu = True
            num_gpu = hvd.size()
        elif utils.use_ddp():
            # The num_replicas/world_size and rank will be set automatically with DDP.
            train_sampler = torch.utils.data.distributed.DistributedSampler(
                trainset, shuffle=shuffle)
            multi_gpu = True
            num_gpu = dist.get_world_size()
        else:
            train_sampler = None

        if multi_gpu:
            # If use DistributedSampler, the shuffle of DataLoader should be set False.
            shuffle = False
            if not utils.is_main_training():
                valid = None

        if use_fast_loader:
            self.train_loader = DataLoaderFast(max_prefetch,
                                               trainset,
                                               batch_size=batch_size,
                                               shuffle=shuffle,
                                               num_workers=num_workers,
                                               pin_memory=pin_memory,
                                               drop_last=drop_last,
                                               sampler=train_sampler)
        else:
            self.train_loader = DataLoader(trainset,
                                           batch_size=batch_size,
                                           shuffle=shuffle,
                                           num_workers=num_workers,
                                           pin_memory=pin_memory,
                                           drop_last=drop_last,
                                           sampler=train_sampler)

        self.num_batch_train = len(self.train_loader)

        if self.num_batch_train <= 0:
            raise ValueError(
                "Expected num_batch of trainset > 0. There are your egs info: num_gpu={}, num_samples/gpu={}, "
                "batch-size={}, drop_last={}.\nNote: If batch-size > num_samples/gpu and drop_last is true, then it "
                "will get 0 batch.".format(num_gpu,
                                           len(trainset) / num_gpu, batch_size,
                                           drop_last))

        if valid is not None:
            valid_batch_size = min(batch_size,
                                   len(valid))  # To save GPU memory

            if len(valid) <= 0:
                raise ValueError("Expected num_samples of valid > 0.")

            # Do not use DataLoaderFast for valid for it increases the memory all the time when compute_valid_accuracy is True.
            # But I have not find the real reason.
            self.valid_loader = DataLoader(valid,
                                           batch_size=valid_batch_size,
                                           shuffle=False,
                                           num_workers=num_workers,
                                           pin_memory=pin_memory,
                                           drop_last=False)

            self.num_batch_valid = len(self.valid_loader)
        else:
            self.valid_loader = None
            self.num_batch_valid = 0
Example #13
0
    def run(self):
        """Main function to start a training process.
        """
        try:
            self.init_training()

            if utils.is_main_training():
                self.reporter = Reporter(self)

            start_epoch = self.params["start_epoch"]
            epochs = self.params["epochs"]
            data = self.elements["data"]
            model = self.elements["model"]
            model_forward = self.elements[
                "model_forward"]  # See init_training.
            lr_scheduler = self.elements["lr_scheduler"]
            base_optimizer = self.elements["optimizer"]
            best_valid_acc = 0.0

            # For lookahead.
            if getattr(base_optimizer, "optimizer", None) is not None:
                base_optimizer = base_optimizer.optimizer
            last_lr = base_optimizer.state_dict()['param_groups'][0]['lr']

            if utils.is_main_training():
                logger.info("Training will run for {0} epochs.".format(epochs))

            for this_epoch in range(start_epoch, epochs):
                if isinstance(data.train_loader.sampler,
                              torch.utils.data.distributed.DistributedSampler):
                    data.train_loader.sampler.set_epoch(this_epoch)
                for this_iter, batch in enumerate(data.train_loader, 0):
                    self.training_point = (this_epoch, this_iter,
                                           data.num_batch_train
                                           )  # It is important for reporter.

                    if model.use_step:
                        model.step(*self.training_point)

                    loss, acc = self.train_one_batch(batch)

                    model.backward_step(*self.training_point)

                    # For multi-GPU training. Remember that it is not convenient to wrap lr_scheduler
                    # for there are many strategies with different details. Here, only warmR, ReduceLROnPlateau
                    # and some simple schedulers whose step() parameter is 'epoch' only are supported.
                    lr_scheduler_params = {
                        "training_point": self.training_point
                    }

                    valid_computed = False
                    if lr_scheduler.name == "reduceP" and lr_scheduler.is_reduce_point(
                            self.training_point):
                        assert data.valid_loader is not None
                        valid_loss, valid_acc = self.compute_validation(
                            data.valid_loader)
                        lr_scheduler_params["valid_metric"] = (valid_loss,
                                                               valid_acc)
                        valid_computed = True

                    if utils.is_main_training():
                        if valid_computed or (data.valid_loader
                                              and self.reporter.is_report(
                                                  self.training_point)):
                            if not valid_computed:
                                valid_loss, valid_acc = self.compute_validation(
                                    data.valid_loader)
                                valid_computed = False

                            # real_snapshot is set for tensorboard to avoid workspace problem
                            real_snapshot = {
                                "train_loss": loss,
                                "valid_loss": valid_loss,
                                "train_acc": acc * 100,
                                "valid_acc": valid_acc * 100
                            }
                            snapshot = {
                                "train_loss": "{0:.6f}".format(loss),
                                "valid_loss": "{0:.6f}".format(valid_loss),
                                "train_acc": "{0:.2f}".format(acc * 100),
                                "valid_acc": "{0:.2f}".format(valid_acc * 100),
                                "real": real_snapshot
                            }
                            # For ReduceLROnPlateau.
                            lr_scheduler_params["valid_metric"] = (valid_loss,
                                                                   valid_acc)

                            if lr_scheduler.name == "warmR":
                                if this_epoch >= epochs - 1 and valid_acc >= best_valid_acc:
                                    best_valid_acc = valid_acc
                                    self.save_model(from_epoch=False)
                        else:
                            real_snapshot = {
                                "train_loss": loss,
                                "train_acc": acc * 100
                            }
                            snapshot = {
                                "train_loss": "{0:.6f}".format(loss),
                                "valid_loss": "",
                                "train_acc": "{0:.2f}".format(acc * 100),
                                "valid_acc": "",
                                "real": real_snapshot
                            }

                    if lr_scheduler is not None:
                        # It is not convenient to wrap lr_scheduler (doing).
                        if isinstance(lr_scheduler, LRSchedulerWrapper):
                            lr_scheduler.step(**lr_scheduler_params)
                            if utils.is_main_training():
                                current_lr = base_optimizer.state_dict(
                                )['param_groups'][0]['lr']
                                if lr_scheduler.name == "reduceP":
                                    if current_lr < last_lr:
                                        last_lr = current_lr
                                        self.save_model(from_epoch=False)
                                    elif current_lr <= lr_scheduler.min_lr and lr_scheduler.is_reduce_point(
                                            self.training_point):
                                        self.save_model(from_epoch=False)
                                elif lr_scheduler.name == "cyclic" and utils.is_main_training(
                                ):
                                    cyclic_size = lr_scheduler.lr_scheduler.total_size
                                    current_iter = self.training_point[
                                        0] * self.training_point[
                                            2] + self.training_point[1] + 1
                                    if current_iter % cyclic_size == 0 and current_iter != 1:
                                        self.save_model(from_epoch=False)
                        else:
                            # For some pytorch lr_schedulers, but it is not available for all.
                            lr_scheduler.step(this_epoch)
                    if utils.is_main_training():
                        self.reporter.update(snapshot)
                if utils.is_main_training():
                    if epochs >= 20:
                        if this_epoch >= epochs - 10:
                            print(current_lr)
                            self.save_model()
                    else:
                        print(current_lr)
                        self.save_model()
            if utils.is_main_training():
                self.reporter.finish()
        except BaseException as e:
            if utils.use_ddp():
                utils.cleanup_ddp()
            if not isinstance(e, KeyboardInterrupt):
                traceback.print_exc()
            sys.exit(1)
Example #14
0
    def run(self):
        """Main function to start a training process.
        """
        try:
            self.init_training()

            if utils.is_main_training():
                self.reporter = Reporter(self)

            start_epoch = self.params["start_epoch"]
            epochs = self.params["epochs"]
            data = self.elements["data"]
            model = self.elements["model"]
            model_forward = self.elements[
                "model_forward"]  # See init_training.
            lr_scheduler = self.elements["lr_scheduler"]

            if utils.is_main_training():
                logger.info("Training will run for {0} epochs.".format(epochs))

            for this_epoch in range(start_epoch, epochs):
                for this_iter, batch in enumerate(data.train_loader, 0):
                    self.training_point = (this_epoch, this_iter,
                                           data.num_batch_train
                                           )  # It is important for reporter.

                    if model.use_step:
                        model.step(*self.training_point)

                    if lr_scheduler is not None:
                        # It is not convenient to wrap lr_scheduler (doing).
                        if isinstance(lr_scheduler, LRSchedulerWrapper):
                            lr_scheduler.step(self.training_point)
                        else:
                            # For some pytorch lr_schedulers, but it is not available for all.
                            lr_scheduler.step(this_epoch)

                    loss, acc = self.train_one_batch(batch)

                    # For multi-GPU training.
                    if utils.is_main_training():
                        if data.valid_loader and self.reporter.is_report(
                                self.training_point):
                            valid_loss, valid_acc = self.compute_validation(
                                data.valid_loader)
                            snapshot = {
                                "train_loss": "{0:.6f}".format(loss),
                                "valid_loss": "{0:.6f}".format(valid_loss),
                                "train_acc": "{0:.2f}".format(acc * 100),
                                "valid_acc": "{0:.2f}".format(valid_acc * 100)
                            }
                        else:
                            snapshot = {
                                "train_loss": "{0:.6f}".format(loss),
                                "valid_loss": "",
                                "train_acc": "{0:.2f}".format(acc * 100),
                                "valid_acc": ""
                            }

                    if utils.is_main_training(): self.reporter.update(snapshot)
                if utils.is_main_training(): self.save_model()
            if utils.is_main_training(): self.reporter.finish()
        except BaseException as e:
            if utils.use_ddp(): utils.cleanup_ddp()
            if not isinstance(e, KeyboardInterrupt):
                traceback.print_exc()
            sys.exit(1)