def get_config(model):
    nr_tower = max(get_num_gpu(), 1)
    batch = args.batch // nr_tower

    logger.info("Running on {} towers. Batch size per tower: {}".format(nr_tower, batch))

    callbacks = [ThroughputTracker(args.batch)]
    if args.fake:
        data = QueueInput(FakeData(
            [[batch, 224, 224, 3], [batch]], 1000, random=False, dtype='uint8'))
    else:
        data = QueueInput(
            get_imagenet_dataflow(args.data, 'train', batch),
            # use a larger queue
            queue=tf.FIFOQueue(200, [tf.uint8, tf.int32], [[batch, 224, 224, 3], [batch]])
        )

        BASE_LR = 30
        SCALED_LR = BASE_LR * (args.batch / 256.0)
        callbacks.extend([
            ModelSaver(),
            EstimatedTimeLeft(),
            ScheduledHyperParamSetter(
                'learning_rate', [
                    (0, SCALED_LR),
                    (60, SCALED_LR * 1e-1),
                    (70, SCALED_LR * 1e-2),
                    (80, SCALED_LR * 1e-3),
                    (90, SCALED_LR * 1e-4),
                ]),
        ])

        dataset_val = get_imagenet_dataflow(args.data, 'val', 64)
        infs = [ClassificationError('wrong-top1', 'val-error-top1'),
                ClassificationError('wrong-top5', 'val-error-top5')]
        if nr_tower == 1:
            callbacks.append(InferenceRunner(QueueInput(dataset_val), infs))
        else:
            callbacks.append(DataParallelInferenceRunner(
                dataset_val, infs, list(range(nr_tower))))

    if args.load.endswith(".npz"):
        # a released model in npz format
        init = SmartInit(args.load)
    else:
        # a pre-trained checkpoint
        init = SaverRestore(args.load, ignore=("learning_rate", "global_step"))
    return TrainConfig(
        model=model,
        data=data,
        callbacks=callbacks,
        steps_per_epoch=100 if args.fake else 1281167 // args.batch,
        session_init=init,
        max_epoch=100,
    )
Beispiel #2
0
def get_config(model):
    input_sig = model.get_input_signature()
    nr_tower = max(hvd.size(), 1)
    batch = args.batch // nr_tower
    logger.info("Running on {} towers. Batch size per tower: {}".format(
        nr_tower, batch))

    callbacks = [ThroughputTracker(args.batch), UpdateMomentumEncoder()]

    if args.fake:
        data = QueueInput(
            FakeData([x.shape for x in input_sig],
                     1000,
                     random=False,
                     dtype='uint8'))
    else:
        zmq_addr = 'ipc://@imagenet-train-b{}'.format(batch)
        data = ZMQInput(zmq_addr, 25, bind=False)

        dataset = data.to_dataset(input_sig).repeat().prefetch(15)
        dataset = dataset.apply(
            tf.data.experimental.prefetch_to_device('/gpu:0'))
        data = TFDatasetInput(dataset)

        callbacks.extend([
            ModelSaver(),
            EstimatedTimeLeft(),
        ])

        if not args.v2:
            # step-wise LR in v1
            SCALED_LR = BASE_LR * (args.batch / 256.0)
            callbacks.append(
                ScheduledHyperParamSetter('learning_rate',
                                          [(0, min(BASE_LR, SCALED_LR)),
                                           (120, SCALED_LR * 1e-1),
                                           (160, SCALED_LR * 1e-2)]))
            if SCALED_LR > BASE_LR:
                callbacks.append(
                    ScheduledHyperParamSetter('learning_rate',
                                              [(0, BASE_LR), (5, SCALED_LR)],
                                              interp='linear'))

    return TrainConfig(
        model=model,
        data=data,
        callbacks=callbacks,
        steps_per_epoch=100 if args.fake else 1281167 // args.batch,
        max_epoch=200,
    )
