def test(net, test_data, batch_fn, data_source_needs_reset, metric, dtype, ctx, input_image_size, in_channels, calc_weight_count=False, calc_flops=False, calc_flops_only=True, extended_log=False): if not calc_flops_only: tic = time.time() validate(metric=metric, net=net, val_data=test_data, batch_fn=batch_fn, data_source_needs_reset=data_source_needs_reset, dtype=dtype, ctx=ctx) accuracy_msg = report_accuracy(metric=metric, extended_log=extended_log) logging.info("Test: {}".format(accuracy_msg)) logging.info("Time cost: {:.4f} sec".format(time.time() - tic)) if calc_weight_count: weight_count = calc_net_weight_count(net) if not calc_flops: logging.info("Model: {} trainable parameters".format(weight_count)) if calc_flops: num_flops, num_macs, num_params = measure_model( net, in_channels, input_image_size, ctx[0]) assert (not calc_weight_count) or (weight_count == num_params) stat_msg = "Params: {params} ({params_m:.2f}M), FLOPs: {flops} ({flops_m:.2f}M)," \ " FLOPs/2: {flops2} ({flops2_m:.2f}M), MACs: {macs} ({macs_m:.2f}M)" logging.info( stat_msg.format(params=num_params, params_m=num_params / 1e6, flops=num_flops, flops_m=num_flops / 1e6, flops2=num_flops / 2, flops2_m=num_flops / 2 / 1e6, macs=num_macs, macs_m=num_macs / 1e6))
def train_net(batch_size, num_epochs, start_epoch1, train_data, val_data, batch_fn, data_source_needs_reset, dtype, net, teacher_net, discrim_net, trainer, lr_scheduler, lp_saver, log_interval, mixup, mixup_epoch_tail, label_smoothing, num_classes, grad_clip_value, batch_size_scale, val_metric, train_metric, loss_metrics, loss_func, discrim_loss_func, ctx): """ Main procedure for training model. Parameters: ---------- batch_size : int Training batch size. num_epochs : int Number of training epochs. start_epoch1 : int Number of starting epoch (1-based). train_data : DataLoader or ImageRecordIter Data loader or ImRec-iterator (training subset). val_data : DataLoader or ImageRecordIter Data loader or ImRec-iterator (validation subset). batch_fn : func Function for splitting data after extraction from data loader. data_source_needs_reset : bool Whether to reset data (if test_data is ImageRecordIter). dtype : str Base data type for tensors. net : HybridBlock Model. teacher_net : HybridBlock or None Teacher model. discrim_net : HybridBlock or None MEALv2 discriminator model. trainer : Trainer Trainer. lr_scheduler : LRScheduler Learning rate scheduler. lp_saver : TrainLogParamSaver Model/trainer state saver. log_interval : int Batch count period for logging. mixup : bool Whether to use mixup. mixup_epoch_tail : int Number of epochs without mixup at the end of training. label_smoothing : bool Whether to use label-smoothing. num_classes : int Number of model classes. grad_clip_value : float Threshold for gradient clipping. batch_size_scale : int Manual batch-size increasing factor. val_metric : EvalMetric Metric object instance (validation subset). train_metric : EvalMetric Metric object instance (training subset). loss_metrics : list of EvalMetric Metric object instances (loss values). loss_func : Loss Loss object instance. discrim_loss_func : Loss or None MEALv2 adversarial loss function. ctx : Context MXNet context. """ if batch_size_scale != 1: for p in net.collect_params().values(): p.grad_req = "add" if isinstance(ctx, mx.Context): ctx = [ctx] # loss_func = gluon.loss.SoftmaxCrossEntropyLoss(sparse_label=(not (mixup or label_smoothing))) assert (type(start_epoch1) == int) assert (start_epoch1 >= 1) if start_epoch1 > 1: logging.info("Start training from [Epoch {}]".format(start_epoch1)) validate(metric=val_metric, net=net, val_data=val_data, batch_fn=batch_fn, data_source_needs_reset=data_source_needs_reset, dtype=dtype, ctx=ctx) val_accuracy_msg = report_accuracy(metric=val_metric) logging.info("[Epoch {}] validation: {}".format( start_epoch1 - 1, val_accuracy_msg)) gtic = time.time() for epoch in range(start_epoch1 - 1, num_epochs): train_epoch(epoch=epoch, net=net, teacher_net=teacher_net, discrim_net=discrim_net, train_metric=train_metric, loss_metrics=loss_metrics, train_data=train_data, batch_fn=batch_fn, data_source_needs_reset=data_source_needs_reset, dtype=dtype, ctx=ctx, loss_func=loss_func, discrim_loss_func=discrim_loss_func, trainer=trainer, lr_scheduler=lr_scheduler, batch_size=batch_size, log_interval=log_interval, mixup=mixup, mixup_epoch_tail=mixup_epoch_tail, label_smoothing=label_smoothing, num_classes=num_classes, num_epochs=num_epochs, grad_clip_value=grad_clip_value, batch_size_scale=batch_size_scale) validate(metric=val_metric, net=net, val_data=val_data, batch_fn=batch_fn, data_source_needs_reset=data_source_needs_reset, dtype=dtype, ctx=ctx) val_accuracy_msg = report_accuracy(metric=val_metric) logging.info("[Epoch {}] validation: {}".format( epoch + 1, val_accuracy_msg)) if lp_saver is not None: lp_saver_kwargs = {"net": net, "trainer": trainer} val_acc_values = val_metric.get()[1] train_acc_values = train_metric.get()[1] val_acc_values = val_acc_values if type( val_acc_values) == list else [val_acc_values] train_acc_values = train_acc_values if type( train_acc_values) == list else [train_acc_values] lp_saver.epoch_test_end_callback( epoch1=(epoch + 1), params=(val_acc_values + train_acc_values + [loss_metrics[0].get()[1], trainer.learning_rate]), **lp_saver_kwargs) logging.info("Total time cost: {:.2f} sec".format(time.time() - gtic)) if lp_saver is not None: opt_metric_name = get_metric_name(val_metric, lp_saver.acc_ind) logging.info("Best {}: {:.4f} at {} epoch".format( opt_metric_name, lp_saver.best_eval_metric_value, lp_saver.best_eval_metric_epoch))
def train_epoch(epoch, net, teacher_net, discrim_net, train_metric, loss_metrics, train_data, batch_fn, data_source_needs_reset, dtype, ctx, loss_func, discrim_loss_func, trainer, lr_scheduler, batch_size, log_interval, mixup, mixup_epoch_tail, label_smoothing, num_classes, num_epochs, grad_clip_value, batch_size_scale): """ Train model on particular epoch. Parameters: ---------- epoch : int Epoch number. net : HybridBlock Model. teacher_net : HybridBlock or None Teacher model. discrim_net : HybridBlock or None MEALv2 discriminator model. train_metric : EvalMetric Metric object instance. loss_metric : list of EvalMetric Metric object instances (loss values). train_data : DataLoader or ImageRecordIter Data loader or ImRec-iterator. batch_fn : func Function for splitting data after extraction from data loader. data_source_needs_reset : bool Whether to reset data (if test_data is ImageRecordIter). dtype : str Base data type for tensors. ctx : Context MXNet context. loss_func : Loss Loss function. discrim_loss_func : Loss or None MEALv2 adversarial loss function. trainer : Trainer Trainer. lr_scheduler : LRScheduler Learning rate scheduler. batch_size : int Training batch size. log_interval : int Batch count period for logging. mixup : bool Whether to use mixup. mixup_epoch_tail : int Number of epochs without mixup at the end of training. label_smoothing : bool Whether to use label-smoothing. num_classes : int Number of model classes. num_epochs : int Number of training epochs. grad_clip_value : float Threshold for gradient clipping. batch_size_scale : int Manual batch-size increasing factor. Returns: ------- float Loss value. """ labels_list_inds = None batch_size_extend_count = 0 tic = time.time() if data_source_needs_reset: train_data.reset() train_metric.reset() for m in loss_metrics: m.reset() i = 0 btic = time.time() for i, batch in enumerate(train_data): data_list, labels_list = batch_fn(batch, ctx) labels_one_hot = False if teacher_net is not None: labels_list = [ teacher_net(x.astype(dtype, copy=False)).softmax(axis=-1).mean(axis=1) for x in data_list ] labels_list_inds = [y.argmax(axis=-1) for y in labels_list] labels_one_hot = True if label_smoothing and not (teacher_net is not None): eta = 0.1 on_value = 1 - eta + eta / num_classes off_value = eta / num_classes if not labels_one_hot: labels_list_inds = labels_list labels_list = [ y.one_hot(depth=num_classes, on_value=on_value, off_value=off_value) for y in labels_list ] labels_one_hot = True if mixup: if not labels_one_hot: labels_list_inds = labels_list labels_list = [ y.one_hot(depth=num_classes) for y in labels_list ] labels_one_hot = True if epoch < num_epochs - mixup_epoch_tail: alpha = 1 lam = np.random.beta(alpha, alpha) data_list = [lam * x + (1 - lam) * x[::-1] for x in data_list] labels_list = [ lam * y + (1 - lam) * y[::-1] for y in labels_list ] with ag.record(): outputs_list = [ net(x.astype(dtype, copy=False)) for x in data_list ] loss_list = [ loss_func(yhat, y.astype(dtype, copy=False)) for yhat, y in zip(outputs_list, labels_list) ] if discrim_net is not None: d_pred_list = [ discrim_net(yhat.astype(dtype, copy=False).softmax()) for yhat in outputs_list ] d_label_list = [ discrim_net(y.astype(dtype, copy=False)) for y in labels_list ] d_loss_list = [ discrim_loss_func(yhat, y) for yhat, y in zip(d_pred_list, d_label_list) ] loss_list = [z + dz for z, dz in zip(loss_list, d_loss_list)] for loss in loss_list: loss.backward() lr_scheduler.update(i, epoch) if grad_clip_value is not None: grads = [ v.grad(ctx[0]) for v in net.collect_params().values() if v._grad is not None ] gluon.utils.clip_global_norm(grads, max_norm=grad_clip_value) if batch_size_scale == 1: trainer.step(batch_size) else: if (i + 1) % batch_size_scale == 0: batch_size_extend_count = 0 trainer.step(batch_size * batch_size_scale) for p in net.collect_params().values(): p.zero_grad() else: batch_size_extend_count += 1 train_metric.update( labels=(labels_list if not labels_one_hot else labels_list_inds), preds=outputs_list) loss_metrics[0].update(labels=None, preds=loss_list) if (discrim_net is not None) and (len(loss_metrics) > 1): loss_metrics[1].update(labels=None, preds=d_loss_list) if log_interval and not (i + 1) % log_interval: speed = batch_size * log_interval / (time.time() - btic) btic = time.time() train_accuracy_msg = report_accuracy(metric=train_metric) loss_accuracy_msg = report_accuracy(metric=loss_metrics[0]) if (discrim_net is not None) and (len(loss_metrics) > 1): dloss_accuracy_msg = report_accuracy(metric=loss_metrics[1]) logging.info( "Epoch[{}] Batch [{}]\tSpeed: {:.2f} samples/sec\t{}\t{}\t{}\tlr={:.5f}" .format(epoch + 1, i, speed, train_accuracy_msg, loss_accuracy_msg, dloss_accuracy_msg, trainer.learning_rate)) else: logging.info( "Epoch[{}] Batch [{}]\tSpeed: {:.2f} samples/sec\t{}\t{}\tlr={:.5f}" .format(epoch + 1, i, speed, train_accuracy_msg, loss_accuracy_msg, trainer.learning_rate)) if (batch_size_scale != 1) and (batch_size_extend_count > 0): trainer.step(batch_size * batch_size_extend_count) for p in net.collect_params().values(): p.zero_grad() throughput = int(batch_size * (i + 1) / (time.time() - tic)) logging.info( "[Epoch {}] speed: {:.2f} samples/sec\ttime cost: {:.2f} sec".format( epoch + 1, throughput, time.time() - tic)) train_accuracy_msg = report_accuracy(metric=train_metric) loss_accuracy_msg = report_accuracy(metric=loss_metrics[0]) if (discrim_net is not None) and (len(loss_metrics) > 1): dloss_accuracy_msg = report_accuracy(metric=loss_metrics[1]) logging.info("[Epoch {}] training: {}\t{}\t{}".format( epoch + 1, train_accuracy_msg, loss_accuracy_msg, dloss_accuracy_msg)) else: logging.info("[Epoch {}] training: {}\t{}".format( epoch + 1, train_accuracy_msg, loss_accuracy_msg))
def calc_model_accuracy(net, test_data, batch_fn, data_source_needs_reset, metric, dtype, ctx, input_image_size, in_channels, calc_weight_count=False, calc_flops=False, calc_flops_only=True, extended_log=False): """ Main test routine. Parameters: ---------- net : HybridBlock Model. test_data : DataLoader or ImageRecordIter Data loader or ImRec-iterator. batch_fn : func Function for splitting data after extraction from data loader. data_source_needs_reset : bool Whether to reset data (if test_data is ImageRecordIter). metric : EvalMetric Metric object instance. dtype : str Base data type for tensors. ctx : Context MXNet context. input_image_size : tuple of 2 ints Spatial size of the expected input image. in_channels : int Number of input channels. calc_weight_count : bool, default False Whether to calculate count of weights. calc_flops : bool, default False Whether to calculate FLOPs. calc_flops_only : bool, default True Whether to only calculate FLOPs without testing. extended_log : bool, default False Whether to log more precise accuracy values. Returns: ------- list of floats Accuracy values. """ if not calc_flops_only: tic = time.time() validate(metric=metric, net=net, val_data=test_data, batch_fn=batch_fn, data_source_needs_reset=data_source_needs_reset, dtype=dtype, ctx=ctx) accuracy_msg = report_accuracy(metric=metric, extended_log=extended_log) logging.info("Test: {}".format(accuracy_msg)) logging.info("Time cost: {:.4f} sec".format(time.time() - tic)) acc_values = metric.get()[1] acc_values = acc_values if type(acc_values) == list else [acc_values] else: acc_values = [] if calc_weight_count: weight_count = calc_net_weight_count(net) if not calc_flops: logging.info("Model: {} trainable parameters".format(weight_count)) if calc_flops: num_flops, num_macs, num_params = measure_model( net, in_channels, input_image_size, ctx[0]) assert (not calc_weight_count) or (weight_count == num_params) stat_msg = "Params: {params} ({params_m:.2f}M), FLOPs: {flops} ({flops_m:.2f}M)," \ " FLOPs/2: {flops2} ({flops2_m:.2f}M), MACs: {macs} ({macs_m:.2f}M)" logging.info( stat_msg.format(params=num_params, params_m=num_params / 1e6, flops=num_flops, flops_m=num_flops / 1e6, flops2=num_flops / 2, flops2_m=num_flops / 2 / 1e6, macs=num_macs, macs_m=num_macs / 1e6)) return acc_values
def train_net(batch_size, num_epochs, start_epoch1, train_data, val_data, batch_fn, data_source_needs_reset, dtype, net, trainer, lr_scheduler, lp_saver, log_interval, mixup, mixup_epoch_tail, label_smoothing, num_classes, grad_clip_value, batch_size_scale, val_metric, train_metric, ctx): if batch_size_scale != 1: for p in net.collect_params().values(): p.grad_req = "add" if isinstance(ctx, mx.Context): ctx = [ctx] loss_func = gluon.loss.SoftmaxCrossEntropyLoss( sparse_label=(not (mixup or label_smoothing))) assert (type(start_epoch1) == int) assert (start_epoch1 >= 1) if start_epoch1 > 1: logging.info("Start training from [Epoch {}]".format(start_epoch1)) validate(metric=val_metric, net=net, val_data=val_data, batch_fn=batch_fn, data_source_needs_reset=data_source_needs_reset, dtype=dtype, ctx=ctx) val_accuracy_msg = report_accuracy(metric=val_metric) logging.info("[Epoch {}] validation: {}".format( start_epoch1 - 1, val_accuracy_msg)) gtic = time.time() for epoch in range(start_epoch1 - 1, num_epochs): train_loss = train_epoch( epoch=epoch, net=net, train_metric=train_metric, train_data=train_data, batch_fn=batch_fn, data_source_needs_reset=data_source_needs_reset, dtype=dtype, ctx=ctx, loss_func=loss_func, trainer=trainer, lr_scheduler=lr_scheduler, batch_size=batch_size, log_interval=log_interval, mixup=mixup, mixup_epoch_tail=mixup_epoch_tail, label_smoothing=label_smoothing, num_classes=num_classes, num_epochs=num_epochs, grad_clip_value=grad_clip_value, batch_size_scale=batch_size_scale) validate(metric=val_metric, net=net, val_data=val_data, batch_fn=batch_fn, data_source_needs_reset=data_source_needs_reset, dtype=dtype, ctx=ctx) val_accuracy_msg = report_accuracy(metric=val_metric) logging.info("[Epoch {}] validation: {}".format( epoch + 1, val_accuracy_msg)) if lp_saver is not None: lp_saver_kwargs = {"net": net, "trainer": trainer} val_acc_values = val_metric.get()[1] train_acc_values = train_metric.get()[1] val_acc_values = val_acc_values if type( val_acc_values) == list else [val_acc_values] train_acc_values = train_acc_values if type( train_acc_values) == list else [train_acc_values] lp_saver.epoch_test_end_callback( epoch1=(epoch + 1), params=(val_acc_values + train_acc_values + [train_loss, trainer.learning_rate]), **lp_saver_kwargs) logging.info("Total time cost: {:.2f} sec".format(time.time() - gtic)) if lp_saver is not None: opt_metric_name = get_metric_name(val_metric, lp_saver.acc_ind) logging.info("Best {}: {:.4f} at {} epoch".format( opt_metric_name, lp_saver.best_eval_metric_value, lp_saver.best_eval_metric_epoch))
def train_epoch(epoch, net, train_metric, train_data, batch_fn, data_source_needs_reset, dtype, ctx, loss_func, trainer, lr_scheduler, batch_size, log_interval, mixup, mixup_epoch_tail, label_smoothing, num_classes, num_epochs, grad_clip_value, batch_size_scale): labels_list_inds = None batch_size_extend_count = 0 tic = time.time() if data_source_needs_reset: train_data.reset() train_metric.reset() train_loss = 0.0 btic = time.time() for i, batch in enumerate(train_data): data_list, labels_list = batch_fn(batch, ctx) if label_smoothing: eta = 0.1 on_value = 1 - eta + eta / num_classes off_value = eta / num_classes labels_list_inds = labels_list labels_list = [ Y.one_hot(depth=num_classes, on_value=on_value, off_value=off_value) for Y in labels_list ] if mixup: if not label_smoothing: labels_list_inds = labels_list labels_list = [ Y.one_hot(depth=num_classes) for Y in labels_list ] if epoch < num_epochs - mixup_epoch_tail: alpha = 1 lam = np.random.beta(alpha, alpha) data_list = [lam * X + (1 - lam) * X[::-1] for X in data_list] labels_list = [ lam * Y + (1 - lam) * Y[::-1] for Y in labels_list ] with ag.record(): outputs_list = [ net(X.astype(dtype, copy=False)) for X in data_list ] loss_list = [ loss_func(yhat, y.astype(dtype, copy=False)) for yhat, y in zip(outputs_list, labels_list) ] for loss in loss_list: loss.backward() lr_scheduler.update(i, epoch) if grad_clip_value is not None: grads = [ v.grad(ctx[0]) for v in net.collect_params().values() if v._grad is not None ] gluon.utils.clip_global_norm(grads, max_norm=grad_clip_value) if batch_size_scale == 1: trainer.step(batch_size) else: if (i + 1) % batch_size_scale == 0: batch_size_extend_count = 0 trainer.step(batch_size * batch_size_scale) for p in net.collect_params().values(): p.zero_grad() else: batch_size_extend_count += 1 train_loss += sum([loss.mean().asscalar() for loss in loss_list]) / len(loss_list) train_metric.update( labels=(labels_list if not (mixup or label_smoothing) else labels_list_inds), preds=outputs_list) if log_interval and not (i + 1) % log_interval: speed = batch_size * log_interval / (time.time() - btic) btic = time.time() train_accuracy_msg = report_accuracy(metric=train_metric) logging.info( "Epoch[{}] Batch [{}]\tSpeed: {:.2f} samples/sec\t{}\tlr={:.5f}" .format(epoch + 1, i, speed, train_accuracy_msg, trainer.learning_rate)) if (batch_size_scale != 1) and (batch_size_extend_count > 0): trainer.step(batch_size * batch_size_extend_count) for p in net.collect_params().values(): p.zero_grad() throughput = int(batch_size * (i + 1) / (time.time() - tic)) logging.info( "[Epoch {}] speed: {:.2f} samples/sec\ttime cost: {:.2f} sec".format( epoch + 1, throughput, time.time() - tic)) train_loss /= (i + 1) train_accuracy_msg = report_accuracy(metric=train_metric) logging.info("[Epoch {}] training: {}\tloss={:.4f}".format( epoch + 1, train_accuracy_msg, train_loss)) return train_loss
def train_epoch(epoch, net, train_metric, train_data, batch_fn, data_source_needs_reset, dtype, ctx, loss_func, trainer, lr_scheduler, batch_size, log_interval, mixup, mixup_epoch_tail, label_smoothing, num_classes, num_epochs, grad_clip_value, batch_size_scale): """ Train model on particular epoch. Parameters: ---------- epoch : int Epoch number. net : HybridBlock Model. train_metric : EvalMetric Metric object instance. train_data : DataLoader or ImageRecordIter Data loader or ImRec-iterator. batch_fn : func Function for splitting data after extraction from data loader. data_source_needs_reset : bool Whether to reset data (if test_data is ImageRecordIter). dtype : str Base data type for tensors. ctx : Context MXNet context. loss_func : Loss Loss function. trainer : Trainer Trainer. lr_scheduler : LRScheduler Learning rate scheduler. batch_size : int Training batch size. log_interval : int Batch count period for logging. mixup : bool Whether to use mixup. mixup_epoch_tail : int Number of epochs without mixup at the end of training. label_smoothing : bool Whether to use label-smoothing. num_classes : int Number of model classes. num_epochs : int Number of training epochs. grad_clip_value : float Threshold for gradient clipping. batch_size_scale : int Manual batch-size increasing factor. Returns ------- float Loss value. """ labels_list_inds = None batch_size_extend_count = 0 tic = time.time() if data_source_needs_reset: train_data.reset() train_metric.reset() train_loss = 0.0 btic = time.time() for i, batch in enumerate(train_data): data_list, labels_list = batch_fn(batch, ctx) # print("--> {}".format(labels_list[0][0].asscalar())) # import cv2 # img = data_list[0][0].transpose((1, 2, 0)).asnumpy().copy()[:, :, [2, 1, 0]] # img -= img.min() # img *= 255 / img.max() # cv2.imshow(winname="img", mat=img.astype(np.uint8)) # cv2.waitKey() if label_smoothing: eta = 0.1 on_value = 1 - eta + eta / num_classes off_value = eta / num_classes labels_list_inds = labels_list labels_list = [ Y.one_hot(depth=num_classes, on_value=on_value, off_value=off_value) for Y in labels_list ] if mixup: if not label_smoothing: labels_list_inds = labels_list labels_list = [ Y.one_hot(depth=num_classes) for Y in labels_list ] if epoch < num_epochs - mixup_epoch_tail: alpha = 1 lam = np.random.beta(alpha, alpha) data_list = [lam * X + (1 - lam) * X[::-1] for X in data_list] labels_list = [ lam * Y + (1 - lam) * Y[::-1] for Y in labels_list ] with ag.record(): outputs_list = [ net(X.astype(dtype, copy=False)) for X in data_list ] loss_list = [ loss_func(yhat, y.astype(dtype, copy=False)) for yhat, y in zip(outputs_list, labels_list) ] for loss in loss_list: loss.backward() lr_scheduler.update(i, epoch) if grad_clip_value is not None: grads = [ v.grad(ctx[0]) for v in net.collect_params().values() if v._grad is not None ] gluon.utils.clip_global_norm(grads, max_norm=grad_clip_value) if batch_size_scale == 1: trainer.step(batch_size) else: if (i + 1) % batch_size_scale == 0: batch_size_extend_count = 0 trainer.step(batch_size * batch_size_scale) for p in net.collect_params().values(): p.zero_grad() else: batch_size_extend_count += 1 train_loss += sum([loss.mean().asscalar() for loss in loss_list]) / len(loss_list) train_metric.update( labels=(labels_list if not (mixup or label_smoothing) else labels_list_inds), preds=outputs_list) if log_interval and not (i + 1) % log_interval: speed = batch_size * log_interval / (time.time() - btic) btic = time.time() train_accuracy_msg = report_accuracy(metric=train_metric) logging.info( "Epoch[{}] Batch [{}]\tSpeed: {:.2f} samples/sec\t{}\tlr={:.5f}" .format(epoch + 1, i, speed, train_accuracy_msg, trainer.learning_rate)) if (batch_size_scale != 1) and (batch_size_extend_count > 0): trainer.step(batch_size * batch_size_extend_count) for p in net.collect_params().values(): p.zero_grad() throughput = int(batch_size * (i + 1) / (time.time() - tic)) logging.info( "[Epoch {}] speed: {:.2f} samples/sec\ttime cost: {:.2f} sec".format( epoch + 1, throughput, time.time() - tic)) train_loss /= (i + 1) train_accuracy_msg = report_accuracy(metric=train_metric) logging.info("[Epoch {}] training: {}\tloss={:.4f}".format( epoch + 1, train_accuracy_msg, train_loss)) return train_loss