def on_epoch_end(
        self,
        epoch_no: int,
        epoch_loss: float,
        training_network: nn.HybridBlock,
        trainer: gluon.Trainer,
        best_epoch_info: Dict[str, Any],
        ctx: mx.Context,
    ) -> bool:
        should_continue = self.lr_scheduler.step(metric_value=epoch_loss)
        if not should_continue:
            print(
                "Early stopping based on learning rate scheduler callback (min_lr was reached)."
            )
            return False

        pre_step_learning_rate = trainer.learning_rate
        trainer.optimizer.set_learning_rate(
            self.lr_scheduler(trainer.optimizer.num_update))

        if not trainer.learning_rate == pre_step_learning_rate:
            if best_epoch_info["epoch_no"] == -1:
                raise GluonTSUserError(
                    "Got NaN in first epoch. Try reducing initial learning rate."
                )

            logger.info(f"Loading parameters from best epoch "
                        f"({best_epoch_info['epoch_no']})")
            training_network.load_parameters(best_epoch_info["params_path"],
                                             ctx)

        return True
    def load_averaged_model(self, model: nn.HybridBlock):
        r"""
        When validating/evaluating the averaged model in the half way of training,
        use load_averaged_model first to load the averaged model and overwrite the current model,
        do the evaluation, and then use load_cached_model to load the current model back.

        Parameters
        ----------
        model
            The model that the averaged model is loaded to.
        """
        if self.averaged_model is not None:
            # cache the current model
            if self.cached_model is None:
                self.cached_model = {
                    k: v.list_data()[0].copy()
                    for k, v in model.collect_params().items()
                }
            else:
                for name, param_cached in self.cached_model.items():
                    param_cached[:] = model.collect_params()[name].list_data()[
                        0
                    ]
            # load the averaged model
            for name, param_avg in self.averaged_model.items():
                model.collect_params()[name].set_data(param_avg)
Beispiel #3
0
def clip_dis(d_model: HybridBlock, clip_size):
    for param_name in d_model.collect_params():
        param = d_model.collect_params(param_name)[param_name]
        cliped_param = mx.nd.clip(param.data(),
                                  a_min=-clip_size,
                                  a_max=clip_size)
        param.set_data(cliped_param)
Beispiel #4
0
 def __init__(
     self,
     bij_blocks: Optional[List[Bijection]] = None,
     *args,
     **kwargs,
 ) -> None:
     HybridBlock.__init__(self, *args, **kwargs)
     ComposedBijection.__init__(self, bij_blocks)
 def load_cached_model(self, model: nn.HybridBlock):
     r"""
     Parameters
     ----------
     model
         The model that the cached model is loaded to.
     """
     if self.cached_model is not None:
         # load the cached model
         for name, param_cached in self.cached_model.items():
             model.collect_params()[name].set_data(param_cached)
Beispiel #6
0
    def on_train_end(
        self,
        training_network: nn.HybridBlock,
        temporary_dir: str,
        ctx: mx.context.Context = None,
    ) -> None:
        logging.info("Computing averaged parameters.")
        averaged_params_path = self.avg_strategy.apply(temporary_dir)

        logging.info("Loading averaged parameters.")
        training_network.load_parameters(averaged_params_path, ctx)
Beispiel #7
0
 def on_train_batch_end(self, network: nn.HybridBlock,
                        time_elapsed: float) -> None:
     self.batch_count += 1
     if (len(self.milestones) > self.seq
             and time_elapsed > self.milestones[self.seq]):
         file = self.directory / f"model_{self.seq}.params"
         network.save_parameters(file.absolute().as_posix())
         self.saved_parameters.append(file)
         self.training_times.append(time_elapsed)
         self.num_gradient_updates.append(self.batch_count)
         self.seq += 1
Beispiel #8
0
 def count_model_params(self, net: nn.HybridBlock) -> int:
     params = net.collect_params()
     num_params = 0
     for p in params:
         v = params[p]
         num_params += np.prod(v.shape)
     return num_params
Beispiel #9
0
    def __call__(
        self,
        net: nn.HybridBlock,
        train_iter: DataLoader,
        validation_iter: Optional[DataLoader] = None,
    ) -> None:
        logger.info("Start model training")
        net.initialize(ctx=self.ctx, init=self.init)

        with HybridContext(
                net=net,
                hybridize=self.hybridize,
                static_alloc=True,
                static_shape=True,
        ):
            self._train_loop(net, train_iter, validation_iter)
