Exemple #1
0
    def __call__(
        self,
        net: nn.HybridBlock,
        input_names: List[str],
        train_iter: TrainDataLoader,
    ) -> None:  # TODO: we may want to return some training information here
        self.halt = False

        with tempfile.TemporaryDirectory(
                prefix="gluonts-trainer-temp-") as gluonts_temp:

            def base_path() -> str:
                return os.path.join(
                    gluonts_temp,
                    "{}_{}".format(STATE_ARTIFACT_FILE_NAME, uuid.uuid4()),
                )

            logging.info("Start model training")

            net.initialize(ctx=self.ctx, init=self.init)

            with HybridContext(
                    net=net,
                    hybridize=self.hybridize,
                    static_alloc=True,
                    static_shape=True,
            ):
                batch_size = train_iter.batch_size
                epoch_loss = mx.metric.Loss()

                best_epoch_info = BestEpochInfo(
                    params_path="%s-%s.params" % (base_path(), "init"),
                    epoch_no=-1,
                    metric_value=np.Inf,
                )

                lr_scheduler = lrs.MetricAttentiveScheduler(
                    objective="min",
                    patience=self.patience,
                    decay_factor=self.learning_rate_decay_factor,
                    min_lr=self.minimum_learning_rate,
                )

                optimizer = mx.optimizer.Adam(
                    learning_rate=self.learning_rate,
                    lr_scheduler=lr_scheduler,
                    wd=self.weight_decay,
                    clip_gradient=self.clip_gradient,
                )

                trainer = mx.gluon.Trainer(
                    net.collect_params(),
                    optimizer=optimizer,
                    kvstore="device",  # FIXME: initialize properly
                )

                for epoch_no in range(self.epochs):
                    if self.halt:
                        logging.info(
                            f"Epoch[{epoch_no}] Interrupting training")
                        break

                    curr_lr = trainer.learning_rate
                    logging.info(
                        f"Epoch[{epoch_no}] Learning rate is {curr_lr}")

                    # mark epoch start time
                    tic = time.time()

                    epoch_loss.reset()

                    with tqdm(train_iter) as it:
                        for batch_no, data_entry in enumerate(it, start=1):
                            if self.halt:
                                break

                            inputs = [data_entry[k] for k in input_names]

                            with mx.autograd.record():
                                output = net(*inputs)

                                # network can returns several outputs, the first being always the loss
                                # when having multiple outputs, the forward returns a list in the case of hybrid and a
                                # tuple otherwise
                                # we may wrap network outputs in the future to avoid this type check
                                if isinstance(output, (list, tuple)):
                                    loss = output[0]
                                else:
                                    loss = output

                            loss.backward()
                            trainer.step(batch_size)

                            epoch_loss.update(None, preds=loss)
                            it.set_postfix(
                                ordered_dict={
                                    "avg_epoch_loss": loss_value(epoch_loss)
                                },
                                refresh=False,
                            )
                            # print out parameters of the network at the first pass
                            if batch_no == 1 and epoch_no == 0:
                                net_name = type(net).__name__
                                num_model_param = self.count_model_params(net)
                                logging.info(
                                    f"Number of parameters in {net_name}: {num_model_param}"
                                )

                    # mark epoch end time and log time cost of current epoch
                    toc = time.time()
                    logging.info(
                        "Epoch[%d] Elapsed time %.3f seconds",
                        epoch_no,
                        (toc - tic),
                    )

                    # check and log epoch loss
                    check_loss_finite(loss_value(epoch_loss))
                    logging.info(
                        "Epoch[%d] Evaluation metric '%s'=%f",
                        epoch_no,
                        "epoch_loss",
                        loss_value(epoch_loss),
                    )

                    lr_scheduler.step(loss_value(epoch_loss))

                    if loss_value(epoch_loss) < best_epoch_info.metric_value:
                        best_epoch_info = BestEpochInfo(
                            params_path="%s-%04d.params" %
                            (base_path(), epoch_no),
                            epoch_no=epoch_no,
                            metric_value=loss_value(epoch_loss),
                        )
                        net.save_parameters(
                            best_epoch_info.params_path
                        )  # TODO: handle possible exception

                    if not trainer.learning_rate == curr_lr:
                        logging.info(f"Loading parameters from best epoch "
                                     f"({best_epoch_info.epoch_no})")
                        net.load_parameters(best_epoch_info.params_path,
                                            self.ctx)

                logging.info(f"Loading parameters from best epoch "
                             f"({best_epoch_info.epoch_no})")
                net.load_parameters(best_epoch_info.params_path, self.ctx)

                logging.info(f"Final loss: {best_epoch_info.metric_value} "
                             f"(occurred at epoch {best_epoch_info.epoch_no})")

                # save net parameters
                net.save_parameters(best_epoch_info.params_path)

                logging.getLogger().info("End model training")
