Exemple #1
0
    def run(self):
        """Main function to start a training process.
        """
        try:
            self.init_training()

            if utils.is_main_training():
                self.reporter = Reporter(self)

            start_epoch = self.params["start_epoch"]
            epochs = self.params["epochs"]
            data = self.elements["data"]
            model = self.elements["model"]
            model_forward = self.elements[
                "model_forward"]  # See init_training.
            lr_scheduler = self.elements["lr_scheduler"]
            base_optimizer = self.elements["optimizer"]

            # For lookahead.
            if getattr(base_optimizer, "optimizer", None) is not None:
                base_optimizer = base_optimizer.optimizer
            last_lr = base_optimizer.state_dict()['param_groups'][0]['lr']

            if utils.is_main_training():
                logger.info("Training will run for {0} epochs.".format(epochs))

            for this_epoch in range(start_epoch, epochs):
                # Set random seed w.r.t epoch for distributed training.
                if isinstance(data.train_loader.sampler, torch.utils.data.distributed.DistributedSampler) and \
                    self.params["ddp_random_epoch"]:
                    data.train_loader.sampler.set_epoch(this_epoch)
                for this_iter, batch in enumerate(data.train_loader, 0):
                    self.training_point = (this_epoch, this_iter,
                                           data.num_batch_train
                                           )  # It is important for reporter.

                    if model.use_step:
                        model.step(*self.training_point)

                    loss, acc = self.train_one_batch(batch)

                    model.backward_step(*self.training_point)

                    # For multi-GPU training. Remember that it is not convenient to wrap lr_scheduler
                    # for there are many strategies with different details. Here, only warmR, ReduceLROnPlateau
                    # and some simple schedulers whose step() parameter is 'epoch' only are supported.
                    lr_scheduler_params = {
                        "training_point": self.training_point
                    }

                    if utils.is_main_training(
                    ) or lr_scheduler.name == "reduceP":
                        if data.valid_loader and (self.reporter.is_report(self.training_point) or \
                           lr_scheduler.is_reduce_point(self.training_point)):

                            valid_loss, valid_acc = self.compute_validation(
                                data.valid_loader)
                            # real_snapshot is set for tensorboard to avoid workspace problem
                            real_snapshot = {
                                "train_loss": loss,
                                "valid_loss": valid_loss,
                                "train_acc": acc * 100,
                                "valid_acc": valid_acc * 100
                            }
                            snapshot = {
                                "train_loss": "{0:.6f}".format(loss),
                                "valid_loss": "{0:.6f}".format(valid_loss),
                                "train_acc": "{0:.2f}".format(acc * 100),
                                "valid_acc": "{0:.2f}".format(valid_acc * 100),
                                "real": real_snapshot
                            }
                            try:
                                weight = model.loss.weight.squeeze(dim=2)
                                weight = F.normalize(weight, dim=1)
                                orth_snapshot = {"orth_snp": 0.}
                                for i in range(weight.shape[0]):
                                    for j in range(i + 1, weight.shape[0]):
                                        orth_snapshot["orth_snp"] += torch.dot(
                                            weight[i], weight[j]).item()
                                orth_snapshot["orth_snp"] /= weight.shape[
                                    0] * (weight.shape[0] - 1) / 2
                                real_snapshot.update(orth_snapshot)
                                snapshot.update(orth_snapshot)
                                snapshot["real"] = real_snapshot
                            except Exception as e:
                                pass
                            # For ReduceLROnPlateau.
                            lr_scheduler_params["valid_metric"] = (valid_loss,
                                                                   valid_acc)
                        else:
                            real_snapshot = {
                                "train_loss": loss,
                                "train_acc": acc * 100
                            }
                            snapshot = {
                                "train_loss": "{0:.6f}".format(loss),
                                "valid_loss": "",
                                "train_acc": "{0:.2f}".format(acc * 100),
                                "valid_acc": "",
                                "real": real_snapshot
                            }

                    if lr_scheduler is not None:
                        # It is not convenient to wrap lr_scheduler (doing).
                        if isinstance(lr_scheduler, LRSchedulerWrapper):
                            lr_scheduler.step(**lr_scheduler_params)
                            if lr_scheduler.name == "reduceP" and utils.is_main_training(
                            ):
                                current_lr = base_optimizer.state_dict(
                                )['param_groups'][0]['lr']
                                if current_lr < last_lr:
                                    last_lr = current_lr
                                    self.save_model(from_epoch=False)
                        else:
                            # For some pytorch lr_schedulers, but it is not available for all.
                            lr_scheduler.step(this_epoch)
                    if utils.is_main_training(): self.reporter.update(snapshot)
                if utils.is_main_training(): self.save_model()
            if utils.is_main_training(): self.reporter.finish()
        except BaseException as e:
            if utils.use_ddp(): utils.cleanup_ddp()
            if not isinstance(e, KeyboardInterrupt):
                traceback.print_exc()
            sys.exit(1)
