Exemplo n.º 1
0
    def step(self, training_point=None, valid_metric=None):
        if self.name == "warmR":
            if self.lr_decay_step > 0 and training_point[1]%self.lr_decay_step == 0:
                self.lr_scheduler.step(training_point[0]+training_point[1]/training_point[2])
            elif self.lr_decay_step == 0:
                self.lr_scheduler.step(training_point[0])
        elif self.name == "1cycle":
            self.lr_scheduler.step()
        elif self.name == "reduceP":
            # Sample a point in which the metrics of valid are computed and adjust learning rate at this point.
            if self.is_reduce_point(training_point):
                # Do not support horovod now.
                if utils.use_ddp():
                    # Multi-gpu case.
                    # In this case, we do not compute valid set for all processes but just computing it in main process
                    # and broadcast the metrics to other processes.
                    if not self.init:
                        device = utils.get_device_from_optimizer(self.lr_scheduler.optimizer)
                        # Create a must tentor to prepare to broadcast with torch.distributed.broadcast fuction.
                        self.metric = torch.randn(2, device=device) 
                        # New a group to broadcast the special metric tensor. It is important.
                        self.group = torch.distributed.new_group(ranks=list(range(torch.distributed.get_world_size())), 
                                                                 backend="nccl")
                        self.init = True
                    if utils.is_main_training():
                        # Gather the new value of metric.
                        self.metric = torch.tensor([valid_metric[0], valid_metric[1]], device=self.metric.device)
                    # Broadcast
                    torch.distributed.broadcast(self.metric, 0, group=self.group)
                    metric = self.metric[0] if self.metric == "valid_loss" else self.metric[1]
                else:
                    # Single-GPU case.
                    metric = valid_metric[0] if self.metric == "valid_loss" else valid_metric[1]

                self.lr_scheduler.step(metric)