Beispiel #10
0
 def update_average(self, model: nn.HybridBlock):
     r"""
     Parameters
     ----------
     model
         The model to update the average.
     """
     self.average_counter += 1
     if self.averaged_model is None:
         self.averaged_model = {
             k: v.list_data()[0].copy()
             for k, v in model.collect_params().items()
         }
     else:
         alpha = (self.eta + 1.0) / (self.eta + self.average_counter)
         # moving average
         for name, param_avg in self.averaged_model.items():
             param_avg[:] += alpha * (
                 model.collect_params()[name].list_data()[0] - param_avg)
Beispiel #11
0
    def __call__(
        self,
        net: nn.HybridBlock,
        input_names: List[str],
        train_iter: TrainDataLoader,
    ) -> None:  # TODO: we may want to return some training information here
        self.halt = False

        with tempfile.TemporaryDirectory(
                prefix="gluonts-trainer-temp-") as gluonts_temp:

            def base_path() -> str:
                return os.path.join(
                    gluonts_temp,
                    "{}_{}".format(STATE_ARTIFACT_FILE_NAME, uuid.uuid4()),
                )

            logging.info("Start model training")

            net.initialize(ctx=self.ctx, init=self.init)

            with HybridContext(
                    net=net,
                    hybridize=self.hybridize,
                    static_alloc=True,
                    static_shape=True,
            ):
                batch_size = train_iter.batch_size
                epoch_loss = mx.metric.Loss()

                best_epoch_info = BestEpochInfo(
                    params_path="%s-%s.params" % (base_path(), "init"),
                    epoch_no=-1,
                    metric_value=np.Inf,
                )

                lr_scheduler = lrs.MetricAttentiveScheduler(
                    objective="min",
                    patience=self.patience,
                    decay_factor=self.learning_rate_decay_factor,
                    min_lr=self.minimum_learning_rate,
                )

                optimizer = mx.optimizer.Adam(
                    learning_rate=self.learning_rate,
                    lr_scheduler=lr_scheduler,
                    wd=self.weight_decay,
                    clip_gradient=self.clip_gradient,
                )

                trainer = mx.gluon.Trainer(
                    net.collect_params(),
                    optimizer=optimizer,
                    kvstore="device",  # FIXME: initialize properly
                )

                for epoch_no in range(self.epochs):
                    if self.halt:
                        logging.info(
                            f"Epoch[{epoch_no}] Interrupting training")
                        break

                    curr_lr = trainer.learning_rate
                    logging.info(
                        f"Epoch[{epoch_no}] Learning rate is {curr_lr}")

                    # mark epoch start time
                    tic = time.time()

                    epoch_loss.reset()

                    with tqdm(train_iter) as it:
                        for batch_no, data_entry in enumerate(it, start=1):
                            if self.halt:
                                break

                            inputs = [data_entry[k] for k in input_names]

                            with mx.autograd.record():
                                output = net(*inputs)

                                # network can returns several outputs, the first being always the loss
                                # when having multiple outputs, the forward returns a list in the case of hybrid and a
                                # tuple otherwise
                                # we may wrap network outputs in the future to avoid this type check
                                if isinstance(output, (list, tuple)):
                                    loss = output[0]
                                else:
                                    loss = output

                            loss.backward()
                            trainer.step(batch_size)

                            epoch_loss.update(None, preds=loss)
                            it.set_postfix(
                                ordered_dict={
                                    "avg_epoch_loss": loss_value(epoch_loss)
                                },
                                refresh=False,
                            )
                            # print out parameters of the network at the first pass
                            if batch_no == 1 and epoch_no == 0:
                                net_name = type(net).__name__
                                num_model_param = self.count_model_params(net)
                                logging.info(
                                    f"Number of parameters in {net_name}: {num_model_param}"
                                )

                    # mark epoch end time and log time cost of current epoch
                    toc = time.time()
                    logging.info(
                        "Epoch[%d] Elapsed time %.3f seconds",
                        epoch_no,
                        (toc - tic),
                    )

                    # check and log epoch loss
                    check_loss_finite(loss_value(epoch_loss))
                    logging.info(
                        "Epoch[%d] Evaluation metric '%s'=%f",
                        epoch_no,
                        "epoch_loss",
                        loss_value(epoch_loss),
                    )

                    lr_scheduler.step(loss_value(epoch_loss))

                    if loss_value(epoch_loss) < best_epoch_info.metric_value:
                        best_epoch_info = BestEpochInfo(
                            params_path="%s-%04d.params" %
                            (base_path(), epoch_no),
                            epoch_no=epoch_no,
                            metric_value=loss_value(epoch_loss),
                        )
                        net.save_parameters(
                            best_epoch_info.params_path
                        )  # TODO: handle possible exception

                    if not trainer.learning_rate == curr_lr:
                        logging.info(f"Loading parameters from best epoch "
                                     f"({best_epoch_info.epoch_no})")
                        net.load_parameters(best_epoch_info.params_path,
                                            self.ctx)

                logging.info(f"Loading parameters from best epoch "
                             f"({best_epoch_info.epoch_no})")
                net.load_parameters(best_epoch_info.params_path, self.ctx)

                logging.info(f"Final loss: {best_epoch_info.metric_value} "
                             f"(occurred at epoch {best_epoch_info.epoch_no})")

                # save net parameters
                net.save_parameters(best_epoch_info.params_path)

                logging.getLogger().info("End model training")