Exemple #2
0
    def run(self):
        """Main function to start a training process.
        """
        try:
            self.init_training()

            if utils.is_main_training():
                self.reporter = Reporter(self)

            start_epoch = self.params["start_epoch"]
            epochs = self.params["epochs"]
            data = self.elements["data"]
            model = self.elements["model"]
            model_forward = self.elements[
                "model_forward"]  # See init_training.
            lr_scheduler = self.elements["lr_scheduler"]
            base_optimizer = self.elements["optimizer"]
            best_valid_acc = 0.0

            # For lookahead.
            if getattr(base_optimizer, "optimizer", None) is not None:
                base_optimizer = base_optimizer.optimizer
            last_lr = base_optimizer.state_dict()['param_groups'][0]['lr']

            if utils.is_main_training():
                logger.info("Training will run for {0} epochs.".format(epochs))

            for this_epoch in range(start_epoch, epochs):
                if isinstance(data.train_loader.sampler,
                              torch.utils.data.distributed.DistributedSampler):
                    data.train_loader.sampler.set_epoch(this_epoch)
                for this_iter, batch in enumerate(data.train_loader, 0):
                    self.training_point = (this_epoch, this_iter,
                                           data.num_batch_train
                                           )  # It is important for reporter.

                    if model.use_step:
                        model.step(*self.training_point)

                    loss, acc = self.train_one_batch(batch)

                    model.backward_step(*self.training_point)

                    # For multi-GPU training. Remember that it is not convenient to wrap lr_scheduler
                    # for there are many strategies with different details. Here, only warmR, ReduceLROnPlateau
                    # and some simple schedulers whose step() parameter is 'epoch' only are supported.
                    lr_scheduler_params = {
                        "training_point": self.training_point
                    }

                    valid_computed = False
                    if lr_scheduler.name == "reduceP" and lr_scheduler.is_reduce_point(
                            self.training_point):
                        assert data.valid_loader is not None
                        valid_loss, valid_acc = self.compute_validation(
                            data.valid_loader)
                        lr_scheduler_params["valid_metric"] = (valid_loss,
                                                               valid_acc)
                        valid_computed = True

                    if utils.is_main_training():
                        if valid_computed or (data.valid_loader
                                              and self.reporter.is_report(
                                                  self.training_point)):
                            if not valid_computed:
                                valid_loss, valid_acc = self.compute_validation(
                                    data.valid_loader)
                                valid_computed = False

                            # real_snapshot is set for tensorboard to avoid workspace problem
                            real_snapshot = {
                                "train_loss": loss,
                                "valid_loss": valid_loss,
                                "train_acc": acc * 100,
                                "valid_acc": valid_acc * 100
                            }
                            snapshot = {
                                "train_loss": "{0:.6f}".format(loss),
                                "valid_loss": "{0:.6f}".format(valid_loss),
                                "train_acc": "{0:.2f}".format(acc * 100),
                                "valid_acc": "{0:.2f}".format(valid_acc * 100),
                                "real": real_snapshot
                            }
                            # For ReduceLROnPlateau.
                            lr_scheduler_params["valid_metric"] = (valid_loss,
                                                                   valid_acc)

                            if lr_scheduler.name == "warmR":
                                if this_epoch >= epochs - 1 and valid_acc >= best_valid_acc:
                                    best_valid_acc = valid_acc
                                    self.save_model(from_epoch=False)
                        else:
                            real_snapshot = {
                                "train_loss": loss,
                                "train_acc": acc * 100
                            }
                            snapshot = {
                                "train_loss": "{0:.6f}".format(loss),
                                "valid_loss": "",
                                "train_acc": "{0:.2f}".format(acc * 100),
                                "valid_acc": "",
                                "real": real_snapshot
                            }

                    if lr_scheduler is not None:
                        # It is not convenient to wrap lr_scheduler (doing).
                        if isinstance(lr_scheduler, LRSchedulerWrapper):
                            lr_scheduler.step(**lr_scheduler_params)
                            if utils.is_main_training():
                                current_lr = base_optimizer.state_dict(
                                )['param_groups'][0]['lr']
                                if lr_scheduler.name == "reduceP":
                                    if current_lr < last_lr:
                                        last_lr = current_lr
                                        self.save_model(from_epoch=False)
                                    elif current_lr <= lr_scheduler.min_lr and lr_scheduler.is_reduce_point(
                                            self.training_point):
                                        self.save_model(from_epoch=False)
                                elif lr_scheduler.name == "cyclic" and utils.is_main_training(
                                ):
                                    cyclic_size = lr_scheduler.lr_scheduler.total_size
                                    current_iter = self.training_point[
                                        0] * self.training_point[
                                            2] + self.training_point[1] + 1
                                    if current_iter % cyclic_size == 0 and current_iter != 1:
                                        self.save_model(from_epoch=False)
                        else:
                            # For some pytorch lr_schedulers, but it is not available for all.
                            lr_scheduler.step(this_epoch)
                    if utils.is_main_training():
                        self.reporter.update(snapshot)
                if utils.is_main_training():
                    if epochs >= 20:
                        if this_epoch >= epochs - 10:
                            print(current_lr)
                            self.save_model()
                    else:
                        print(current_lr)
                        self.save_model()
            if utils.is_main_training():
                self.reporter.finish()
        except BaseException as e:
            if utils.use_ddp():
                utils.cleanup_ddp()
            if not isinstance(e, KeyboardInterrupt):
                traceback.print_exc()
            sys.exit(1)