Exemplo n.º 2
0
    def run(self):
        """Main function to start a training process.
        """
        try:
            self.init_training()

            if utils.is_main_training():
                self.reporter = Reporter(self)

            start_epoch = self.params["start_epoch"]
            epochs = self.params["epochs"]
            data = self.elements["data"]
            model = self.elements["model"]
            model_forward = self.elements[
                "model_forward"]  # See init_training.
            lr_scheduler = self.elements["lr_scheduler"]
            base_optimizer = self.elements["optimizer"]

            # For lookahead.
            if getattr(base_optimizer, "optimizer", None) is not None:
                base_optimizer = base_optimizer.optimizer
            last_lr = base_optimizer.state_dict()['param_groups'][0]['lr']

            if utils.is_main_training():
                logger.info("Training will run for {0} epochs.".format(epochs))

            for this_epoch in range(start_epoch, epochs):
                # Set random seed w.r.t epoch for distributed training.
                if isinstance(data.train_loader.sampler, torch.utils.data.distributed.DistributedSampler) and \
                    self.params["ddp_random_epoch"]:
                    data.train_loader.sampler.set_epoch(this_epoch)
                for this_iter, batch in enumerate(data.train_loader, 0):
                    self.training_point = (this_epoch, this_iter,
                                           data.num_batch_train
                                           )  # It is important for reporter.

                    if model.use_step:
                        model.step(*self.training_point)

                    loss, acc = self.train_one_batch(batch)

                    model.backward_step(*self.training_point)

                    # For multi-GPU training. Remember that it is not convenient to wrap lr_scheduler
                    # for there are many strategies with different details. Here, only warmR, ReduceLROnPlateau
                    # and some simple schedulers whose step() parameter is 'epoch' only are supported.
                    lr_scheduler_params = {
                        "training_point": self.training_point
                    }

                    if utils.is_main_training(
                    ) or lr_scheduler.name == "reduceP":
                        if data.valid_loader and (self.reporter.is_report(self.training_point) or \
                           lr_scheduler.is_reduce_point(self.training_point)):

                            valid_loss, valid_acc = self.compute_validation(
                                data.valid_loader)
                            # real_snapshot is set for tensorboard to avoid workspace problem
                            real_snapshot = {
                                "train_loss": loss,
                                "valid_loss": valid_loss,
                                "train_acc": acc * 100,
                                "valid_acc": valid_acc * 100
                            }
                            snapshot = {
                                "train_loss": "{0:.6f}".format(loss),
                                "valid_loss": "{0:.6f}".format(valid_loss),
                                "train_acc": "{0:.2f}".format(acc * 100),
                                "valid_acc": "{0:.2f}".format(valid_acc * 100),
                                "real": real_snapshot
                            }
                            try:
                                weight = model.loss.weight.squeeze(dim=2)
                                weight = F.normalize(weight, dim=1)
                                orth_snapshot = {"orth_snp": 0.}
                                for i in range(weight.shape[0]):
                                    for j in range(i + 1, weight.shape[0]):
                                        orth_snapshot["orth_snp"] += torch.dot(
                                            weight[i], weight[j]).item()
                                orth_snapshot["orth_snp"] /= weight.shape[
                                    0] * (weight.shape[0] - 1) / 2
                                real_snapshot.update(orth_snapshot)
                                snapshot.update(orth_snapshot)
                                snapshot["real"] = real_snapshot
                            except Exception as e:
                                pass
                            # For ReduceLROnPlateau.
                            lr_scheduler_params["valid_metric"] = (valid_loss,
                                                                   valid_acc)
                        else:
                            real_snapshot = {
                                "train_loss": loss,
                                "train_acc": acc * 100
                            }
                            snapshot = {
                                "train_loss": "{0:.6f}".format(loss),
                                "valid_loss": "",
                                "train_acc": "{0:.2f}".format(acc * 100),
                                "valid_acc": "",
                                "real": real_snapshot
                            }

                    if lr_scheduler is not None:
                        # It is not convenient to wrap lr_scheduler (doing).
                        if isinstance(lr_scheduler, LRSchedulerWrapper):
                            lr_scheduler.step(**lr_scheduler_params)
                            if lr_scheduler.name == "reduceP" and utils.is_main_training(
                            ):
                                current_lr = base_optimizer.state_dict(
                                )['param_groups'][0]['lr']
                                if current_lr < last_lr:
                                    last_lr = current_lr
                                    self.save_model(from_epoch=False)
                        else:
                            # For some pytorch lr_schedulers, but it is not available for all.
                            lr_scheduler.step(this_epoch)
                    if utils.is_main_training(): self.reporter.update(snapshot)
                if utils.is_main_training(): self.save_model()
            if utils.is_main_training(): self.reporter.finish()
        except BaseException as e:
            if utils.use_ddp(): utils.cleanup_ddp()
            if not isinstance(e, KeyboardInterrupt):
                traceback.print_exc()
            sys.exit(1)
