def train_net(batch_size, num_epochs, start_epoch1, train_data, val_data, batch_fn, data_source_needs_reset, dtype, net, teacher_net, discrim_net, trainer, lr_scheduler, lp_saver, log_interval, mixup, mixup_epoch_tail, label_smoothing, num_classes, grad_clip_value, batch_size_scale, val_metric, train_metric, loss_metrics, loss_func, discrim_loss_func, ctx): """ Main procedure for training model. Parameters: ---------- batch_size : int Training batch size. num_epochs : int Number of training epochs. start_epoch1 : int Number of starting epoch (1-based). train_data : DataLoader or ImageRecordIter Data loader or ImRec-iterator (training subset). val_data : DataLoader or ImageRecordIter Data loader or ImRec-iterator (validation subset). batch_fn : func Function for splitting data after extraction from data loader. data_source_needs_reset : bool Whether to reset data (if test_data is ImageRecordIter). dtype : str Base data type for tensors. net : HybridBlock Model. teacher_net : HybridBlock or None Teacher model. discrim_net : HybridBlock or None MEALv2 discriminator model. trainer : Trainer Trainer. lr_scheduler : LRScheduler Learning rate scheduler. lp_saver : TrainLogParamSaver Model/trainer state saver. log_interval : int Batch count period for logging. mixup : bool Whether to use mixup. mixup_epoch_tail : int Number of epochs without mixup at the end of training. label_smoothing : bool Whether to use label-smoothing. num_classes : int Number of model classes. grad_clip_value : float Threshold for gradient clipping. batch_size_scale : int Manual batch-size increasing factor. val_metric : EvalMetric Metric object instance (validation subset). train_metric : EvalMetric Metric object instance (training subset). loss_metrics : list of EvalMetric Metric object instances (loss values). loss_func : Loss Loss object instance. discrim_loss_func : Loss or None MEALv2 adversarial loss function. ctx : Context MXNet context. """ if batch_size_scale != 1: for p in net.collect_params().values(): p.grad_req = "add" if isinstance(ctx, mx.Context): ctx = [ctx] # loss_func = gluon.loss.SoftmaxCrossEntropyLoss(sparse_label=(not (mixup or label_smoothing))) assert (type(start_epoch1) == int) assert (start_epoch1 >= 1) if start_epoch1 > 1: logging.info("Start training from [Epoch {}]".format(start_epoch1)) validate(metric=val_metric, net=net, val_data=val_data, batch_fn=batch_fn, data_source_needs_reset=data_source_needs_reset, dtype=dtype, ctx=ctx) val_accuracy_msg = report_accuracy(metric=val_metric) logging.info("[Epoch {}] validation: {}".format( start_epoch1 - 1, val_accuracy_msg)) gtic = time.time() for epoch in range(start_epoch1 - 1, num_epochs): train_epoch(epoch=epoch, net=net, teacher_net=teacher_net, discrim_net=discrim_net, train_metric=train_metric, loss_metrics=loss_metrics, train_data=train_data, batch_fn=batch_fn, data_source_needs_reset=data_source_needs_reset, dtype=dtype, ctx=ctx, loss_func=loss_func, discrim_loss_func=discrim_loss_func, trainer=trainer, lr_scheduler=lr_scheduler, batch_size=batch_size, log_interval=log_interval, mixup=mixup, mixup_epoch_tail=mixup_epoch_tail, label_smoothing=label_smoothing, num_classes=num_classes, num_epochs=num_epochs, grad_clip_value=grad_clip_value, batch_size_scale=batch_size_scale) validate(metric=val_metric, net=net, val_data=val_data, batch_fn=batch_fn, data_source_needs_reset=data_source_needs_reset, dtype=dtype, ctx=ctx) val_accuracy_msg = report_accuracy(metric=val_metric) logging.info("[Epoch {}] validation: {}".format( epoch + 1, val_accuracy_msg)) if lp_saver is not None: lp_saver_kwargs = {"net": net, "trainer": trainer} val_acc_values = val_metric.get()[1] train_acc_values = train_metric.get()[1] val_acc_values = val_acc_values if type( val_acc_values) == list else [val_acc_values] train_acc_values = train_acc_values if type( train_acc_values) == list else [train_acc_values] lp_saver.epoch_test_end_callback( epoch1=(epoch + 1), params=(val_acc_values + train_acc_values + [loss_metrics[0].get()[1], trainer.learning_rate]), **lp_saver_kwargs) logging.info("Total time cost: {:.2f} sec".format(time.time() - gtic)) if lp_saver is not None: opt_metric_name = get_metric_name(val_metric, lp_saver.acc_ind) logging.info("Best {}: {:.4f} at {} epoch".format( opt_metric_name, lp_saver.best_eval_metric_value, lp_saver.best_eval_metric_epoch))
def train_net(batch_size, num_epochs, start_epoch1, train_data, val_data, batch_fn, data_source_needs_reset, dtype, net, trainer, lr_scheduler, lp_saver, log_interval, mixup, mixup_epoch_tail, label_smoothing, num_classes, grad_clip_value, batch_size_scale, val_metric, train_metric, ctx): if batch_size_scale != 1: for p in net.collect_params().values(): p.grad_req = "add" if isinstance(ctx, mx.Context): ctx = [ctx] loss_func = gluon.loss.SoftmaxCrossEntropyLoss( sparse_label=(not (mixup or label_smoothing))) assert (type(start_epoch1) == int) assert (start_epoch1 >= 1) if start_epoch1 > 1: logging.info("Start training from [Epoch {}]".format(start_epoch1)) validate(metric=val_metric, net=net, val_data=val_data, batch_fn=batch_fn, data_source_needs_reset=data_source_needs_reset, dtype=dtype, ctx=ctx) val_accuracy_msg = report_accuracy(metric=val_metric) logging.info("[Epoch {}] validation: {}".format( start_epoch1 - 1, val_accuracy_msg)) gtic = time.time() for epoch in range(start_epoch1 - 1, num_epochs): train_loss = train_epoch( epoch=epoch, net=net, train_metric=train_metric, train_data=train_data, batch_fn=batch_fn, data_source_needs_reset=data_source_needs_reset, dtype=dtype, ctx=ctx, loss_func=loss_func, trainer=trainer, lr_scheduler=lr_scheduler, batch_size=batch_size, log_interval=log_interval, mixup=mixup, mixup_epoch_tail=mixup_epoch_tail, label_smoothing=label_smoothing, num_classes=num_classes, num_epochs=num_epochs, grad_clip_value=grad_clip_value, batch_size_scale=batch_size_scale) validate(metric=val_metric, net=net, val_data=val_data, batch_fn=batch_fn, data_source_needs_reset=data_source_needs_reset, dtype=dtype, ctx=ctx) val_accuracy_msg = report_accuracy(metric=val_metric) logging.info("[Epoch {}] validation: {}".format( epoch + 1, val_accuracy_msg)) if lp_saver is not None: lp_saver_kwargs = {"net": net, "trainer": trainer} val_acc_values = val_metric.get()[1] train_acc_values = train_metric.get()[1] val_acc_values = val_acc_values if type( val_acc_values) == list else [val_acc_values] train_acc_values = train_acc_values if type( train_acc_values) == list else [train_acc_values] lp_saver.epoch_test_end_callback( epoch1=(epoch + 1), params=(val_acc_values + train_acc_values + [train_loss, trainer.learning_rate]), **lp_saver_kwargs) logging.info("Total time cost: {:.2f} sec".format(time.time() - gtic)) if lp_saver is not None: opt_metric_name = get_metric_name(val_metric, lp_saver.acc_ind) logging.info("Best {}: {:.4f} at {} epoch".format( opt_metric_name, lp_saver.best_eval_metric_value, lp_saver.best_eval_metric_epoch))