Exemple #3
0
    def run(self):
        """Main function to start a training process.
        """
        try:
            self.init_training()

            if utils.is_main_training():
                self.reporter = Reporter(self)

            start_epoch = self.params["start_epoch"]
            epochs = self.params["epochs"]
            data = self.elements["data"]
            model = self.elements["model"]
            model_forward = self.elements[
                "model_forward"]  # See init_training.
            lr_scheduler = self.elements["lr_scheduler"]

            if utils.is_main_training():
                logger.info("Training will run for {0} epochs.".format(epochs))

            for this_epoch in range(start_epoch, epochs):
                for this_iter, batch in enumerate(data.train_loader, 0):
                    self.training_point = (this_epoch, this_iter,
                                           data.num_batch_train
                                           )  # It is important for reporter.

                    if model.use_step:
                        model.step(*self.training_point)

                    loss, acc = self.train_one_batch(batch)

                    # For multi-GPU training. Remember that it is not convenient to wrap lr_scheduler
                    # for there are many strategies with different details. Here, only warmR, ReduceLROnPlateau
                    # and some simple schedulers whose step() parameter is 'epoch' only are supported.
                    lr_scheduler_params = {
                        "training_point": self.training_point
                    }

                    if utils.is_main_training(
                    ) or lr_scheduler.name == "reduceP":
                        if data.valid_loader and (self.reporter.is_report(self.training_point) or \
                           lr_scheduler.is_reduce_point(self.training_point)):

                            valid_loss, valid_acc = self.compute_validation(
                                data.valid_loader)
                            snapshot = {
                                "train_loss": "{0:.6f}".format(loss),
                                "valid_loss": "{0:.6f}".format(valid_loss),
                                "train_acc": "{0:.2f}".format(acc * 100),
                                "valid_acc": "{0:.2f}".format(valid_acc * 100)
                            }
                            # For ReduceLROnPlateau.
                            lr_scheduler_params["valid_metric"] = (valid_loss,
                                                                   valid_acc)
                        else:
                            snapshot = {
                                "train_loss": "{0:.6f}".format(loss),
                                "valid_loss": "",
                                "train_acc": "{0:.2f}".format(acc * 100),
                                "valid_acc": ""
                            }

                    if lr_scheduler is not None:
                        # It is not convenient to wrap lr_scheduler (doing).
                        if isinstance(lr_scheduler, LRSchedulerWrapper):
                            lr_scheduler.step(**lr_scheduler_params)
                        else:
                            # For some pytorch lr_schedulers, but it is not available for all.
                            lr_scheduler.step(this_epoch)
                    if utils.is_main_training(): self.reporter.update(snapshot)
                if utils.is_main_training(): self.save_model()
            if utils.is_main_training(): self.reporter.finish()
        except BaseException as e:
            if utils.use_ddp(): utils.cleanup_ddp()
            if not isinstance(e, KeyboardInterrupt):
                traceback.print_exc()
            sys.exit(1)