Exemplo n.º 3
0
    def __init__(self,
                 trainset,
                 valid=None,
                 use_fast_loader=False,
                 max_prefetch=10,
                 batch_size=512,
                 valid_batch_size=512,
                 shuffle=True,
                 num_workers=0,
                 pin_memory=False,
                 drop_last=True):

        num_samples = len(trainset)
        num_gpu = 1
        multi_gpu = False
        if utils.use_horovod():
            # Multi-GPU training.
            import horovod.torch as hvd
            # Partition dataset among workers using DistributedSampler
            train_sampler = torch.utils.data.distributed.DistributedSampler(
                trainset,
                num_replicas=hvd.size(),
                rank=hvd.rank(),
                shuffle=shuffle)
            multi_gpu = True
            num_gpu = hvd.size()
        elif utils.use_ddp():
            # The num_replicas/world_size and rank will be set automatically with DDP.
            train_sampler = torch.utils.data.distributed.DistributedSampler(
                trainset, shuffle=shuffle)
            multi_gpu = True
            num_gpu = dist.get_world_size()
        else:
            train_sampler = None

        if multi_gpu:
            # If use DistributedSampler, the shuffle of DataLoader should be set False.
            shuffle = False

        if use_fast_loader:
            self.train_loader = DataLoaderFast(max_prefetch,
                                               trainset,
                                               batch_size=batch_size,
                                               shuffle=shuffle,
                                               num_workers=num_workers,
                                               pin_memory=pin_memory,
                                               drop_last=drop_last,
                                               sampler=train_sampler)
        else:
            self.train_loader = DataLoader(trainset,
                                           batch_size=batch_size,
                                           shuffle=shuffle,
                                           num_workers=num_workers,
                                           pin_memory=pin_memory,
                                           drop_last=drop_last,
                                           sampler=train_sampler)

        self.num_batch_train = len(self.train_loader)

        if self.num_batch_train <= 0:
            raise ValueError(
                "Expected num_batch of trainset > 0. There are your egs info: num_gpu={}, num_samples/gpu={}, "
                "batch-size={}, drop_last={}.\nNote: If batch-size > num_samples/gpu and drop_last is true, then it "
                "will get 0 batch.".format(num_gpu,
                                           len(trainset) / num_gpu, batch_size,
                                           drop_last))

        if valid is not None:
            valid_batch_size = min(valid_batch_size,
                                   len(valid))  # To save GPU memory

            if len(valid) <= 0:
                raise ValueError("Expected num_samples of valid > 0.")

            # Do not use DataLoaderFast for valid for it increases the memory all the time when compute_valid_accuracy is True.
            # But I have not find the real reason.
            self.valid_loader = DataLoader(valid,
                                           batch_size=valid_batch_size,
                                           shuffle=False,
                                           num_workers=num_workers,
                                           pin_memory=pin_memory,
                                           drop_last=False)

            self.num_batch_valid = len(self.valid_loader)
        else:
            self.valid_loader = None
            self.num_batch_valid = 0
Exemplo n.º 4
0
    def run(self):
        """Main function to start a training process.
        """
        try:
            self.init_training()

            if utils.is_main_training():
                self.reporter = Reporter(self)

            start_epoch = self.params["start_epoch"]
            epochs = self.params["epochs"]
            data = self.elements["data"]
            model = self.elements["model"]
            model_forward = self.elements[
                "model_forward"]  # See init_training.
            lr_scheduler = self.elements["lr_scheduler"]

            if utils.is_main_training():
                logger.info("Training will run for {0} epochs.".format(epochs))

            for this_epoch in range(start_epoch, epochs):
                for this_iter, batch in enumerate(data.train_loader, 0):
                    self.training_point = (this_epoch, this_iter,
                                           data.num_batch_train
                                           )  # It is important for reporter.

                    if model.use_step:
                        model.step(*self.training_point)

                    loss, acc = self.train_one_batch(batch)

                    # For multi-GPU training. Remember that it is not convenient to wrap lr_scheduler
                    # for there are many strategies with different details. Here, only warmR, ReduceLROnPlateau
                    # and some simple schedulers whose step() parameter is 'epoch' only are supported.
                    lr_scheduler_params = {
                        "training_point": self.training_point
                    }

                    if utils.is_main_training(
                    ) or lr_scheduler.name == "reduceP":
                        if data.valid_loader and (self.reporter.is_report(self.training_point) or \
                           lr_scheduler.is_reduce_point(self.training_point)):

                            valid_loss, valid_acc = self.compute_validation(
                                data.valid_loader)
                            snapshot = {
                                "train_loss": "{0:.6f}".format(loss),
                                "valid_loss": "{0:.6f}".format(valid_loss),
                                "train_acc": "{0:.2f}".format(acc * 100),
                                "valid_acc": "{0:.2f}".format(valid_acc * 100)
                            }
                            # For ReduceLROnPlateau.
                            lr_scheduler_params["valid_metric"] = (valid_loss,
                                                                   valid_acc)
                        else:
                            snapshot = {
                                "train_loss": "{0:.6f}".format(loss),
                                "valid_loss": "",
                                "train_acc": "{0:.2f}".format(acc * 100),
                                "valid_acc": ""
                            }

                    if lr_scheduler is not None:
                        # It is not convenient to wrap lr_scheduler (doing).
                        if isinstance(lr_scheduler, LRSchedulerWrapper):
                            lr_scheduler.step(**lr_scheduler_params)
                        else:
                            # For some pytorch lr_schedulers, but it is not available for all.
                            lr_scheduler.step(this_epoch)
                    if utils.is_main_training(): self.reporter.update(snapshot)
                if utils.is_main_training(): self.save_model()
            if utils.is_main_training(): self.reporter.finish()
        except BaseException as e:
            if utils.use_ddp(): utils.cleanup_ddp()
            if not isinstance(e, KeyboardInterrupt):
                traceback.print_exc()
            sys.exit(1)
