def get_estimator(level, num_augment, epochs=24, batch_size=512, max_train_steps_per_epoch=None, max_eval_steps_per_epoch=None): assert 0 <= level <= 10, "the level should be between 0 and 10" train_data, test_data = load_data() aug_ops = [ OneOf( Rotate(level=level, inputs="x", outputs="x", mode="train"), Identity(level=level, inputs="x", outputs="x", mode="train"), AutoContrast(level=level, inputs="x", outputs="x", mode="train"), Equalize(level=level, inputs="x", outputs="x", mode="train"), Posterize(level=level, inputs="x", outputs="x", mode="train"), Solarize(level=level, inputs="x", outputs="x", mode="train"), Sharpness(level=level, inputs="x", outputs="x", mode="train"), Contrast(level=level, inputs="x", outputs="x", mode="train"), Color(level=level, inputs="x", outputs="x", mode="train"), Brightness(level=level, inputs="x", outputs="x", mode="train"), ShearX(level=level, inputs="x", outputs="x", mode="train"), ShearY(level=level, inputs="x", outputs="x", mode="train"), TranslateX(level=level, inputs="x", outputs="x", mode="train"), TranslateY(level=level, inputs="x", outputs="x", mode="train"), ) for _ in range(num_augment) ] pipeline = fe.Pipeline(train_data=train_data, test_data=test_data, batch_size=batch_size, ops=aug_ops + [ Normalize(inputs="x", outputs="x", mean=(0.4914, 0.4822, 0.4465), std=(0.2471, 0.2435, 0.2616)), ChannelTranspose(inputs="x", outputs="x"), ]) model = fe.build(model_fn=MyModel, optimizer_fn="adam") network = fe.Network(ops=[ ModelOp(model=model, inputs="x", outputs="y_pred"), CrossEntropy(inputs=("y_pred", "y"), outputs="ce"), UpdateOp(model=model, loss_name="ce") ]) estimator = fe.Estimator( pipeline=pipeline, network=network, epochs=epochs, traces=Accuracy(true_key="y", pred_key="y_pred"), max_train_steps_per_epoch=max_train_steps_per_epoch, max_eval_steps_per_epoch=max_eval_steps_per_epoch) return estimator