Beispiel #3
0
def launch_train_with_config(config):
    """
    Train with a :class:`TrainConfig` and a :class:`Trainer`, to
    present a simple training interface. It basically does the following
    3 things (and you can easily do them by yourself if you need more control):
    1. Setup the input with automatic prefetching heuristics,
       from `config.data` or `config.dataflow`.
    2. Call `trainer.setup_graph` with the input as well as `config.model`.
    3. Call `trainer.train` with rest of the attributes of config.
    Args:
        config (TrainConfig):
        trainer (Trainer): an instance of :class:`BNNTrainer`.
    Example:
    .. code-block:: python
        launch_train_with_config(
            config, SyncMultiGPUTrainerParameterServer(8, ps_device='gpu'))
    """
    assert config.model is not None
    assert config.dataflow is not None

    model = config.model
    input = QueueInput(config.dataflow)
    trainer = BNNTrainer(input, model)

    trainer.train_with_defaults(
        callbacks=config.callbacks,
        monitors=config.monitors,
        session_creator=config.session_creator,
        session_init=config.session_init,
        steps_per_epoch=config.steps_per_epoch,
        starting_epoch=config.starting_epoch,
        max_epoch=config.max_epoch,
        extra_callbacks=config.extra_callbacks)
Beispiel #4
0
def eval_classification(model, sessinit, dataflow):

    pred_config = PredictConfig(
        model=model,
        session_init=sessinit,
        input_names=['input', 'label'],
        output_names=['wrong-top1','bn5/output:0','conv5/output:0']
    )
    acc1 = RatioCounter()


    pred = FeedfreePredictor(pred_config, StagingInput(QueueInput(dataflow), device='/gpu:0'))

    for _ in tqdm.trange(dataflow.size()):
        top1,afbn5,beforbn5= pred()
        dic ={}
        dic['bn5/output:0']=afbn5
        dic['conv5/output:0']=beforbn5   

        batch_size = top1.shape[0]
        acc1.feed(top1.sum(), batch_size)
        dir = logger.get_logger_dir()

        fname = os.path.join(
                dir, 'afbn5-{}.npz'.format(int(time.time())))
        np.savez(fname, **dic)


    print("Top1 Error: {}".format(acc1.ratio))
Beispiel #5
0
 def pred_config(self, args, df, callbacks) -> TrainConfig:
     return TrainConfig(
         model=self.train_model(args),
         data=StagingInput(QueueInput(df)),
         callbacks=callbacks,
         max_epoch=args.epochs,
         steps_per_epoch=args.steps,
         session_init=SaverRestore(args.load) if args.load else None,
     )
Beispiel #6
0
def get_config(
    files_list,
    input_names=["state_1", "state_2"],
    output_names=["Qvalue_1", "Qvalue_2"],
    agents=2,
):
    """This is only used during training."""
    expreplay = ExpReplay(
        predictor_io_names=(input_names, output_names),
        player=get_player(task="train", files_list=files_list, agents=agents),
        state_shape=IMAGE_SIZE,
        batch_size=BATCH_SIZE,
        memory_size=MEMORY_SIZE,
        init_memory_size=INIT_MEMORY_SIZE,
        init_exploration=1.0,
        update_frequency=UPDATE_FREQ,
        history_len=FRAME_HISTORY,
        agents=agents,
    )

    return TrainConfig(
        # dataflow=expreplay,
        data=QueueInput(expreplay),
        model=Model(agents=agents),
        callbacks=[
            ModelSaver(),
            PeriodicTrigger(
                RunOp(DQNModel.update_target_param, verbose=True),
                # update target network every 10k steps
                every_k_steps=10000 // UPDATE_FREQ,
            ),
            expreplay,
            ScheduledHyperParamSetter("learning_rate", [(60, 4e-4),
                                                        (100, 2e-4)]),
            ScheduledHyperParamSetter(
                ObjAttrParam(expreplay, "exploration"),
                # 1->0.1 in the first million steps
                [(0, 1), (10, 0.1), (320, 0.01)],
                interp="linear",
            ),
            PeriodicTrigger(
                Evaluator(
                    nr_eval=EVAL_EPISODE,
                    input_names=input_names,
                    output_names=output_names,
                    files_list=files_list,
                    get_player_fn=get_player,
                    agents=agents,
                ),
                every_k_epochs=EPOCHS_PER_EVAL,
            ),
            HumanHyperParamSetter("learning_rate"),
        ],
        steps_per_epoch=STEPS_PER_EPOCH,
        max_epoch=1000,
    )
