Ejemplo n.º 1
0
    def train_one_batch(self, batch):
        """A normal training core without fetching data from iterator.
        """
        model = self.elements["model"]
        model_forward = self.elements["model_forward"]
        optimizer = self.elements["optimizer"]

        if not model.training:
            model.train()

        if self.params["nan_debug"]:
            device = utils.get_device(self.elements["model"])
            inputs = torch.load("{0}/nan.batch".format(self.params["model_dir"])).to(device)
            targets = torch.load("{0}/nan.targets".format(self.params["model_dir"])).to(device)
            self.elements["model"].load_state_dict(torch.load("{0}/nan.params".format(self.params["model_dir"]), 
                                             map_location="cpu"))
            self.elements["model"].to(device)
        else:
            inputs, targets = batch
        optimizer.zero_grad()

        loss = model.get_loss(model_forward(inputs), targets)
        loss.backward()
        loss.detach() # For safe.

        if self.params["max_change"] > 0:
            # Reference:https://github.com/horovod/horovod/blob/master/horovod/torch/__init__.py:420~423.
            # Synchronize the grad for grad_norm when using horovod.
            if utils.use_horovod(): optimizer.synchronize()
            grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), self.params["max_change"])

            if math.isnan(grad_norm):
                if self.params["nan_debug"]:
                    raise RuntimeError("[NOT OK] Nan is still found in this debug.")
                torch.save(inputs.cpu(), "{0}/nan.batch".format(self.params["model_dir"]))
                torch.save(targets.cpu(), "{0}/nan.targets".format(self.params["model_dir"]))
                torch.save(self.elements["model"].state_dict(), "{0}/nan.params".format(self.params["model_dir"]))
                raise RuntimeError('There is Nan problem in iter/epoch: {0}/{1} (nan batch and params are saved in {2})'.format(self.training_point[1]+1, 
                self.training_point[0]+1, "{0}/nan.*".format(self.params["model_dir"])))
            else:
                if self.params["nan_debug"]:
                    raise RuntimeError("[OK] There is no nan found for this debug.")
                if utils.use_horovod():
                    with optimizer.skip_synchronize():
                        optimizer.step()
                else:
                    optimizer.step()
        else:
            optimizer.step()

        accuracy = model.get_accuracy(targets) if self.params["compute_accuracy"] else None

        return loss.item(), accuracy
Ejemplo n.º 2
0
    def __init__(self, trainer):
        default_params = {
            "report_times_every_epoch": None,
            "report_interval_iters": 100,
            "record_file": "train.csv",
            "use_tensorboard": False
        }
        self.trainer = trainer
        default_params = utils.assign_params_dict(default_params,
                                                  self.trainer.params)

        if default_params["report_times_every_epoch"] is not None:
            self.report_interval_iters = max(
                1, self.trainer.training_point[2] //
                default_params["report_times_every_epoch"])
        else:
            self.report_interval_iters = default_params[
                "report_interval_iters"]

        if not self.trainer.params["debug"] and default_params[
                "use_tensorboard"]:
            # from tensorboardX import SummaryWriter
            from torch.utils.tensorboard import SummaryWriter
            model_name = os.path.basename(self.trainer.params["model_dir"])
            # time_string = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time()))
            # time_string = self.trainer.params["time_string"]
            # self.board_writer = SummaryWriter("{}/log/{}-{}-tensorboard".format(self.trainer.params["model_dir"], model_name, time_string))
            # self.board_writer = SummaryWriter("{}/log/{}-{}-tensorboard".format(
            #     self.trainer.params["model_dir"], time_string, model_name))
            self.board_writer = SummaryWriter("{}/log/tensorboard".format(
                self.trainer.params["model_dir"]))
        else:
            self.board_writer = None

        self.epochs = self.trainer.params["epochs"]

        self.optimizer = self.trainer.elements["optimizer"]

        # For optimizer wrapper such as lookahead.
        # "None" is the default value
        if getattr(self.optimizer, "optimizer", None) is not None:
            self.optimizer = self.optimizer.optimizer

        self.device = "[{0}]".format(
            utils.get_device(self.trainer.elements["model"]))

        self.record_value = []

        self.start_write_log = False
        if not self.trainer.params["debug"] and default_params[
                "record_file"] != "" and default_params[
                    "record_file"] is not None:
            self.record_file = "{0}/log/{1}".format(
                self.trainer.params["model_dir"],
                default_params["record_file"])

            # The case to recover training
            if self.trainer.params["start_epoch"] > 0:
                # train.csv using append mode
                self.start_write_log = True
            elif os.path.exists(self.record_file):
                # Do backup to avoid clearing the loss log when re-running a same launcher.
                bk_file = "{0}.backup.{1}".format(
                    self.record_file,
                    time.strftime('%Y-%m-%d_%H:%M:%S',
                                  time.localtime(time.time())))
                shutil.move(self.record_file, bk_file)
        else:
            self.record_file = None

        # A format to show progress
        # Do not use progressbar.Bar(marker="\x1b[32m█\x1b[39m") and progressbar.SimpleProgress(format='%(value_s)s/%(max_value_s)s') to avoid too long string.
        widgets = [
            progressbar.Percentage(format='%(percentage)3.2f%%'), " | ",
            "Epoch:",
            progressbar.Variable('current_epoch',
                                 format='{formatted_value}',
                                 width=0,
                                 precision=0), "/{0}, ".format(self.epochs),
            "Iter:",
            progressbar.Variable('current_iter',
                                 format='{formatted_value}',
                                 width=0,
                                 precision=0),
            "/{0}".format(self.trainer.training_point[2]), " (",
            progressbar.Timer(format='ELA: %(elapsed)s'), ", ",
            progressbar.AdaptiveETA(), ")"
        ]

        # total num of iter
        max_value = self.trainer.params[
            "epochs"] * self.trainer.training_point[2]

        self.bar = progressbar.ProgressBar(max_value=max_value,
                                           widgets=widgets,
                                           redirect_stdout=True)

        # Use multi-process for update.
        self.queue = Queue()
        self.process = Process(target=self._update, daemon=True)
        self.process.start()