Exemple #4
0
    def run(self):
        """Main function to start a training process.
        """
        try:
            self.init_training()

            if utils.is_main_training():
                self.reporter = Reporter(self)

            start_epoch = self.params["start_epoch"]
            epochs = self.params["epochs"]
            data = self.elements["data"]
            model = self.elements["model"]
            model_forward = self.elements[
                "model_forward"]  # See init_training.
            lr_scheduler = self.elements["lr_scheduler"]

            if utils.is_main_training():
                logger.info("Training will run for {0} epochs.".format(epochs))

            for this_epoch in range(start_epoch, epochs):
                for this_iter, batch in enumerate(data.train_loader, 0):
                    self.training_point = (this_epoch, this_iter,
                                           data.num_batch_train
                                           )  # It is important for reporter.

                    if model.use_step:
                        model.step(*self.training_point)

                    if lr_scheduler is not None:
                        # It is not convenient to wrap lr_scheduler (doing).
                        if isinstance(lr_scheduler, LRSchedulerWrapper):
                            lr_scheduler.step(self.training_point)
                        else:
                            # For some pytorch lr_schedulers, but it is not available for all.
                            lr_scheduler.step(this_epoch)

                    loss, acc = self.train_one_batch(batch)

                    # For multi-GPU training.
                    if utils.is_main_training():
                        if data.valid_loader and self.reporter.is_report(
                                self.training_point):
                            valid_loss, valid_acc = self.compute_validation(
                                data.valid_loader)
                            snapshot = {
                                "train_loss": "{0:.6f}".format(loss),
                                "valid_loss": "{0:.6f}".format(valid_loss),
                                "train_acc": "{0:.2f}".format(acc * 100),
                                "valid_acc": "{0:.2f}".format(valid_acc * 100)
                            }
                        else:
                            snapshot = {
                                "train_loss": "{0:.6f}".format(loss),
                                "valid_loss": "",
                                "train_acc": "{0:.2f}".format(acc * 100),
                                "valid_acc": ""
                            }

                    if utils.is_main_training(): self.reporter.update(snapshot)
                if utils.is_main_training(): self.save_model()
            if utils.is_main_training(): self.reporter.finish()
        except BaseException as e:
            if utils.use_ddp(): utils.cleanup_ddp()
            if not isinstance(e, KeyboardInterrupt):
                traceback.print_exc()
            sys.exit(1)