Beispiel #12
0
    def __call__(
        self,
        net: nn.HybridBlock,
        input_names: List[str],
        train_iter: TrainDataLoader,
        validation_iter: Optional[ValidationDataLoader] = None,
    ) -> None:  # TODO: we may want to return some training information here
        is_validation_available = validation_iter is not None
        self.halt = False

        with tempfile.TemporaryDirectory(
                prefix="gluonts-trainer-temp-") as gluonts_temp:

            def base_path() -> str:
                return os.path.join(
                    gluonts_temp,
                    "{}_{}".format(STATE_ARTIFACT_FILE_NAME, uuid.uuid4()),
                )

            logger.info("Start model training")

            net.initialize(ctx=self.ctx, init=self.init)

            with HybridContext(
                    net=net,
                    hybridize=self.hybridize,
                    static_alloc=True,
                    static_shape=True,
            ):
                batch_size = train_iter.batch_size

                best_epoch_info = {
                    "params_path": "%s-%s.params" % (base_path(), "init"),
                    "epoch_no": -1,
                    "score": np.Inf,
                }

                lr_scheduler = lrs.MetricAttentiveScheduler(
                    objective="min",
                    patience=self.patience,
                    decay_factor=self.learning_rate_decay_factor,
                    min_lr=self.minimum_learning_rate,
                )

                optimizer = mx.optimizer.Adam(
                    learning_rate=self.learning_rate,
                    lr_scheduler=lr_scheduler,
                    wd=self.weight_decay,
                    clip_gradient=self.clip_gradient,
                )

                trainer = mx.gluon.Trainer(
                    net.collect_params(),
                    optimizer=optimizer,
                    kvstore="device",  # FIXME: initialize properly
                )

                first_forward = True

                def loop(epoch_no,
                         batch_iter,
                         is_training: bool = True) -> mx.metric.Loss:
                    nonlocal first_forward
                    tic = time.time()

                    epoch_loss = mx.metric.Loss()

                    # use averaged model for validation
                    if not is_training and isinstance(
                            self.avg_strategy, IterationAveragingStrategy):
                        self.avg_strategy.load_averaged_model(net)

                    with tqdm(batch_iter) as it:
                        for batch_no, data_entry in enumerate(it, start=1):
                            if self.halt:
                                break

                            inputs = [data_entry[k] for k in input_names]

                            if first_forward:
                                first_forward = False
                                _ = net(*inputs)
                                if self.post_initialize_cb:
                                    self.post_initialize_cb(net)

                            with mx.autograd.record():
                                output = net(*inputs)

                                # network can returns several outputs, the first being always the loss
                                # when having multiple outputs, the forward returns a list in the case of hybrid and a
                                # tuple otherwise
                                # we may wrap network outputs in the future to avoid this type check
                                if isinstance(output, (list, tuple)):
                                    loss = output[0]
                                else:
                                    loss = output

                            if is_training:
                                loss.backward()
                                trainer.step(batch_size)

                                # iteration averaging in training
                                if isinstance(
                                        self.avg_strategy,
                                        IterationAveragingStrategy,
                                ):
                                    self.avg_strategy.apply(net)

                            epoch_loss.update(None, preds=loss)
                            lv = loss_value(epoch_loss)

                            if not np.isfinite(lv):
                                logger.warning("Epoch[%d] gave nan loss",
                                               epoch_no)
                                return epoch_loss

                            it.set_postfix(
                                ordered_dict={
                                    "epoch":
                                    f"{epoch_no + 1}/{self.epochs}",
                                    ("" if is_training else "validation_") + "avg_epoch_loss":
                                    lv,
                                },
                                refresh=False,
                            )
                            # print out parameters of the network at the first pass
                            if batch_no == 1 and epoch_no == 0:
                                net_name = type(net).__name__
                                num_model_param = self.count_model_params(net)
                                logger.info(
                                    f"Number of parameters in {net_name}: {num_model_param}"
                                )
                    # mark epoch end time and log time cost of current epoch
                    toc = time.time()
                    logger.info(
                        "Epoch[%d] Elapsed time %.3f seconds",
                        epoch_no,
                        (toc - tic),
                    )

                    logger.info(
                        "Epoch[%d] Evaluation metric '%s'=%f",
                        epoch_no,
                        ("" if is_training else "validation_") + "epoch_loss",
                        lv,
                    )

                    if not is_training and isinstance(
                            self.avg_strategy, IterationAveragingStrategy):
                        # bring back the cached model
                        self.avg_strategy.load_cached_model(net)

                    return epoch_loss

                for epoch_no in range(self.epochs):
                    if self.halt:
                        logger.info(f"Epoch[{epoch_no}] Interrupting training")
                        break

                    curr_lr = trainer.learning_rate
                    logger.info(
                        f"Epoch[{epoch_no}] Learning rate is {curr_lr}")

                    epoch_loss = loop(epoch_no, train_iter)
                    if is_validation_available:
                        epoch_loss = loop(epoch_no,
                                          validation_iter,
                                          is_training=False)

                    # update average trigger
                    if isinstance(self.avg_strategy,
                                  IterationAveragingStrategy):
                        self.avg_strategy.update_average_trigger(
                            metric=loss_value(epoch_loss), epoch=epoch_no + 1)
                        # once triggered, update the average immediately
                        self.avg_strategy.apply(net)

                    should_continue = lr_scheduler.step(loss_value(epoch_loss))
                    if isinstance(self.avg_strategy,
                                  IterationAveragingStrategy):
                        logging.info(
                            "Overriding early stopping for iteration-based averaging strategies."
                        )
                        should_continue = True
                    if not should_continue:
                        logger.info("Stopping training")
                        break

                    # save model and epoch info
                    bp = base_path()
                    epoch_info = {
                        "params_path": f"{bp}-0000.params",
                        "epoch_no": epoch_no,
                        "score": loss_value(epoch_loss),
                    }

                    net.save_parameters(epoch_info["params_path"]
                                        )  # TODO: handle possible exception

                    save_epoch_info(bp, epoch_info)

                    # update best epoch info - needed for the learning rate scheduler
                    if loss_value(epoch_loss) < best_epoch_info["score"]:
                        best_epoch_info = epoch_info.copy()

                    if not trainer.learning_rate == curr_lr:
                        if best_epoch_info["epoch_no"] == -1:
                            raise GluonTSUserError(
                                "Got NaN in first epoch. Try reducing initial learning rate."
                            )

                        logger.info(f"Loading parameters from best epoch "
                                    f"({best_epoch_info['epoch_no']})")
                        net.load_parameters(best_epoch_info["params_path"],
                                            self.ctx)

                if isinstance(self.avg_strategy, AveragingStrategy):
                    logging.info("Computing averaged parameters.")
                    averaged_params_path = self.avg_strategy.apply(
                        gluonts_temp)

                    logging.info("Loading averaged parameters.")
                    net.load_parameters(averaged_params_path, self.ctx)

                if isinstance(self.avg_strategy, IterationAveragingStrategy):
                    logging.info("Loading averaged parameters.")
                    self.avg_strategy.load_averaged_model(net)

                logger.info("End model training")
