Пример #1
0
def get_estimator(level,
                  num_augment,
                  epochs=24,
                  batch_size=512,
                  max_train_steps_per_epoch=None,
                  max_eval_steps_per_epoch=None):
    assert 0 <= level <= 10, "the level should be between 0 and 10"
    train_data, test_data = load_data()
    aug_ops = [
        OneOf(
            Rotate(level=level, inputs="x", outputs="x", mode="train"),
            Identity(level=level, inputs="x", outputs="x", mode="train"),
            AutoContrast(level=level, inputs="x", outputs="x", mode="train"),
            Equalize(level=level, inputs="x", outputs="x", mode="train"),
            Posterize(level=level, inputs="x", outputs="x", mode="train"),
            Solarize(level=level, inputs="x", outputs="x", mode="train"),
            Sharpness(level=level, inputs="x", outputs="x", mode="train"),
            Contrast(level=level, inputs="x", outputs="x", mode="train"),
            Color(level=level, inputs="x", outputs="x", mode="train"),
            Brightness(level=level, inputs="x", outputs="x", mode="train"),
            ShearX(level=level, inputs="x", outputs="x", mode="train"),
            ShearY(level=level, inputs="x", outputs="x", mode="train"),
            TranslateX(level=level, inputs="x", outputs="x", mode="train"),
            TranslateY(level=level, inputs="x", outputs="x", mode="train"),
        ) for _ in range(num_augment)
    ]
    pipeline = fe.Pipeline(train_data=train_data,
                           test_data=test_data,
                           batch_size=batch_size,
                           ops=aug_ops + [
                               Normalize(inputs="x",
                                         outputs="x",
                                         mean=(0.4914, 0.4822, 0.4465),
                                         std=(0.2471, 0.2435, 0.2616)),
                               ChannelTranspose(inputs="x", outputs="x"),
                           ])
    model = fe.build(model_fn=MyModel, optimizer_fn="adam")
    network = fe.Network(ops=[
        ModelOp(model=model, inputs="x", outputs="y_pred"),
        CrossEntropy(inputs=("y_pred", "y"), outputs="ce"),
        UpdateOp(model=model, loss_name="ce")
    ])
    estimator = fe.Estimator(
        pipeline=pipeline,
        network=network,
        epochs=epochs,
        traces=Accuracy(true_key="y", pred_key="y_pred"),
        max_train_steps_per_epoch=max_train_steps_per_epoch,
        max_eval_steps_per_epoch=max_eval_steps_per_epoch)
    return estimator