Exemple #2
0
    def __call__(
        self,
        net: nn.HybridBlock,
        input_names: List[str],
        train_iter: TrainDataLoader,
        validation_iter: Optional[ValidationDataLoader] = None,
    ) -> None:  # TODO: we may want to return some training information here
        is_validation_available = validation_iter is not None
        self.halt = False

        with tempfile.TemporaryDirectory(
                prefix="gluonts-trainer-temp-") as gluonts_temp:

            def base_path() -> str:
                return os.path.join(
                    gluonts_temp,
                    "{}_{}".format(STATE_ARTIFACT_FILE_NAME, uuid.uuid4()),
                )

            logger.info("Start model training")

            net.initialize(ctx=self.ctx, init=self.init)

            with HybridContext(
                    net=net,
                    hybridize=self.hybridize,
                    static_alloc=True,
                    static_shape=True,
            ):
                batch_size = train_iter.batch_size

                best_epoch_info = {
                    "params_path": "%s-%s.params" % (base_path(), "init"),
                    "epoch_no": -1,
                    "score": np.Inf,
                }

                lr_scheduler = lrs.MetricAttentiveScheduler(
                    objective="min",
                    patience=self.patience,
                    decay_factor=self.learning_rate_decay_factor,
                    min_lr=self.minimum_learning_rate,
                )

                optimizer = mx.optimizer.Adam(
                    learning_rate=self.learning_rate,
                    lr_scheduler=lr_scheduler,
                    wd=self.weight_decay,
                    clip_gradient=self.clip_gradient,
                )

                trainer = mx.gluon.Trainer(
                    net.collect_params(),
                    optimizer=optimizer,
                    kvstore="device",  # FIXME: initialize properly
                )

                first_forward = True

                def loop(epoch_no,
                         batch_iter,
                         is_training: bool = True) -> mx.metric.Loss:
                    nonlocal first_forward
                    tic = time.time()

                    epoch_loss = mx.metric.Loss()

                    # use averaged model for validation
                    if not is_training and isinstance(
                            self.avg_strategy, IterationAveragingStrategy):
                        self.avg_strategy.load_averaged_model(net)

                    with tqdm(batch_iter) as it:
                        for batch_no, data_entry in enumerate(it, start=1):
                            if self.halt:
                                break

                            inputs = [data_entry[k] for k in input_names]

                            if first_forward:
                                first_forward = False
                                _ = net(*inputs)
                                if self.post_initialize_cb:
                                    self.post_initialize_cb(net)

                            with mx.autograd.record():
                                output = net(*inputs)

                                # network can returns several outputs, the first being always the loss
                                # when having multiple outputs, the forward returns a list in the case of hybrid and a
                                # tuple otherwise
                                # we may wrap network outputs in the future to avoid this type check
                                if isinstance(output, (list, tuple)):
                                    loss = output[0]
                                else:
                                    loss = output

                            if is_training:
                                loss.backward()
                                trainer.step(batch_size)

                                # iteration averaging in training
                                if isinstance(
                                        self.avg_strategy,
                                        IterationAveragingStrategy,
                                ):
                                    self.avg_strategy.apply(net)

                            epoch_loss.update(None, preds=loss)
                            lv = loss_value(epoch_loss)

                            if not np.isfinite(lv):
                                logger.warning("Epoch[%d] gave nan loss",
                                               epoch_no)
                                return epoch_loss

                            it.set_postfix(
                                ordered_dict={
                                    "epoch":
                                    f"{epoch_no + 1}/{self.epochs}",
                                    ("" if is_training else "validation_") + "avg_epoch_loss":
                                    lv,
                                },
                                refresh=False,
                            )
                            # print out parameters of the network at the first pass
                            if batch_no == 1 and epoch_no == 0:
                                net_name = type(net).__name__
                                num_model_param = self.count_model_params(net)
                                logger.info(
                                    f"Number of parameters in {net_name}: {num_model_param}"
                                )
                    # mark epoch end time and log time cost of current epoch
                    toc = time.time()
                    logger.info(
                        "Epoch[%d] Elapsed time %.3f seconds",
                        epoch_no,
                        (toc - tic),
                    )

                    logger.info(
                        "Epoch[%d] Evaluation metric '%s'=%f",
                        epoch_no,
                        ("" if is_training else "validation_") + "epoch_loss",
                        lv,
                    )

                    if not is_training and isinstance(
                            self.avg_strategy, IterationAveragingStrategy):
                        # bring back the cached model
                        self.avg_strategy.load_cached_model(net)

                    return epoch_loss

                for epoch_no in range(self.epochs):
                    if self.halt:
                        logger.info(f"Epoch[{epoch_no}] Interrupting training")
                        break

                    curr_lr = trainer.learning_rate
                    logger.info(
                        f"Epoch[{epoch_no}] Learning rate is {curr_lr}")

                    epoch_loss = loop(epoch_no, train_iter)
                    if is_validation_available:
                        epoch_loss = loop(epoch_no,
                                          validation_iter,
                                          is_training=False)

                    # update average trigger
                    if isinstance(self.avg_strategy,
                                  IterationAveragingStrategy):
                        self.avg_strategy.update_average_trigger(
                            metric=loss_value(epoch_loss), epoch=epoch_no + 1)
                        # once triggered, update the average immediately
                        self.avg_strategy.apply(net)

                    should_continue = lr_scheduler.step(loss_value(epoch_loss))
                    if isinstance(self.avg_strategy,
                                  IterationAveragingStrategy):
                        logging.info(
                            "Overriding early stopping for iteration-based averaging strategies."
                        )
                        should_continue = True
                    if not should_continue:
                        logger.info("Stopping training")
                        break

                    # save model and epoch info
                    bp = base_path()
                    epoch_info = {
                        "params_path": f"{bp}-0000.params",
                        "epoch_no": epoch_no,
                        "score": loss_value(epoch_loss),
                    }

                    net.save_parameters(epoch_info["params_path"]
                                        )  # TODO: handle possible exception

                    save_epoch_info(bp, epoch_info)

                    # update best epoch info - needed for the learning rate scheduler
                    if loss_value(epoch_loss) < best_epoch_info["score"]:
                        best_epoch_info = epoch_info.copy()

                    if not trainer.learning_rate == curr_lr:
                        if best_epoch_info["epoch_no"] == -1:
                            raise GluonTSUserError(
                                "Got NaN in first epoch. Try reducing initial learning rate."
                            )

                        logger.info(f"Loading parameters from best epoch "
                                    f"({best_epoch_info['epoch_no']})")
                        net.load_parameters(best_epoch_info["params_path"],
                                            self.ctx)

                if isinstance(self.avg_strategy, AveragingStrategy):
                    logging.info("Computing averaged parameters.")
                    averaged_params_path = self.avg_strategy.apply(
                        gluonts_temp)

                    logging.info("Loading averaged parameters.")
                    net.load_parameters(averaged_params_path, self.ctx)

                if isinstance(self.avg_strategy, IterationAveragingStrategy):
                    logging.info("Loading averaged parameters.")
                    self.avg_strategy.load_averaged_model(net)

                logger.info("End model training")