Beispiel #13
0
 def on_network_initialization_end(self, network: nn.HybridBlock) -> None:
     self.num_parameters = sum(
         np.prod(p.shape) for p in network.collect_params().values()
     )
Beispiel #14
0
    def _train_loop(  # pylint: disable=too-many-statements
        self,
        net: nn.HybridBlock,
        train_iter: DataLoader,
        validation_iter: Optional[DataLoader],
    ) -> None:
        optimizer = mx.optimizer.Adam(
            learning_rate=self.learning_rate,
            wd=self.weight_decay,
            clip_gradient=self.clip_gradient,
        )

        trainer = mx.gluon.Trainer(
            net.collect_params(),
            optimizer=optimizer,
            kvstore="device",
        )

        first_forward = True
        time_elapsed = 0
        validation_idx = 0

        def loop(
            batch_iter: DataLoader,
            num_batches_to_use: Optional[int] = None,
            is_training: bool = True,
        ) -> mx.metric.Loss:
            nonlocal first_forward, time_elapsed, validation_idx

            tic = time.time()
            subtic = 0

            epoch_loss = mx.metric.Loss()
            batch_iter = itertools.islice(batch_iter, num_batches_to_use)

            it = tqdm(batch_iter, total=num_batches_to_use)
            for batch_no, batch in enumerate(it, start=1):
                # `batch` here is expected to be a dictionary whose fields
                # should correspond 1-to-1 with the network inputs
                # see below how `batch.values()` is fed into the network
                if first_forward:
                    tictic = time.time()
                    first_forward = False
                    _ = net(*batch.values())
                    self.callbacks.on_network_initialization_end(net)
                    subtic += time.time() - tictic

                with mx.autograd.record():  # type: ignore
                    # we set the mode explicitly as by default mxnet assumes
                    # predict mode and hence dropout layers are not used if
                    # the mode is not explicitly set to training
                    mode = (autograd.train_mode
                            if is_training else autograd.predict_mode)
                    with mode():
                        output = net(*batch.values())

                    # network can returns several outputs, the first being
                    # always the loss when having multiple outputs, the
                    # forward returns a list in the case of hybrid and a
                    # tuple otherwise we may wrap network outputs in the
                    # future to avoid this type check
                    if isinstance(output, (list, tuple)):
                        loss = output[0]
                    else:
                        loss = output

                    batch_size = loss.shape[0]

                # pylint: disable=no-member
                if not np.isfinite(
                        ndarray.sum(loss).asscalar()):  # type: ignore
                    logger.warning(
                        "Batch [%d] gave NaN loss and it will be ignored",
                        batch_no,
                    )
                else:
                    if is_training:
                        loss.backward()
                        trainer.step(batch_size)
                    epoch_loss.update(None, preds=loss)

                if is_training:
                    total_time_elapsed = (time_elapsed + time.time() - tic -
                                          subtic)

                    orig_lr = trainer.learning_rate
                    tictic = time.time()
                    self.callbacks.on_train_batch_end(net, total_time_elapsed)
                    subtic += time.time() - tictic
                    if trainer.learning_rate != orig_lr:
                        logger.info(
                            "Trainer learning rate set to %f",
                            trainer.learning_rate,
                        )

                lv = _loss_value(epoch_loss)
                it.set_postfix(
                    ordered_dict={
                        ("" if is_training else "validation_") + "avg_epoch_loss":
                        lv,
                    },
                    refresh=False,
                )

                # Check if should finish
                if is_training:
                    if total_time_elapsed > self.training_time:  # type: ignore
                        time_elapsed = total_time_elapsed  # type: ignore
                        break
                    if len(self.validation_milestones) > validation_idx and (
                            total_time_elapsed  # type: ignore
                            > self.validation_milestones[validation_idx]):
                        time_elapsed = total_time_elapsed  # type: ignore
                        validation_idx += 1
                        break
                # If validating, call the callback with the loss
                else:
                    self.callbacks.on_validation_epoch_end(lv)

            # mark epoch end time and log time cost of current epoch
            toc = time.time()
            logger.info("Elapsed time %.3f seconds", toc - tic)
            logger.info(
                "Evaluation metric '%s'=%f",
                ("" if is_training else "validation_") + "epoch_loss",
                lv,  # type: ignore
            )

            return epoch_loss

        self.callbacks.on_train_start(trainer)
        while True:
            loop(train_iter)
            if validation_iter is not None:
                loop(validation_iter, is_training=False)
            if time_elapsed > self.training_time:
                break

        logger.info("End model training")