Exemplo n.º 5
0
    def run(self):
        """Main function to start a training process.
        """
        try:
            self.init_training()

            if utils.is_main_training():
                self.reporter = Reporter(self)

            start_epoch = self.params["start_epoch"]
            epochs = self.params["epochs"]
            data = self.elements["data"]
            model = self.elements["model"]
            model_forward = self.elements[
                "model_forward"]  # See init_training.
            lr_scheduler = self.elements["lr_scheduler"]
            base_optimizer = self.elements["optimizer"]
            best_valid_acc = 0.0

            # For lookahead.
            if getattr(base_optimizer, "optimizer", None) is not None:
                base_optimizer = base_optimizer.optimizer
            last_lr = base_optimizer.state_dict()['param_groups'][0]['lr']

            if utils.is_main_training():
                logger.info("Training will run for {0} epochs.".format(epochs))

            for this_epoch in range(start_epoch, epochs):
                if isinstance(data.train_loader.sampler,
                              torch.utils.data.distributed.DistributedSampler):
                    data.train_loader.sampler.set_epoch(this_epoch)
                for this_iter, batch in enumerate(data.train_loader, 0):
                    self.training_point = (this_epoch, this_iter,
                                           data.num_batch_train
                                           )  # It is important for reporter.

                    if model.use_step:
                        model.step(*self.training_point)

                    loss, acc = self.train_one_batch(batch)

                    model.backward_step(*self.training_point)

                    # For multi-GPU training. Remember that it is not convenient to wrap lr_scheduler
                    # for there are many strategies with different details. Here, only warmR, ReduceLROnPlateau
                    # and some simple schedulers whose step() parameter is 'epoch' only are supported.
                    lr_scheduler_params = {
                        "training_point": self.training_point
                    }

                    valid_computed = False
                    if lr_scheduler.name == "reduceP" and lr_scheduler.is_reduce_point(
                            self.training_point):
                        assert data.valid_loader is not None
                        valid_loss, valid_acc = self.compute_validation(
                            data.valid_loader)
                        lr_scheduler_params["valid_metric"] = (valid_loss,
                                                               valid_acc)
                        valid_computed = True

                    if utils.is_main_training():
                        if valid_computed or (data.valid_loader
                                              and self.reporter.is_report(
                                                  self.training_point)):
                            if not valid_computed:
                                valid_loss, valid_acc = self.compute_validation(
                                    data.valid_loader)
                                valid_computed = False

                            # real_snapshot is set for tensorboard to avoid workspace problem
                            real_snapshot = {
                                "train_loss": loss,
                                "valid_loss": valid_loss,
                                "train_acc": acc * 100,
                                "valid_acc": valid_acc * 100
                            }
                            snapshot = {
                                "train_loss": "{0:.6f}".format(loss),
                                "valid_loss": "{0:.6f}".format(valid_loss),
                                "train_acc": "{0:.2f}".format(acc * 100),
                                "valid_acc": "{0:.2f}".format(valid_acc * 100),
                                "real": real_snapshot
                            }
                            # For ReduceLROnPlateau.
                            lr_scheduler_params["valid_metric"] = (valid_loss,
                                                                   valid_acc)

                            if lr_scheduler.name == "warmR":
                                if this_epoch >= epochs - 1 and valid_acc >= best_valid_acc:
                                    best_valid_acc = valid_acc
                                    self.save_model(from_epoch=False)
                        else:
                            real_snapshot = {
                                "train_loss": loss,
                                "train_acc": acc * 100
                            }
                            snapshot = {
                                "train_loss": "{0:.6f}".format(loss),
                                "valid_loss": "",
                                "train_acc": "{0:.2f}".format(acc * 100),
                                "valid_acc": "",
                                "real": real_snapshot
                            }

                    if lr_scheduler is not None:
                        # It is not convenient to wrap lr_scheduler (doing).
                        if isinstance(lr_scheduler, LRSchedulerWrapper):
                            lr_scheduler.step(**lr_scheduler_params)
                            if utils.is_main_training():
                                current_lr = base_optimizer.state_dict(
                                )['param_groups'][0]['lr']
                                if lr_scheduler.name == "reduceP":
                                    if current_lr < last_lr:
                                        last_lr = current_lr
                                        self.save_model(from_epoch=False)
                                    elif current_lr <= lr_scheduler.min_lr and lr_scheduler.is_reduce_point(
                                            self.training_point):
                                        self.save_model(from_epoch=False)
                                elif lr_scheduler.name == "cyclic" and utils.is_main_training(
                                ):
                                    cyclic_size = lr_scheduler.lr_scheduler.total_size
                                    current_iter = self.training_point[
                                        0] * self.training_point[
                                            2] + self.training_point[1] + 1
                                    if current_iter % cyclic_size == 0 and current_iter != 1:
                                        self.save_model(from_epoch=False)
                        else:
                            # For some pytorch lr_schedulers, but it is not available for all.
                            lr_scheduler.step(this_epoch)
                    if utils.is_main_training():
                        self.reporter.update(snapshot)
                if utils.is_main_training():
                    if epochs >= 20:
                        if this_epoch >= epochs - 10:
                            print(current_lr)
                            self.save_model()
                    else:
                        print(current_lr)
                        self.save_model()
            if utils.is_main_training():
                self.reporter.finish()
        except BaseException as e:
            if utils.use_ddp():
                utils.cleanup_ddp()
            if not isinstance(e, KeyboardInterrupt):
                traceback.print_exc()
            sys.exit(1)