Beispiel #7
0
def test(net,
         session_init,
         val_dataflow,
         do_calc_flops=False,
         extended_log=False):
    """
    Main test routine.

    Parameters:
    ----------
    net : obj
        Model.
    session_init : SessionInit
        Session initializer.
    do_calc_flops : bool, default False
        Whether to calculate count of weights.
    extended_log : bool, default False
        Whether to log more precise accuracy values.
    """
    pred_config = PredictConfig(
        model=net,
        session_init=session_init,
        input_names=["input", "label"],
        output_names=["wrong-top1", "wrong-top5"]
    )
    err_top1 = RatioCounter()
    err_top5 = RatioCounter()

    tic = time.time()
    pred = FeedfreePredictor(pred_config, StagingInput(QueueInput(val_dataflow), device="/gpu:0"))

    for _ in tqdm.trange(val_dataflow.size()):
        err_top1_val, err_top5_val = pred()
        batch_size = err_top1_val.shape[0]
        err_top1.feed(err_top1_val.sum(), batch_size)
        err_top5.feed(err_top5_val.sum(), batch_size)

    err_top1_val = err_top1.ratio
    err_top5_val = err_top5.ratio

    if extended_log:
        logging.info("Test: err-top1={top1:.4f} ({top1})\terr-top5={top5:.4f} ({top5})".format(
            top1=err_top1_val, top5=err_top5_val))
    else:
        logging.info("Test: err-top1={top1:.4f}\terr-top5={top5:.4f}".format(
            top1=err_top1_val, top5=err_top5_val))
    logging.info("Time cost: {:.4f} sec".format(
        time.time() - tic))

    if do_calc_flops:
        calc_flops(model=net)