Beispiel #15
0
    def __call__(
        self,
        net: nn.HybridBlock,
        train_iter: DataLoader,
        validation_iter: Optional[DataLoader] = None,
    ) -> None:  # TODO: we may want to return some training information here
        """
        Train a network, given an iterable over training (and optionally
        validation) batches.

        Parameters
        ----------
        net
            Network to be trained. This a Gluon HybridBlock, assumed to produce
            a tensor of loss values as output.
        train_iter
            An iterable over batches to be used for training. Batches are
            assumed to be dictionaries, whose values are MXNet arrays that
            correspond to the network inputs.
        validation_iter
            Similar to `train_iter` but the batches produced here are used to
            compute validation metrics.
        """
        is_validation_available = validation_iter is not None

        logger.info("Start model training")
        net.initialize(ctx=self.ctx, init=self.init)

        with tempfile.TemporaryDirectory(
                prefix="gluonts-trainer-temp-") as gluonts_temp, HybridContext(
                    net=net,
                    hybridize=self.hybridize,
                    static_alloc=True,
                    static_shape=True,
                ):

            def base_path() -> str:
                return os.path.join(
                    gluonts_temp,
                    "{}_{}".format(STATE_ARTIFACT_FILE_NAME, uuid.uuid4()),
                )

            best_epoch_info = {
                "params_path": "%s-%s.params" % (base_path(), "init"),
                "epoch_no": -1,
                "score": np.Inf,
            }

            optimizer = mx.optimizer.Adam(
                learning_rate=self.learning_rate,
                wd=self.weight_decay,
                clip_gradient=self.clip_gradient,
            )

            trainer = mx.gluon.Trainer(
                net.collect_params(),
                optimizer=optimizer,
                kvstore="device",  # FIXME: initialize properly
            )

            first_forward = True

            def loop(  # todo call run epoch
                epoch_no,
                batch_iter,
                num_batches_to_use: Optional[int] = None,
                is_training: bool = True,
            ) -> mx.metric.Loss:
                nonlocal first_forward
                tic = time.time()

                epoch_loss = mx.metric.Loss()

                if is_training:
                    # We should not call this method if we haven't compiled the
                    # network yet. Instead, this callback is called after
                    # network initialization.
                    if not first_forward:
                        self.callbacks.on_train_epoch_start(
                            training_network=net)
                else:
                    self.callbacks.on_validation_epoch_start(
                        training_network=net)

                batch_iter = itertools.islice(batch_iter, num_batches_to_use)

                it = tqdm(batch_iter, total=num_batches_to_use)
                for batch_no, batch in enumerate(it, start=1):
                    # `batch` here is expected to be a dictionary whose fields
                    # should correspond 1-to-1 with the network inputs
                    # see below how `batch.values()` is fed into the network
                    if self.halt:
                        break

                    if first_forward:
                        first_forward = False
                        _ = net(*batch.values())

                        self.callbacks.on_network_initializing_end(
                            training_network=net)

                        # Call the batch start callback as the model was not
                        # compiled before
                        self.callbacks.on_train_epoch_start(
                            training_network=net)

                    with mx.autograd.record():
                        # we set the mode explicitly as by default mxnet assumes
                        # predict mode and hence dropout layers are not used if
                        # the mode is not explicitly set to training
                        mode = (autograd.train_mode
                                if is_training else autograd.predict_mode)
                        with mode():
                            output = net(*batch.values())

                        # network can returns several outputs, the first being
                        # always the loss when having multiple outputs, the
                        # forward returns a list in the case of hybrid and a
                        # tuple otherwise we may wrap network outputs in the
                        # future to avoid this type check
                        if isinstance(output, (list, tuple)):
                            loss = output[0]
                        else:
                            loss = output

                        batch_size = loss.shape[0]

                    if not np.isfinite(ndarray.sum(loss).asscalar()):
                        logger.warning(
                            "Batch [%d] of Epoch[%d] gave NaN loss and it will be ignored",
                            batch_no,
                            epoch_no,
                        )
                    else:
                        if is_training:
                            loss.backward()
                            trainer.step(batch_size)

                            self.callbacks.on_train_batch_end(
                                training_network=net)
                        else:
                            self.callbacks.on_validation_batch_end(
                                training_network=net)

                        epoch_loss.update(None, preds=loss)

                    lv = loss_value(epoch_loss)
                    it.set_postfix(
                        ordered_dict={
                            "epoch":
                            f"{epoch_no + 1}/{self.epochs}",
                            ("" if is_training else "validation_") + "avg_epoch_loss":
                            lv,
                        },
                        refresh=False,
                    )
                    # print out parameters of the network at the first pass
                    if batch_no == 1 and epoch_no == 0:
                        net_name = type(net).__name__
                        num_model_param = self.count_model_params(net)
                        logger.info(
                            f"Number of parameters in {net_name}: {num_model_param}"
                        )
                it.close()

                # mark epoch end time and log time cost of current epoch
                toc = time.time()
                logger.info(
                    "Epoch[%d] Elapsed time %.3f seconds",
                    epoch_no,
                    (toc - tic),
                )

                logger.info(
                    "Epoch[%d] Evaluation metric '%s'=%f",
                    epoch_no,
                    ("" if is_training else "validation_") + "epoch_loss",
                    lv,
                )

                return epoch_loss

            self.callbacks.on_train_start(max_epochs=self.epochs)

            for epoch_no in range(self.epochs):
                if self.halt:
                    logger.info(f"Epoch[{epoch_no}] Interrupting training")
                    break

                curr_lr = trainer.learning_rate
                logger.info(f"Epoch[{epoch_no}] Learning rate is {curr_lr}")

                epoch_loss = loop(
                    epoch_no,
                    train_iter,
                    num_batches_to_use=self.num_batches_per_epoch,
                )

                should_continue = self.callbacks.on_train_epoch_end(
                    epoch_no=epoch_no,
                    epoch_loss=loss_value(epoch_loss),
                    training_network=net,
                    trainer=trainer,
                )

                if is_validation_available:
                    epoch_loss = loop(epoch_no,
                                      validation_iter,
                                      is_training=False)

                    should_continue = (should_continue and
                                       self.callbacks.on_validation_epoch_end(
                                           epoch_no=epoch_no,
                                           epoch_loss=loss_value(epoch_loss),
                                           training_network=net,
                                           trainer=trainer,
                                       ))

                # save model and epoch info
                bp = base_path()
                epoch_info = {
                    "params_path": f"{bp}-0000.params",
                    "epoch_no": epoch_no,
                    "score": loss_value(epoch_loss),
                }

                net.save_parameters(epoch_info["params_path"]
                                    )  # TODO: handle possible exception

                save_epoch_info(bp, epoch_info)

                # update best epoch info
                if loss_value(epoch_loss) < best_epoch_info["score"]:
                    best_epoch_info = epoch_info.copy()

                should_continue = (should_continue
                                   and self.callbacks.on_epoch_end(
                                       epoch_no=epoch_no,
                                       epoch_loss=loss_value(epoch_loss),
                                       training_network=net,
                                       trainer=trainer,
                                       best_epoch_info=best_epoch_info,
                                       ctx=self.ctx,
                                   ))

                if not should_continue:
                    logger.info("Stopping training")
                    break

            self.callbacks.on_train_end(
                training_network=net,
                temporary_dir=gluonts_temp,
                ctx=self.ctx,
            )

            logger.info("End model training")