Exemplo n.º 6
0
    def run(self):
        """Main function to start a training process.
        """
        try:
            self.init_training()

            if utils.is_main_training():
                self.reporter = Reporter(self)

            start_epoch = self.params["start_epoch"]
            epochs = self.params["epochs"]
            data = self.elements["data"]
            model = self.elements["model"]
            model_forward = self.elements[
                "model_forward"]  # See init_training.
            lr_scheduler = self.elements["lr_scheduler"]

            if utils.is_main_training():
                logger.info("Training will run for {0} epochs.".format(epochs))

            for this_epoch in range(start_epoch, epochs):
                for this_iter, batch in enumerate(data.train_loader, 0):
                    self.training_point = (this_epoch, this_iter,
                                           data.num_batch_train
                                           )  # It is important for reporter.

                    if model.use_step:
                        model.step(*self.training_point)

                    if lr_scheduler is not None:
                        # It is not convenient to wrap lr_scheduler (doing).
                        if isinstance(lr_scheduler, LRSchedulerWrapper):
                            lr_scheduler.step(self.training_point)
                        else:
                            # For some pytorch lr_schedulers, but it is not available for all.
                            lr_scheduler.step(this_epoch)

                    loss, acc = self.train_one_batch(batch)

                    # For multi-GPU training.
                    if utils.is_main_training():
                        if data.valid_loader and self.reporter.is_report(
                                self.training_point):
                            valid_loss, valid_acc = self.compute_validation(
                                data.valid_loader)
                            snapshot = {
                                "train_loss": "{0:.6f}".format(loss),
                                "valid_loss": "{0:.6f}".format(valid_loss),
                                "train_acc": "{0:.2f}".format(acc * 100),
                                "valid_acc": "{0:.2f}".format(valid_acc * 100)
                            }
                        else:
                            snapshot = {
                                "train_loss": "{0:.6f}".format(loss),
                                "valid_loss": "",
                                "train_acc": "{0:.2f}".format(acc * 100),
                                "valid_acc": ""
                            }

                    if utils.is_main_training(): self.reporter.update(snapshot)
                if utils.is_main_training(): self.save_model()
            if utils.is_main_training(): self.reporter.finish()
        except BaseException as e:
            if utils.use_ddp(): utils.cleanup_ddp()
            if not isinstance(e, KeyboardInterrupt):
                traceback.print_exc()
            sys.exit(1)