Beispiel #8
0
def get_config():
    """This is only used during training."""
    expreplay = ExpReplay(predictor_io_names=(['state'], ['Qvalue']),
                          player=get_player(directory=data_dir,
                                            task='train',
                                            files_list=train_data_fpaths),
                          state_shape=OBSERVATION_DIMS,
                          batch_size=BATCH_SIZE,
                          memory_size=MEMORY_SIZE,
                          init_memory_size=INIT_MEMORY_SIZE,
                          init_exploration=1.0,
                          update_frequency=UPDATE_FREQ,
                          frame_history_len=FRAME_HISTORY)

    return TrainConfig(
        # dataflow=expreplay,
        data=QueueInput(expreplay),
        model=Model(),
        callbacks=[  # TODO: periodically save videos
            ModelSaver(checkpoint_dir="model_checkpoints",
                       keep_checkpoint_every_n_hours=0.25,
                       max_to_keep=1000),  # TODO: og was just ModelSaver()
            PeriodicTrigger(
                RunOp(DQNModel.update_target_param, verbose=True),
                # update target network every 10k/freq steps
                every_k_steps=10000 // UPDATE_FREQ),
            # expreplay,
            ScheduledHyperParamSetter('learning_rate', [(60, 4e-4),
                                                        (100, 2e-4)]),
            ScheduledHyperParamSetter(
                ObjAttrParam(expreplay, 'exploration'),
                # 1->0.1 in the first 10M steps
                [(0, 1), (100, 0.1), (120, 0.01)],
                interp='linear'),
            PeriodicTrigger(  # runs exprelay._trigger()
                expreplay, every_k_steps=5000),
            PeriodicTrigger(
                # eval_model_multithread(pred, EVAL_EPISODE, get_player)
                Evaluator(nr_eval=EVAL_EPISODE,
                          input_names=['state'],
                          output_names=['Qvalue'],
                          directory=data_dir,
                          files_list=test_data_fpaths,
                          get_player_fn=get_player),
                every_k_steps=10000 // UPDATE_FREQ),
            HumanHyperParamSetter('learning_rate'),
        ],
        steps_per_epoch=STEPS_PER_EPOCH,
        max_epoch=NUM_EPOCHS,
    )
Beispiel #9
0
def train_net(net,
              session_init,
              batch_size,
              num_epochs,
              train_dataflow,
              val_dataflow):

    num_towers = max(get_num_gpu(), 1)
    batch_per_tower = batch_size // num_towers
    logger.info("Running on {} towers. Batch size per tower: {}".format(num_towers, batch_per_tower))

    num_training_samples = 1281167
    step_size = num_training_samples // batch_size
    max_iter = (num_epochs - 1) * step_size
    callbacks = [
        ModelSaver(),
        ScheduledHyperParamSetter(
            'learning_rate',
            [(0, 0.5), (max_iter, 0)],
            interp='linear',
            step_based=True),
        EstimatedTimeLeft()]

    infs = [ClassificationError('wrong-top1', 'val-error-top1'),
            ClassificationError('wrong-top5', 'val-error-top5')]
    if num_towers == 1:
        # single-GPU inference with queue prefetch
        callbacks.append(InferenceRunner(
            input=QueueInput(val_dataflow),
            infs=infs))
    else:
        # multi-GPU inference (with mandatory queue prefetch)
        callbacks.append(DataParallelInferenceRunner(
            input=val_dataflow,
            infs=infs,
            gpus=list(range(num_towers))))

    config = TrainConfig(
        dataflow=train_dataflow,
        model=net,
        callbacks=callbacks,
        session_init=session_init,
        steps_per_epoch=step_size,
        max_epoch=num_epochs)

    launch_train_with_config(
        config=config,
        trainer=SyncMultiGPUTrainerParameterServer(num_towers))
Beispiel #10
0
def get_config():
    """This is only used during training."""
    expreplay = ExpReplay(predictor_io_names=(['state'], ['Qvalue']),
                          player=get_player(directory=data_dir,
                                            task='train',
                                            files_list=train_list),
                          state_shape=IMAGE_SIZE,
                          batch_size=BATCH_SIZE,
                          memory_size=MEMORY_SIZE,
                          init_memory_size=INIT_MEMORY_SIZE,
                          init_exploration=1.0,
                          update_frequency=UPDATE_FREQ,
                          history_len=FRAME_HISTORY)

    return TrainConfig(
        # dataflow=expreplay,
        data=QueueInput(expreplay),
        model=Model(),
        callbacks=[
            ModelSaver(),
            PeriodicTrigger(
                RunOp(DQNModel.update_target_param, verbose=True),
                # update target network every 10k steps
                every_k_steps=10000 // UPDATE_FREQ),
            expreplay,
            ScheduledHyperParamSetter('learning_rate', [(60, 4e-4),
                                                        (100, 2e-4)]),
            ScheduledHyperParamSetter(
                ObjAttrParam(expreplay, 'exploration'),
                # 1->0.1 in the first million steps
                [(0, 1), (10, 0.1), (320, 0.01)],
                interp='linear'),
            PeriodicTrigger(Evaluator(nr_eval=EVAL_EPISODE,
                                      input_names=['state'],
                                      output_names=['Qvalue'],
                                      directory=data_dir,
                                      files_list=test_list,
                                      get_player_fn=get_player),
                            every_k_epochs=EPOCHS_PER_EVAL),
            HumanHyperParamSetter('learning_rate'),
        ],
        steps_per_epoch=STEPS_PER_EPOCH,
        max_epoch=1000,
    )
Beispiel #11
0
def test(net,
         session_init,
         val_dataflow,
         do_calc_flops=False,
         extended_log=False):

    pred_config = PredictConfig(
        model=net,
        session_init=session_init,
        input_names=['input', 'label'],
        output_names=['wrong-top1', 'wrong-top5']
    )
    err_top1 = RatioCounter()
    err_top5 = RatioCounter()

    tic = time.time()
    pred = FeedfreePredictor(pred_config, StagingInput(QueueInput(val_dataflow), device='/gpu:0'))

    # import tensorflow as tf
    # summ_writer = tf.summary.FileWriter("/home/semery/projects/imgclsmob_data/gl-squeezenet_v1_1/", pred._sess.graph)

    for _ in tqdm.trange(val_dataflow.size()):
        err_top1_val, err_top5_val = pred()
        batch_size = err_top1_val.shape[0]
        err_top1.feed(err_top1_val.sum(), batch_size)
        err_top5.feed(err_top5_val.sum(), batch_size)
        # print("err_top1_val={}".format(err_top1_val.sum() / batch_size))
        # print("err_top5_val={}".format(err_top5_val.sum() / batch_size))

    err_top1_val = err_top1.ratio
    err_top5_val = err_top5.ratio

    if extended_log:
        logging.info('Test: err-top1={top1:.4f} ({top1})\terr-top5={top5:.4f} ({top5})'.format(
            top1=err_top1_val, top5=err_top5_val))
    else:
        logging.info('Test: err-top1={top1:.4f}\terr-top5={top5:.4f}'.format(
            top1=err_top1_val, top5=err_top5_val))
    logging.info('Time cost: {:.4f} sec'.format(
        time.time() - tic))

    if do_calc_flops:
        calc_flops(model=net)
Beispiel #12
0
def eval_on_ILSVRC12(model, sessinit, dataflow):
    pred_config = PredictConfig(model=model,
                                session_init=sessinit,
                                input_names=['input', 'label'],
                                output_names=['wrong-top1', 'wrong-top5'])
    acc1, acc5 = RatioCounter(), RatioCounter()

    # This does not have a visible improvement over naive predictor,
    # but will have an improvement if image_dtype is set to float32.
    pred = FeedfreePredictor(
        pred_config, StagingInput(QueueInput(dataflow), device='/gpu:0'))
    for _ in tqdm.trange(dataflow.size()):
        top1, top5 = pred()
        batch_size = top1.shape[0]
        acc1.feed(top1.sum(), batch_size)
        acc5.feed(top5.sum(), batch_size)

    print("Top1 Error: {}".format(acc1.ratio))
    print("Top5 Error: {}".format(acc5.ratio))
Beispiel #13
0
def test(net,
         session_init,
         val_dataflow,
         do_calc_flops=False,
         extended_log=False):

    pred_config = PredictConfig(
        model=net,
        session_init=session_init,
        input_names=["input", "label"],
        output_names=["wrong-top1", "wrong-top5"]
    )
    err_top1 = RatioCounter()
    err_top5 = RatioCounter()

    tic = time.time()
    pred = FeedfreePredictor(pred_config, StagingInput(QueueInput(val_dataflow), device="/gpu:0"))

    for _ in tqdm.trange(val_dataflow.size()):
        err_top1_val, err_top5_val = pred()
        batch_size = err_top1_val.shape[0]
        err_top1.feed(err_top1_val.sum(), batch_size)
        err_top5.feed(err_top5_val.sum(), batch_size)

    err_top1_val = err_top1.ratio
    err_top5_val = err_top5.ratio

    if extended_log:
        logging.info("Test: err-top1={top1:.4f} ({top1})\terr-top5={top5:.4f} ({top5})".format(
            top1=err_top1_val, top5=err_top5_val))
    else:
        logging.info("Test: err-top1={top1:.4f}\terr-top5={top5:.4f}".format(
            top1=err_top1_val, top5=err_top5_val))
    logging.info("Time cost: {:.4f} sec".format(
        time.time() - tic))

    if do_calc_flops:
        calc_flops(model=net)
Beispiel #14
0
def eval_classification(model, sessinit, dataflow):
    """
    Eval a classification model on the dataset. It assumes the model inputs are
    named "input" and "label", and contains "wrong-top1" and "wrong-top5" in the graph.
    """
    pred_config = PredictConfig(model=model,
                                session_init=sessinit,
                                input_names=['input', 'label'],
                                output_names=['wrong-top1', 'wrong-top5'])
    acc1, acc5 = RatioCounter(), RatioCounter()

    # This does not have a visible improvement over naive predictor,
    # but will have an improvement if image_dtype is set to float32.
    pred = FeedfreePredictor(
        pred_config, StagingInput(QueueInput(dataflow), device='/gpu:0'))
    for _ in tqdm.trange(dataflow.size()):
        top1, top5 = pred()
        batch_size = top1.shape[0]
        acc1.feed(top1.sum(), batch_size)
        acc5.feed(top5.sum(), batch_size)

    print("Top1 Error: {}".format(acc1.ratio))
    print("Top5 Error: {}".format(acc5.ratio))
Beispiel #15
0
def get_config(files_list, data_type, trainable_variables):
    """This is only used during training."""
    expreplay = ExpReplay(
        predictor_io_names=(['state'], ['Qvalue']),
        player=get_player(task='train',
                          files_list=files_list,
                          data_type=data_type),
        state_shape=IMAGE_SIZE,
        batch_size=BATCH_SIZE,
        memory_size=MEMORY_SIZE,
        init_memory_size=INIT_MEMORY_SIZE,
        init_exploration=0.8,  #0.0
        ###############################################################################
        # HITL UPDATE
        update_frequency=INIT_UPDATE_FREQ,
        ###############################################################################
        history_len=FRAME_HISTORY,
        arg_type=data_type)

    return TrainConfig(
        # dataflow=expreplay,
        data=QueueInput(expreplay),
        model=Model(IMAGE_SIZE, FRAME_HISTORY, METHOD, NUM_ACTIONS, GAMMA,
                    trainable_variables),
        callbacks=[
            ModelSaver(),
            PeriodicTrigger(
                RunOp(DQNModel.update_target_param, verbose=True),
                # update target network every 10k steps
                every_k_steps=10000 // UPDATE_FREQ),
            expreplay,
            ScheduledHyperParamSetter('learning_rate', [(60, 4e-4),
                                                        (100, 2e-4)]),
            ScheduledHyperParamSetter(
                ObjAttrParam(expreplay, 'exploration'),
                # 1->0.1 in the first million steps
                [(0, 0.8), (1000000, 0.1), (32000000, 0.01)],
                interp='linear',
                step_based=True),
            ###############################################################################
            # HITL UPDATE
            # Here the number of steps taken in the environment is increased from 0, during
            # the pretraining phase, to 4 to allow the agent to take 4 steps in the env
            # between each TD update.
            ScheduledHyperParamSetter(ObjAttrParam(expreplay,
                                                   'update_frequency'),
                                      [(0, INIT_UPDATE_FREQ),
                                       (NUM_PRETRAIN, UPDATE_FREQ)],
                                      interp=None,
                                      step_based=True),

            ###############################################################################
            PeriodicTrigger(Evaluator(nr_eval=EVAL_EPISODE,
                                      input_names=['state'],
                                      output_names=['Qvalue'],
                                      files_list=files_list,
                                      data_type=data_type,
                                      get_player_fn=get_player),
                            every_k_steps=STEPS_PER_EVAL),
            HumanHyperParamSetter('learning_rate'),
        ],
        steps_per_epoch=STEPS_PER_EPOCH,
        max_epoch=MAX_EPOCHS,
    )
Beispiel #16
0
    return train, test


if __name__ == '__main__':
    logger.auto_set_dir()
    M = keras.models.Sequential()
    M.add(KL.Conv2D(32, 3, activation='relu', input_shape=[IMAGE_SIZE, IMAGE_SIZE, 1], padding='same'))
    M.add(KL.MaxPooling2D())
    M.add(KL.Conv2D(32, 3, activation='relu', padding='same'))
    M.add(KL.Conv2D(32, 3, activation='relu', padding='same'))
    M.add(KL.MaxPooling2D())
    M.add(KL.Conv2D(32, 3, padding='same', activation='relu'))
    M.add(KL.Flatten())
    M.add(KL.Dense(512, activation='relu', kernel_regularizer=keras.regularizers.l2(1e-5)))
    M.add(KL.Dropout(0.5))
    M.add(KL.Dense(10, activation=None, kernel_regularizer=keras.regularizers.l2(1e-5)))
    M.add(KL.Activation('softmax'))

    dataset_train, dataset_test = get_data()

    M = KerasModel(M, QueueInput(dataset_train))
    M.compile(
        optimizer=tf.train.AdamOptimizer(1e-3),
        loss='categorical_crossentropy',
        metrics=['accuracy']
    )
    M.fit(
        validation_data=dataset_test,
        steps_per_epoch=dataset_train.size(),
    )
Beispiel #17
0
    M.add(KL.Conv2D(32, 3, activation='relu', padding='same'))
    M.add(KL.MaxPooling2D())
    M.add(KL.Conv2D(32, 3, padding='same', activation='relu'))
    M.add(KL.Flatten())
    M.add(
        KL.Dense(512,
                 activation='relu',
                 kernel_regularizer=regularizers.l2(1e-5)))
    M.add(KL.Dropout(0.5))
    M.add(
        KL.Dense(10, activation=None,
                 kernel_regularizer=regularizers.l2(1e-5)))
    M.add(KL.Activation('softmax'))

    trainer = SimpleTrainer()

    setup_keras_trainer(trainer,
                        model=M,
                        input=QueueInput(dataset_train),
                        optimizer=tf.train.AdamOptimizer(1e-3),
                        loss='categorical_crossentropy',
                        metrics=['accuracy'])
    trainer.train_with_defaults(
        callbacks=[
            ModelSaver(),
            InferenceRunner(dataset_test,
                            [ScalarStats(['total_loss', 'accuracy'])]),
        ],
        steps_per_epoch=dataset_train.size(),
    )