def on_epoch_end( self, epoch_no: int, epoch_loss: float, training_network: nn.HybridBlock, trainer: gluon.Trainer, best_epoch_info: Dict[str, Any], ctx: mx.Context, ) -> bool: should_continue = self.lr_scheduler.step(metric_value=epoch_loss) if not should_continue: print( "Early stopping based on learning rate scheduler callback (min_lr was reached)." ) return False pre_step_learning_rate = trainer.learning_rate trainer.optimizer.set_learning_rate( self.lr_scheduler(trainer.optimizer.num_update)) if not trainer.learning_rate == pre_step_learning_rate: if best_epoch_info["epoch_no"] == -1: raise GluonTSUserError( "Got NaN in first epoch. Try reducing initial learning rate." ) logger.info(f"Loading parameters from best epoch " f"({best_epoch_info['epoch_no']})") training_network.load_parameters(best_epoch_info["params_path"], ctx) return True
def on_train_end( self, training_network: nn.HybridBlock, temporary_dir: str, ctx: mx.context.Context = None, ) -> None: logging.info("Computing averaged parameters.") averaged_params_path = self.avg_strategy.apply(temporary_dir) logging.info("Loading averaged parameters.") training_network.load_parameters(averaged_params_path, ctx)
def __call__( self, net: nn.HybridBlock, input_names: List[str], train_iter: TrainDataLoader, ) -> None: # TODO: we may want to return some training information here self.halt = False with tempfile.TemporaryDirectory( prefix="gluonts-trainer-temp-") as gluonts_temp: def base_path() -> str: return os.path.join( gluonts_temp, "{}_{}".format(STATE_ARTIFACT_FILE_NAME, uuid.uuid4()), ) logging.info("Start model training") net.initialize(ctx=self.ctx, init=self.init) with HybridContext( net=net, hybridize=self.hybridize, static_alloc=True, static_shape=True, ): batch_size = train_iter.batch_size epoch_loss = mx.metric.Loss() best_epoch_info = BestEpochInfo( params_path="%s-%s.params" % (base_path(), "init"), epoch_no=-1, metric_value=np.Inf, ) lr_scheduler = lrs.MetricAttentiveScheduler( objective="min", patience=self.patience, decay_factor=self.learning_rate_decay_factor, min_lr=self.minimum_learning_rate, ) optimizer = mx.optimizer.Adam( learning_rate=self.learning_rate, lr_scheduler=lr_scheduler, wd=self.weight_decay, clip_gradient=self.clip_gradient, ) trainer = mx.gluon.Trainer( net.collect_params(), optimizer=optimizer, kvstore="device", # FIXME: initialize properly ) for epoch_no in range(self.epochs): if self.halt: logging.info( f"Epoch[{epoch_no}] Interrupting training") break curr_lr = trainer.learning_rate logging.info( f"Epoch[{epoch_no}] Learning rate is {curr_lr}") # mark epoch start time tic = time.time() epoch_loss.reset() with tqdm(train_iter) as it: for batch_no, data_entry in enumerate(it, start=1): if self.halt: break inputs = [data_entry[k] for k in input_names] with mx.autograd.record(): output = net(*inputs) # network can returns several outputs, the first being always the loss # when having multiple outputs, the forward returns a list in the case of hybrid and a # tuple otherwise # we may wrap network outputs in the future to avoid this type check if isinstance(output, (list, tuple)): loss = output[0] else: loss = output loss.backward() trainer.step(batch_size) epoch_loss.update(None, preds=loss) it.set_postfix( ordered_dict={ "avg_epoch_loss": loss_value(epoch_loss) }, refresh=False, ) # print out parameters of the network at the first pass if batch_no == 1 and epoch_no == 0: net_name = type(net).__name__ num_model_param = self.count_model_params(net) logging.info( f"Number of parameters in {net_name}: {num_model_param}" ) # mark epoch end time and log time cost of current epoch toc = time.time() logging.info( "Epoch[%d] Elapsed time %.3f seconds", epoch_no, (toc - tic), ) # check and log epoch loss check_loss_finite(loss_value(epoch_loss)) logging.info( "Epoch[%d] Evaluation metric '%s'=%f", epoch_no, "epoch_loss", loss_value(epoch_loss), ) lr_scheduler.step(loss_value(epoch_loss)) if loss_value(epoch_loss) < best_epoch_info.metric_value: best_epoch_info = BestEpochInfo( params_path="%s-%04d.params" % (base_path(), epoch_no), epoch_no=epoch_no, metric_value=loss_value(epoch_loss), ) net.save_parameters( best_epoch_info.params_path ) # TODO: handle possible exception if not trainer.learning_rate == curr_lr: logging.info(f"Loading parameters from best epoch " f"({best_epoch_info.epoch_no})") net.load_parameters(best_epoch_info.params_path, self.ctx) logging.info(f"Loading parameters from best epoch " f"({best_epoch_info.epoch_no})") net.load_parameters(best_epoch_info.params_path, self.ctx) logging.info(f"Final loss: {best_epoch_info.metric_value} " f"(occurred at epoch {best_epoch_info.epoch_no})") # save net parameters net.save_parameters(best_epoch_info.params_path) logging.getLogger().info("End model training")
def __call__( self, net: nn.HybridBlock, input_names: List[str], train_iter: TrainDataLoader, validation_iter: Optional[ValidationDataLoader] = None, ) -> None: # TODO: we may want to return some training information here is_validation_available = validation_iter is not None self.halt = False with tempfile.TemporaryDirectory( prefix="gluonts-trainer-temp-") as gluonts_temp: def base_path() -> str: return os.path.join( gluonts_temp, "{}_{}".format(STATE_ARTIFACT_FILE_NAME, uuid.uuid4()), ) logger.info("Start model training") net.initialize(ctx=self.ctx, init=self.init) with HybridContext( net=net, hybridize=self.hybridize, static_alloc=True, static_shape=True, ): batch_size = train_iter.batch_size best_epoch_info = { "params_path": "%s-%s.params" % (base_path(), "init"), "epoch_no": -1, "score": np.Inf, } lr_scheduler = lrs.MetricAttentiveScheduler( objective="min", patience=self.patience, decay_factor=self.learning_rate_decay_factor, min_lr=self.minimum_learning_rate, ) optimizer = mx.optimizer.Adam( learning_rate=self.learning_rate, lr_scheduler=lr_scheduler, wd=self.weight_decay, clip_gradient=self.clip_gradient, ) trainer = mx.gluon.Trainer( net.collect_params(), optimizer=optimizer, kvstore="device", # FIXME: initialize properly ) first_forward = True def loop(epoch_no, batch_iter, is_training: bool = True) -> mx.metric.Loss: nonlocal first_forward tic = time.time() epoch_loss = mx.metric.Loss() # use averaged model for validation if not is_training and isinstance( self.avg_strategy, IterationAveragingStrategy): self.avg_strategy.load_averaged_model(net) with tqdm(batch_iter) as it: for batch_no, data_entry in enumerate(it, start=1): if self.halt: break inputs = [data_entry[k] for k in input_names] if first_forward: first_forward = False _ = net(*inputs) if self.post_initialize_cb: self.post_initialize_cb(net) with mx.autograd.record(): output = net(*inputs) # network can returns several outputs, the first being always the loss # when having multiple outputs, the forward returns a list in the case of hybrid and a # tuple otherwise # we may wrap network outputs in the future to avoid this type check if isinstance(output, (list, tuple)): loss = output[0] else: loss = output if is_training: loss.backward() trainer.step(batch_size) # iteration averaging in training if isinstance( self.avg_strategy, IterationAveragingStrategy, ): self.avg_strategy.apply(net) epoch_loss.update(None, preds=loss) lv = loss_value(epoch_loss) if not np.isfinite(lv): logger.warning("Epoch[%d] gave nan loss", epoch_no) return epoch_loss it.set_postfix( ordered_dict={ "epoch": f"{epoch_no + 1}/{self.epochs}", ("" if is_training else "validation_") + "avg_epoch_loss": lv, }, refresh=False, ) # print out parameters of the network at the first pass if batch_no == 1 and epoch_no == 0: net_name = type(net).__name__ num_model_param = self.count_model_params(net) logger.info( f"Number of parameters in {net_name}: {num_model_param}" ) # mark epoch end time and log time cost of current epoch toc = time.time() logger.info( "Epoch[%d] Elapsed time %.3f seconds", epoch_no, (toc - tic), ) logger.info( "Epoch[%d] Evaluation metric '%s'=%f", epoch_no, ("" if is_training else "validation_") + "epoch_loss", lv, ) if not is_training and isinstance( self.avg_strategy, IterationAveragingStrategy): # bring back the cached model self.avg_strategy.load_cached_model(net) return epoch_loss for epoch_no in range(self.epochs): if self.halt: logger.info(f"Epoch[{epoch_no}] Interrupting training") break curr_lr = trainer.learning_rate logger.info( f"Epoch[{epoch_no}] Learning rate is {curr_lr}") epoch_loss = loop(epoch_no, train_iter) if is_validation_available: epoch_loss = loop(epoch_no, validation_iter, is_training=False) # update average trigger if isinstance(self.avg_strategy, IterationAveragingStrategy): self.avg_strategy.update_average_trigger( metric=loss_value(epoch_loss), epoch=epoch_no + 1) # once triggered, update the average immediately self.avg_strategy.apply(net) should_continue = lr_scheduler.step(loss_value(epoch_loss)) if isinstance(self.avg_strategy, IterationAveragingStrategy): logging.info( "Overriding early stopping for iteration-based averaging strategies." ) should_continue = True if not should_continue: logger.info("Stopping training") break # save model and epoch info bp = base_path() epoch_info = { "params_path": f"{bp}-0000.params", "epoch_no": epoch_no, "score": loss_value(epoch_loss), } net.save_parameters(epoch_info["params_path"] ) # TODO: handle possible exception save_epoch_info(bp, epoch_info) # update best epoch info - needed for the learning rate scheduler if loss_value(epoch_loss) < best_epoch_info["score"]: best_epoch_info = epoch_info.copy() if not trainer.learning_rate == curr_lr: if best_epoch_info["epoch_no"] == -1: raise GluonTSUserError( "Got NaN in first epoch. Try reducing initial learning rate." ) logger.info(f"Loading parameters from best epoch " f"({best_epoch_info['epoch_no']})") net.load_parameters(best_epoch_info["params_path"], self.ctx) if isinstance(self.avg_strategy, AveragingStrategy): logging.info("Computing averaged parameters.") averaged_params_path = self.avg_strategy.apply( gluonts_temp) logging.info("Loading averaged parameters.") net.load_parameters(averaged_params_path, self.ctx) if isinstance(self.avg_strategy, IterationAveragingStrategy): logging.info("Loading averaged parameters.") self.avg_strategy.load_averaged_model(net) logger.info("End model training")
def __call__( self, net: nn.HybridBlock, train_iter: DataLoader, validation_iter: Optional[DataLoader] = None, ) -> None: # TODO: we may want to return some training information here """ Train a network, given an iterable over training (and optionally validation) batches. Parameters ---------- net Network to be trained. This a Gluon HybridBlock, assumed to produce a tensor of loss values as output. train_iter An iterable over batches to be used for training. Batches are assumed to be dictionaries, whose values are MXNet arrays that correspond to the network inputs. validation_iter Similar to `train_iter` but the batches produced here are used to compute validation metrics. """ is_validation_available = validation_iter is not None with tempfile.TemporaryDirectory( prefix="gluonts-trainer-temp-" ) as gluonts_temp: def base_path() -> str: return os.path.join( gluonts_temp, "{}_{}".format(STATE_ARTIFACT_FILE_NAME, uuid.uuid4()), ) logger.info("Start model training") net.initialize(ctx=self.ctx, init=self.init) with HybridContext( net=net, hybridize=self.hybridize, static_alloc=True, static_shape=True, ): best_epoch_info = { "params_path": "%s-%s.params" % (base_path(), "init"), "epoch_no": -1, "score": np.Inf, } lr_scheduler = lrs.MetricAttentiveScheduler( objective="min", patience=self.patience, decay_factor=self.learning_rate_decay_factor, min_lr=self.minimum_learning_rate, ) optimizer = mx.optimizer.Adam( learning_rate=self.learning_rate, lr_scheduler=lr_scheduler, wd=self.weight_decay, clip_gradient=self.clip_gradient, ) trainer = mx.gluon.Trainer( net.collect_params(), optimizer=optimizer, kvstore="device", # FIXME: initialize properly ) first_forward = True def loop( epoch_no, batch_iter, num_batches_to_use: Optional[int] = None, is_training: bool = True, ) -> mx.metric.Loss: nonlocal first_forward tic = time.time() epoch_loss = mx.metric.Loss() # use averaged model for validation if not is_training and isinstance( self.avg_strategy, IterationAveragingStrategy ): self.avg_strategy.load_averaged_model(net) batch_iter = itertools.islice( batch_iter, num_batches_to_use ) with tqdm(batch_iter, total=num_batches_to_use) as it: for batch_no, batch in enumerate(it, start=1): # `batch` here is expected to be a dictionary whose fields # should correspond 1-to-1 with the network inputs # see below how `batch.values()` is fed into the network if first_forward: first_forward = False _ = net(*batch.values()) if self.post_initialize_cb: self.post_initialize_cb(net) with mx.autograd.record(): # we set the mode explicitly as by default mxnet assumes predict mode and hence # dropout layers are not used if the mode is not explicitly set to training mode = ( autograd.train_mode if is_training else autograd.predict_mode ) with mode(): output = net(*batch.values()) # network can returns several outputs, the first being always the loss # when having multiple outputs, the forward returns a list in the case of hybrid and a # tuple otherwise # we may wrap network outputs in the future to avoid this type check if isinstance(output, (list, tuple)): loss = output[0] else: loss = output batch_size = loss.shape[0] if not np.isfinite(ndarray.sum(loss).asscalar()): logger.warning( "Batch [%d] of Epoch[%d] gave NaN loss and it will be ignored", batch_no, epoch_no, ) else: if is_training: loss.backward() trainer.step(batch_size) # iteration averaging in training if isinstance( self.avg_strategy, IterationAveragingStrategy, ): self.avg_strategy.apply(net) epoch_loss.update(None, preds=loss) lv = loss_value(epoch_loss) it.set_postfix( ordered_dict={ "epoch": f"{epoch_no + 1}/{self.epochs}", ("" if is_training else "validation_") + "avg_epoch_loss": lv, }, refresh=False, ) # print out parameters of the network at the first pass if batch_no == 1 and epoch_no == 0: net_name = type(net).__name__ num_model_param = self.count_model_params(net) logger.info( f"Number of parameters in {net_name}: {num_model_param}" ) # mark epoch end time and log time cost of current epoch toc = time.time() logger.info( "Epoch[%d] Elapsed time %.3f seconds", epoch_no, (toc - tic), ) logger.info( "Epoch[%d] Evaluation metric '%s'=%f", epoch_no, ("" if is_training else "validation_") + "epoch_loss", lv, ) if not is_training and isinstance( self.avg_strategy, IterationAveragingStrategy ): # bring back the cached model self.avg_strategy.load_cached_model(net) return epoch_loss for epoch_no in range(self.epochs): curr_lr = trainer.learning_rate logger.info( f"Epoch[{epoch_no}] Learning rate is {curr_lr}" ) epoch_loss = loop( epoch_no, train_iter, num_batches_to_use=self.num_batches_per_epoch, ) if is_validation_available: epoch_loss = loop( epoch_no, validation_iter, is_training=False ) # update average trigger if isinstance( self.avg_strategy, IterationAveragingStrategy ): self.avg_strategy.update_average_trigger( metric=loss_value(epoch_loss), epoch=epoch_no + 1 ) # once triggered, update the average immediately self.avg_strategy.apply(net) should_continue = lr_scheduler.step(loss_value(epoch_loss)) if isinstance( self.avg_strategy, IterationAveragingStrategy ): logging.info( "Overriding early stopping for iteration-based averaging strategies." ) should_continue = True if not should_continue: logger.info("Stopping training") break # save model and epoch info bp = base_path() epoch_info = { "params_path": f"{bp}-0000.params", "epoch_no": epoch_no, "score": loss_value(epoch_loss), } net.save_parameters( epoch_info["params_path"] ) # TODO: handle possible exception save_epoch_info(bp, epoch_info) # update best epoch info - needed for the learning rate scheduler if loss_value(epoch_loss) < best_epoch_info["score"]: best_epoch_info = epoch_info.copy() if not trainer.learning_rate == curr_lr: if best_epoch_info["epoch_no"] == -1: raise GluonTSUserError( "Got NaN in first epoch. Try reducing initial learning rate." ) logger.info( f"Loading parameters from best epoch " f"({best_epoch_info['epoch_no']})" ) net.load_parameters( best_epoch_info["params_path"], self.ctx ) if isinstance(self.avg_strategy, AveragingStrategy): logging.info("Computing averaged parameters.") averaged_params_path = self.avg_strategy.apply( gluonts_temp ) logging.info("Loading averaged parameters.") net.load_parameters(averaged_params_path, self.ctx) if isinstance(self.avg_strategy, IterationAveragingStrategy): logging.info("Loading averaged parameters.") self.avg_strategy.load_averaged_model(net) logger.info("End model training")