Beispiel #16
0
    def __call__(
        self,
        net: nn.HybridBlock,
        train_iter: DataLoader,
        validation_iter: Optional[DataLoader] = None,
    ) -> None:  # TODO: we may want to return some training information here
        """
        Train a network, given an iterable over training (and optionally validation) batches.

        Parameters
        ----------
        net
            Network to be trained. This a Gluon HybridBlock, assumed to produce a tensor
            of loss values as output.
        train_iter
            An iterable over batches to be used for training. Batches are assumed to be
            dictionaries, whose values are MXNet arrays that correspond to the network
            inputs.
        validation_iter
            Similar to `train_iter` but the batches produced here are used to compute
            validation metrics.
        """
        is_validation_available = validation_iter is not None

        with tempfile.TemporaryDirectory(
            prefix="gluonts-trainer-temp-"
        ) as gluonts_temp:

            def base_path() -> str:
                return os.path.join(
                    gluonts_temp,
                    "{}_{}".format(STATE_ARTIFACT_FILE_NAME, uuid.uuid4()),
                )

            logger.info("Start model training")

            net.initialize(ctx=self.ctx, init=self.init)

            with HybridContext(
                net=net,
                hybridize=self.hybridize,
                static_alloc=True,
                static_shape=True,
            ):
                best_epoch_info = {
                    "params_path": "%s-%s.params" % (base_path(), "init"),
                    "epoch_no": -1,
                    "score": np.Inf,
                }

                lr_scheduler = lrs.MetricAttentiveScheduler(
                    objective="min",
                    patience=self.patience,
                    decay_factor=self.learning_rate_decay_factor,
                    min_lr=self.minimum_learning_rate,
                )

                optimizer = mx.optimizer.Adam(
                    learning_rate=self.learning_rate,
                    lr_scheduler=lr_scheduler,
                    wd=self.weight_decay,
                    clip_gradient=self.clip_gradient,
                )

                trainer = mx.gluon.Trainer(
                    net.collect_params(),
                    optimizer=optimizer,
                    kvstore="device",  # FIXME: initialize properly
                )

                first_forward = True

                def loop(
                    epoch_no,
                    batch_iter,
                    num_batches_to_use: Optional[int] = None,
                    is_training: bool = True,
                ) -> mx.metric.Loss:
                    nonlocal first_forward
                    tic = time.time()

                    epoch_loss = mx.metric.Loss()

                    # use averaged model for validation
                    if not is_training and isinstance(
                        self.avg_strategy, IterationAveragingStrategy
                    ):
                        self.avg_strategy.load_averaged_model(net)

                    batch_iter = itertools.islice(
                        batch_iter, num_batches_to_use
                    )

                    with tqdm(batch_iter, total=num_batches_to_use) as it:
                        for batch_no, batch in enumerate(it, start=1):
                            # `batch` here is expected to be a dictionary whose fields
                            # should correspond 1-to-1 with the network inputs
                            # see below how `batch.values()` is fed into the network

                            if first_forward:
                                first_forward = False
                                _ = net(*batch.values())
                                if self.post_initialize_cb:
                                    self.post_initialize_cb(net)

                            with mx.autograd.record():
                                # we set the mode explicitly as by default mxnet assumes predict mode and hence
                                # dropout layers are not used if the mode is not explicitly set to training
                                mode = (
                                    autograd.train_mode
                                    if is_training
                                    else autograd.predict_mode
                                )
                                with mode():
                                    output = net(*batch.values())

                                # network can returns several outputs, the first being always the loss
                                # when having multiple outputs, the forward returns a list in the case of hybrid and a
                                # tuple otherwise
                                # we may wrap network outputs in the future to avoid this type check
                                if isinstance(output, (list, tuple)):
                                    loss = output[0]
                                else:
                                    loss = output

                                batch_size = loss.shape[0]

                            if not np.isfinite(ndarray.sum(loss).asscalar()):
                                logger.warning(
                                    "Batch [%d] of Epoch[%d] gave NaN loss and it will be ignored",
                                    batch_no,
                                    epoch_no,
                                )
                            else:
                                if is_training:
                                    loss.backward()
                                    trainer.step(batch_size)

                                    # iteration averaging in training
                                    if isinstance(
                                        self.avg_strategy,
                                        IterationAveragingStrategy,
                                    ):
                                        self.avg_strategy.apply(net)

                                epoch_loss.update(None, preds=loss)

                            lv = loss_value(epoch_loss)
                            it.set_postfix(
                                ordered_dict={
                                    "epoch": f"{epoch_no + 1}/{self.epochs}",
                                    ("" if is_training else "validation_")
                                    + "avg_epoch_loss": lv,
                                },
                                refresh=False,
                            )
                            # print out parameters of the network at the first pass
                            if batch_no == 1 and epoch_no == 0:
                                net_name = type(net).__name__
                                num_model_param = self.count_model_params(net)
                                logger.info(
                                    f"Number of parameters in {net_name}: {num_model_param}"
                                )
                    # mark epoch end time and log time cost of current epoch
                    toc = time.time()
                    logger.info(
                        "Epoch[%d] Elapsed time %.3f seconds",
                        epoch_no,
                        (toc - tic),
                    )

                    logger.info(
                        "Epoch[%d] Evaluation metric '%s'=%f",
                        epoch_no,
                        ("" if is_training else "validation_") + "epoch_loss",
                        lv,
                    )

                    if not is_training and isinstance(
                        self.avg_strategy, IterationAveragingStrategy
                    ):
                        # bring back the cached model
                        self.avg_strategy.load_cached_model(net)

                    return epoch_loss

                for epoch_no in range(self.epochs):

                    curr_lr = trainer.learning_rate
                    logger.info(
                        f"Epoch[{epoch_no}] Learning rate is {curr_lr}"
                    )

                    epoch_loss = loop(
                        epoch_no,
                        train_iter,
                        num_batches_to_use=self.num_batches_per_epoch,
                    )
                    if is_validation_available:
                        epoch_loss = loop(
                            epoch_no, validation_iter, is_training=False
                        )

                    # update average trigger
                    if isinstance(
                        self.avg_strategy, IterationAveragingStrategy
                    ):
                        self.avg_strategy.update_average_trigger(
                            metric=loss_value(epoch_loss), epoch=epoch_no + 1
                        )
                        # once triggered, update the average immediately
                        self.avg_strategy.apply(net)

                    should_continue = lr_scheduler.step(loss_value(epoch_loss))
                    if isinstance(
                        self.avg_strategy, IterationAveragingStrategy
                    ):
                        logging.info(
                            "Overriding early stopping for iteration-based averaging strategies."
                        )
                        should_continue = True
                    if not should_continue:
                        logger.info("Stopping training")
                        break

                    # save model and epoch info
                    bp = base_path()
                    epoch_info = {
                        "params_path": f"{bp}-0000.params",
                        "epoch_no": epoch_no,
                        "score": loss_value(epoch_loss),
                    }

                    net.save_parameters(
                        epoch_info["params_path"]
                    )  # TODO: handle possible exception

                    save_epoch_info(bp, epoch_info)

                    # update best epoch info - needed for the learning rate scheduler
                    if loss_value(epoch_loss) < best_epoch_info["score"]:
                        best_epoch_info = epoch_info.copy()

                    if not trainer.learning_rate == curr_lr:
                        if best_epoch_info["epoch_no"] == -1:
                            raise GluonTSUserError(
                                "Got NaN in first epoch. Try reducing initial learning rate."
                            )

                        logger.info(
                            f"Loading parameters from best epoch "
                            f"({best_epoch_info['epoch_no']})"
                        )
                        net.load_parameters(
                            best_epoch_info["params_path"], self.ctx
                        )

                if isinstance(self.avg_strategy, AveragingStrategy):
                    logging.info("Computing averaged parameters.")
                    averaged_params_path = self.avg_strategy.apply(
                        gluonts_temp
                    )

                    logging.info("Loading averaged parameters.")
                    net.load_parameters(averaged_params_path, self.ctx)

                if isinstance(self.avg_strategy, IterationAveragingStrategy):
                    logging.info("Loading averaged parameters.")
                    self.avg_strategy.load_averaged_model(net)

                logger.info("End model training")