Exemplo n.º 1
0
def get_estimator(epochs=10, batch_size=32, alpha=1.0, warmup=0, model_dir=tempfile.mkdtemp()):
    (x_train, y_train), (x_eval, y_eval) = tf.keras.datasets.cifar10.load_data()
    data = {"train": {"x": x_train, "y": y_train}, "eval": {"x": x_eval, "y": y_eval}}
    num_classes = 10
    pipeline = fe.Pipeline(batch_size=batch_size, data=data, ops=Minmax(inputs="x", outputs="x"))

    model = FEModel(model_def=lambda: LeNet(input_shape=x_train.shape[1:], classes=num_classes),
                    model_name="LeNet",
                    optimizer="adam")

    mixup_map = {warmup: MixUpBatch(inputs="x", outputs=["x", "lambda"], alpha=alpha, mode="train")}
    mixup_loss = {
        0: SparseCategoricalCrossentropy(y_true="y", y_pred="y_pred", mode="train"),
        warmup: MixUpLoss(KerasCrossentropy(), lam="lambda", y_true="y", y_pred="y_pred", mode="train")
    }
    network = fe.Network(ops=[
        Scheduler(mixup_map),
        ModelOp(inputs="x", model=model, outputs="y_pred"),
        Scheduler(mixup_loss),
        SparseCategoricalCrossentropy(y_true="y", y_pred="y_pred", mode="eval")
    ])

    traces = [
        Accuracy(true_key="y", pred_key="y_pred"),
        ConfusionMatrix(true_key="y", pred_key="y_pred", num_classes=num_classes),
        ModelSaver(model_name="LeNet", save_dir=model_dir, save_best=True)
    ]

    estimator = fe.Estimator(network=network, pipeline=pipeline, epochs=epochs, traces=traces)
    return estimator
Exemplo n.º 2
0
    def prepare(self, mode_list, distribute_strategy):
        """This function constructs the model specified in model definition and create replica of model
         for distributed training across multiple devices if there are multiple GPU available.

        Args:
            mode_list : can be either 'train' or 'eval'
            distribute_strategy : Tensorflow class that defines distribution strategy (e.g. tf.distribute.MirroredStrategy)
        """
        all_output_keys = []
        for mode in mode_list:
            signature_epoch, mode_ops = self._get_signature_epoch(mode)
            epoch_ops_map = {}
            epoch_model_map = {}
            for epoch in signature_epoch:
                epoch_ops = []
                epoch_model = []
                # generate ops for specific mode and epoch
                for op in mode_ops:
                    if isinstance(op, Scheduler):
                        scheduled_op = op.get_current_value(epoch)
                        if scheduled_op:
                            epoch_ops.append(scheduled_op)
                    else:
                        epoch_ops.append(op)
                # check the ops
                verify_ops(epoch_ops, "Network")
                # create model list
                for op in epoch_ops:
                    all_output_keys.append(op.outputs)
                    if isinstance(op, ModelOp):
                        if op.model.keras_model is None:
                            with distribute_strategy.scope(
                            ) if distribute_strategy else NonContext():
                                op.model.keras_model = op.model.model_def()
                                op.model.keras_model.optimizer = op.model.optimizer
                                op.model.keras_model.loss_name = op.model.loss_name
                                op.model.keras_model.model_name = op.model.model_name
                                assert op.model.model_name not in self.model, \
                                    "duplicated model name: {}".format(op.model.model_name)
                                self.model[
                                    op.model.model_name] = op.model.keras_model
                                if op.model.loss_name not in self.all_losses:
                                    self.all_losses.append(op.model.loss_name)
                        if op.model.keras_model not in epoch_model:
                            epoch_model.append(op.model.keras_model)
                assert epoch_model, "Network has no model for epoch {}".format(
                    epoch)
                epoch_ops_map[epoch] = epoch_ops
                epoch_model_map[epoch] = epoch_model
            self.op_schedule[mode] = Scheduler(epoch_dict=epoch_ops_map)
            self.model_schedule[mode] = Scheduler(epoch_dict=epoch_model_map)
        self.all_output_keys = set(flatten_list(all_output_keys)) - {None}
Exemplo n.º 3
0
 def prepare(self, mode_list):
     """This function constructs the operations necessary for each epoch
     """
     all_output_keys = []
     all_models = []
     for mode in mode_list:
         signature_epoch, mode_ops = self._get_signature_epoch(mode)
         epoch_ops_map = {}
         epoch_model_map = {}
         for epoch in signature_epoch:
             epoch_ops = []
             epoch_model = []
             epoch_model_update = defaultdict(lambda: False)
             # generate ops for specific mode and epoch
             for op in mode_ops:
                 if isinstance(op, Scheduler):
                     scheduled_op = op.get_current_value(epoch)
                     if scheduled_op:
                         epoch_ops.append(scheduled_op)
                 else:
                     epoch_ops.append(op)
             # check the ops
             verify_ops(epoch_ops, "Network")
             # create model list
             for op in epoch_ops:
                 all_output_keys.append(op.outputs)
                 if isinstance(op, ModelOp):
                     if op.model not in epoch_model:
                         epoch_model.append(op.model)
                         epoch_model_update[op.model] = epoch_model_update[
                             op.model]
                     if op.model not in all_models:
                         all_models.append(op.model)
                 if isinstance(op, UpdateOp):
                     epoch_model_update[op.model] = True
             if mode == "train":
                 for model, has_update in epoch_model_update.items():
                     if not has_update:
                         epoch_ops.append(UpdateOp(model=model))
             assert epoch_model, "Network has no model for epoch {}".format(
                 epoch)
             epoch_ops_map[epoch] = epoch_ops
             epoch_model_map[epoch] = epoch_model
         self.op_schedule[mode] = Scheduler(epoch_dict=epoch_ops_map)
         self.model_schedule[mode] = Scheduler(epoch_dict=epoch_model_map)
     self.all_output_keys = set(flatten_list(all_output_keys)) - {None}
     for model in all_models:
         assert model.model_name not in self.model, "duplicated model name: {}".format(
             model.model_name)
         self.model[model.model_name] = model
Exemplo n.º 4
0
 def _transform_dataset(self, mode):
     all_output_keys = []
     signature_epoch, mode_ops = self._get_signature_epoch(mode)
     extracted_ds = self.extracted_dataset[mode]
     state = {"mode": mode}
     dataset_map = {}
     for epoch in signature_epoch:
         epoch_ops_all = []
         forward_ops_epoch = []
         filter_ops_epoch = []
         forward_ops_between_filter = []
         # get batch size for the epoch
         global_batch_size = self.get_global_batch_size(epoch)
         # generate ops for specific mode and epoch
         for op in mode_ops:
             if isinstance(op, Scheduler):
                 scheduled_op = op.get_current_value(epoch)
                 if scheduled_op:
                     epoch_ops_all.append(scheduled_op)
             else:
                 epoch_ops_all.append(op)
         # check the ops
         epoch_ops_without_filter = [
             op for op in epoch_ops_all if not isinstance(op, TensorFilter)
         ]
         verify_ops(epoch_ops_without_filter, "Pipeline")
         # arrange operation according to filter location
         for op in epoch_ops_all:
             all_output_keys.append(op.outputs)
             if not isinstance(op, TensorFilter):
                 forward_ops_between_filter.append(op)
             else:
                 forward_ops_epoch.append(forward_ops_between_filter)
                 filter_ops_epoch.append(op)
                 forward_ops_between_filter = []
         forward_ops_epoch.append(forward_ops_between_filter)
         # execute the operations
         dataset = self._execute_ops(extracted_ds, forward_ops_epoch,
                                     filter_ops_epoch, state)
         if self.expand_dims:
             dataset = dataset.flat_map(tf.data.Dataset.from_tensor_slices)
         if self.batch:
             if self.padded_batch:
                 _ = dataset.map(self._get_padded_shape)
                 dataset = dataset.padded_batch(
                     global_batch_size, padded_shapes=self.padded_shape)
             else:
                 dataset = dataset.batch(global_batch_size)
         dataset = dataset.prefetch(buffer_size=1)
         if fe.distribute_strategy:
             dataset = fe.distribute_strategy.experimental_distribute_dataset(
                 dataset)
         dataset_map[epoch] = iter(dataset)
     self.dataset_schedule[mode] = Scheduler(epoch_dict=dataset_map)
     self.all_output_keys = self.all_output_keys | set(
         flatten_list(all_output_keys))
def get_estimator(data_dir=None, save_dir=None):
    train_csv, data_path = load_data(data_dir)

    imreader = ImageReader(inputs="x", parent_path=data_path, grey_scale=True)
    writer_128 = RecordWriter(
        save_dir=os.path.join(data_path, "tfrecord_128"),
        train_data=train_csv,
        ops=[imreader,
             ResizeRecord(target_size=(128, 128), outputs="x")])
    writer_1024 = RecordWriter(
        save_dir=os.path.join(data_path, "tfrecord_1024"),
        train_data=train_csv,
        ops=[imreader,
             ResizeRecord(target_size=(1024, 1024), outputs="x")])
    # We create a scheduler for batch_size with the epochs at which it will change and corresponding values.
    batchsize_scheduler_128 = Scheduler({
        0: 128,
        5: 64,
        15: 32,
        25: 16,
        35: 8,
        45: 4
    })
    batchsize_scheduler_1024 = Scheduler({55: 4, 65: 2, 75: 1})
    # pipeline ops
    resize_scheduler_128 = Scheduler({
        0:
        Resize(inputs="x", size=(4, 4), outputs="x"),
        5:
        Resize(inputs="x", size=(8, 8), outputs="x"),
        15:
        Resize(inputs="x", size=(16, 16), outputs="x"),
        25:
        Resize(inputs="x", size=(32, 32), outputs="x"),
        35:
        Resize(inputs="x", size=(64, 64), outputs="x"),
        45:
        None
    })
    resize_scheduler_1024 = Scheduler({
        55:
        Resize(inputs="x", size=(256, 256), outputs="x"),
        65:
        Resize(inputs="x", size=(512, 512), outputs="x"),
        75:
        None
    })
    lowres_op = CreateLowRes(inputs="x", outputs="x_lowres")
    rescale_x = Rescale(inputs="x", outputs="x")
    rescale_lowres = Rescale(inputs="x_lowres", outputs="x_lowres")
    pipeline_128 = fe.Pipeline(
        batch_size=batchsize_scheduler_128,
        data=writer_128,
        ops=[resize_scheduler_128, lowres_op, rescale_x, rescale_lowres])
    pipeline_1024 = fe.Pipeline(
        batch_size=batchsize_scheduler_1024,
        data=writer_1024,
        ops=[resize_scheduler_1024, lowres_op, rescale_x, rescale_lowres])

    pipeline_scheduler = Scheduler({0: pipeline_128, 55: pipeline_1024})

    optimizer = tf.keras.optimizers.Adam(learning_rate=0.001,
                                         beta_1=0.0,
                                         beta_2=0.99,
                                         epsilon=1e-8)

    fade_in_alpha = tf.Variable(initial_value=1.0,
                                dtype='float32',
                                trainable=False)

    d2, d3, d4, d5, d6, d7, d8, d9, d10 = fe.build(
        model_def=lambda: build_D(
            fade_in_alpha=fade_in_alpha, target_resolution=10, num_channels=1),
        model_name=["d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10"],
        optimizer=[optimizer] * 9,
        loss_name=["dloss"] * 9)

    g2, g3, g4, g5, g6, g7, g8, g9, g10, G = fe.build(
        model_def=lambda: build_G(
            fade_in_alpha=fade_in_alpha, target_resolution=10, num_channels=1),
        model_name=[
            "g2", "g3", "g4", "g5", "g6", "g7", "g8", "g9", "g10", "G"
        ],
        optimizer=[optimizer] * 10,
        loss_name=["gloss"] * 10)

    g_scheduler = Scheduler({
        0: ModelOp(model=g2, outputs="x_fake"),
        5: ModelOp(model=g3, outputs="x_fake"),
        15: ModelOp(model=g4, outputs="x_fake"),
        25: ModelOp(model=g5, outputs="x_fake"),
        35: ModelOp(model=g6, outputs="x_fake"),
        45: ModelOp(model=g7, outputs="x_fake"),
        55: ModelOp(model=g8, outputs="x_fake"),
        65: ModelOp(model=g9, outputs="x_fake"),
        75: ModelOp(model=g10, outputs="x_fake")
    })

    fake_score_scheduler = Scheduler({
        0:
        ModelOp(inputs="x_fake", model=d2, outputs="fake_score"),
        5:
        ModelOp(inputs="x_fake", model=d3, outputs="fake_score"),
        15:
        ModelOp(inputs="x_fake", model=d4, outputs="fake_score"),
        25:
        ModelOp(inputs="x_fake", model=d5, outputs="fake_score"),
        35:
        ModelOp(inputs="x_fake", model=d6, outputs="fake_score"),
        45:
        ModelOp(inputs="x_fake", model=d7, outputs="fake_score"),
        55:
        ModelOp(inputs="x_fake", model=d8, outputs="fake_score"),
        65:
        ModelOp(inputs="x_fake", model=d9, outputs="fake_score"),
        75:
        ModelOp(inputs="x_fake", model=d10, outputs="fake_score")
    })

    real_score_scheduler = Scheduler({
        0:
        ModelOp(model=d2, outputs="real_score"),
        5:
        ModelOp(model=d3, outputs="real_score"),
        15:
        ModelOp(model=d4, outputs="real_score"),
        25:
        ModelOp(model=d5, outputs="real_score"),
        35:
        ModelOp(model=d6, outputs="real_score"),
        45:
        ModelOp(model=d7, outputs="real_score"),
        55:
        ModelOp(model=d8, outputs="real_score"),
        65:
        ModelOp(model=d9, outputs="real_score"),
        75:
        ModelOp(model=d10, outputs="real_score")
    })

    interp_score_scheduler = Scheduler({
        0:
        ModelOp(inputs="x_interp",
                model=d2,
                outputs="interp_score",
                track_input=True),
        5:
        ModelOp(inputs="x_interp",
                model=d3,
                outputs="interp_score",
                track_input=True),
        15:
        ModelOp(inputs="x_interp",
                model=d4,
                outputs="interp_score",
                track_input=True),
        25:
        ModelOp(inputs="x_interp",
                model=d5,
                outputs="interp_score",
                track_input=True),
        35:
        ModelOp(inputs="x_interp",
                model=d6,
                outputs="interp_score",
                track_input=True),
        45:
        ModelOp(inputs="x_interp",
                model=d7,
                outputs="interp_score",
                track_input=True),
        55:
        ModelOp(inputs="x_interp",
                model=d8,
                outputs="interp_score",
                track_input=True),
        65:
        ModelOp(inputs="x_interp",
                model=d9,
                outputs="interp_score",
                track_input=True),
        75:
        ModelOp(inputs="x_interp",
                model=d10,
                outputs="interp_score",
                track_input=True)
    })

    network = fe.Network(ops=[
        RandomInput(inputs=lambda: 512), g_scheduler, fake_score_scheduler,
        ImageBlender(inputs=(
            "x", "x_lowres"), alpha=fade_in_alpha), real_score_scheduler,
        Interpolate(inputs=("x_fake",
                            "x"), outputs="x_interp"), interp_score_scheduler,
        GradientPenalty(inputs=("x_interp", "interp_score"), outputs="gp"),
        GLoss(inputs="fake_score", outputs="gloss"),
        DLoss(inputs=("real_score", "fake_score", "gp"), outputs="dloss")
    ])

    if save_dir is None:
        save_dir = os.path.join(str(Path.home()), 'fastestimator_results',
                                'NIH_CXR_PGGAN')
        os.makedirs(save_dir, exist_ok=True)

    estimator = fe.Estimator(
        network=network,
        pipeline=pipeline_scheduler,
        epochs=85,
        traces=[
            AlphaController(alpha=fade_in_alpha,
                            fade_start=[5, 15, 25, 35, 45, 55, 65, 75, 85],
                            duration=[5, 5, 5, 5, 5, 5, 5, 5, 5]),
            ResetOptimizer(reset_epochs=[5, 15, 25, 35, 45, 55, 65, 75],
                           optimizer=optimizer),
            ImageSaving(epoch_model={
                4: g2,
                14: g3,
                24: g4,
                34: g5,
                44: g6,
                54: g7,
                64: g8,
                74: g9,
                84: G
            },
                        save_dir=save_dir,
                        num_channels=1),
            ModelSaving(epoch_model={84: G}, save_dir=save_dir)
        ])
    return estimator
Exemplo n.º 6
0
def get_estimator():
    train_csv, data_path = load_data()
    writer = RecordWriter(save_dir=os.path.join(data_path, "tfrecord"),
                          train_data=train_csv,
                          ops=[
                              ImageReader(inputs="x", parent_path=data_path),
                              ResizeRecord(target_size=(128, 128), outputs="x")
                          ])

    # We create a scheduler for batch_size with the epochs at which it will change and corresponding values.
    batchsize_scheduler = Scheduler({0: 64, 5: 32, 15: 16, 25: 8, 35: 4})
    # batchsize_scheduler = Scheduler({0: 64, 5: 64, 15: 64, 25: 64, 35: 32})

    # We create a scheduler for the Resize ops.
    resize_scheduler = Scheduler({
        0:
        Resize(inputs="x", size=(4, 4), outputs="x"),
        5:
        Resize(inputs="x", size=(8, 8), outputs="x"),
        15:
        Resize(inputs="x", size=(16, 16), outputs="x"),
        25:
        Resize(inputs="x", size=(32, 32), outputs="x"),
        35:
        Resize(inputs="x", size=(64, 64), outputs="x"),
        45:
        None
    })

    # In Pipeline, we use the schedulers for batch_size and ops.
    pipeline = fe.Pipeline(batch_size=batchsize_scheduler,
                           data=writer,
                           ops=[
                               resize_scheduler,
                               CreateLowRes(inputs="x", outputs="x_lowres"),
                               Rescale(inputs="x", outputs="x"),
                               Rescale(inputs="x_lowres", outputs="x_lowres")
                           ])

    opt2 = tf.keras.optimizers.Adam(learning_rate=0.001,
                                    beta_1=0.0,
                                    beta_2=0.99,
                                    epsilon=1e-8)
    opt3 = tf.keras.optimizers.Adam(learning_rate=0.001,
                                    beta_1=0.0,
                                    beta_2=0.99,
                                    epsilon=1e-8)
    opt4 = tf.keras.optimizers.Adam(learning_rate=0.001,
                                    beta_1=0.0,
                                    beta_2=0.99,
                                    epsilon=1e-8)
    opt5 = tf.keras.optimizers.Adam(learning_rate=0.001,
                                    beta_1=0.0,
                                    beta_2=0.99,
                                    epsilon=1e-8)
    opt6 = tf.keras.optimizers.Adam(learning_rate=0.001,
                                    beta_1=0.0,
                                    beta_2=0.99,
                                    epsilon=1e-8)
    opt7 = tf.keras.optimizers.Adam(learning_rate=0.001,
                                    beta_1=0.0,
                                    beta_2=0.99,
                                    epsilon=1e-8)
    fade_in_alpha = tf.Variable(initial_value=1.0,
                                dtype='float32',
                                trainable=False)
    d2, d3, d4, d5, d6, d7 = fe.build(
        model_def=lambda: build_D(fade_in_alpha=fade_in_alpha,
                                  target_resolution=7),
        model_name=["d2", "d3", "d4", "d5", "d6", "d7"],
        optimizer=[opt2, opt3, opt4, opt5, opt6, opt7],
        loss_name=["dloss", "dloss", "dloss", "dloss", "dloss", "dloss"])

    g2, g3, g4, g5, g6, g7, G = fe.build(
        model_def=lambda: build_G(fade_in_alpha=fade_in_alpha,
                                  target_resolution=7),
        model_name=["g2", "g3", "g4", "g5", "g6", "g7", "G"],
        optimizer=[opt2, opt3, opt4, opt5, opt6, opt7, opt7],
        loss_name=[
            "gloss", "gloss", "gloss", "gloss", "gloss", "gloss", "gloss"
        ])

    g_scheduler = Scheduler({
        0: ModelOp(model=g2, outputs="x_fake"),
        5: ModelOp(model=g3, outputs="x_fake"),
        15: ModelOp(model=g4, outputs="x_fake"),
        25: ModelOp(model=g5, outputs="x_fake"),
        35: ModelOp(model=g6, outputs="x_fake"),
        45: ModelOp(model=g7, outputs="x_fake"),
    })

    fake_score_scheduler = Scheduler({
        0:
        ModelOp(inputs="x_fake", model=d2, outputs="fake_score"),
        5:
        ModelOp(inputs="x_fake", model=d3, outputs="fake_score"),
        15:
        ModelOp(inputs="x_fake", model=d4, outputs="fake_score"),
        25:
        ModelOp(inputs="x_fake", model=d5, outputs="fake_score"),
        35:
        ModelOp(inputs="x_fake", model=d6, outputs="fake_score"),
        45:
        ModelOp(inputs="x_fake", model=d7, outputs="fake_score")
    })

    real_score_scheduler = Scheduler({
        0:
        ModelOp(model=d2, outputs="real_score"),
        5:
        ModelOp(model=d3, outputs="real_score"),
        15:
        ModelOp(model=d4, outputs="real_score"),
        25:
        ModelOp(model=d5, outputs="real_score"),
        35:
        ModelOp(model=d6, outputs="real_score"),
        45:
        ModelOp(model=d7, outputs="real_score")
    })

    interp_score_scheduler = Scheduler({
        0:
        ModelOp(inputs="x_interp",
                model=d2,
                outputs="interp_score",
                track_input=True),
        5:
        ModelOp(inputs="x_interp",
                model=d3,
                outputs="interp_score",
                track_input=True),
        15:
        ModelOp(inputs="x_interp",
                model=d4,
                outputs="interp_score",
                track_input=True),
        25:
        ModelOp(inputs="x_interp",
                model=d5,
                outputs="interp_score",
                track_input=True),
        35:
        ModelOp(inputs="x_interp",
                model=d6,
                outputs="interp_score",
                track_input=True),
        45:
        ModelOp(inputs="x_interp",
                model=d7,
                outputs="interp_score",
                track_input=True)
    })

    network = fe.Network(ops=[
        RandomInput(inputs=lambda: 512), g_scheduler, fake_score_scheduler,
        ImageBlender(inputs=(
            "x", "x_lowres"), alpha=fade_in_alpha), real_score_scheduler,
        Interpolate(inputs=("x_fake",
                            "x"), outputs="x_interp"), interp_score_scheduler,
        GradientPenalty(inputs=("x_interp", "interp_score"), outputs="gp"),
        GLoss(inputs="fake_score", outputs="gloss"),
        DLoss(inputs=("real_score", "fake_score", "gp"), outputs="dloss")
    ])

    estimator = fe.Estimator(network=network,
                             pipeline=pipeline,
                             epochs=55,
                             traces=[
                                 AlphaController(
                                     alpha=fade_in_alpha,
                                     fade_start=[5, 15, 25, 35, 45, 55],
                                     duration=[5, 5, 5, 5, 5, 5]),
                                 ImageSaving(epoch_model={
                                     4: "g2",
                                     14: "g3",
                                     24: "g4",
                                     34: "g5",
                                     44: "g6",
                                     54: "g7"
                                 },
                                             save_dir="/data/Xiaomeng/images")
                             ])
    return estimator
Exemplo n.º 7
0
def get_estimator(epochs=10,
                  batch_size=32,
                  epsilon=0.01,
                  warmup=0,
                  model_dir=tempfile.mkdtemp()):
    (x_train, y_train), (x_eval,
                         y_eval) = tf.keras.datasets.cifar10.load_data()
    data = {
        "train": {
            "x": x_train,
            "y": y_train
        },
        "eval": {
            "x": x_eval,
            "y": y_eval
        }
    }
    num_classes = 10

    pipeline = Pipeline(batch_size=batch_size,
                        data=data,
                        ops=Minmax(inputs="x", outputs="x"))

    model = FEModel(model_def=lambda: LeNet(input_shape=x_train.shape[1:],
                                            classes=num_classes),
                    model_name="LeNet",
                    optimizer="adam")

    adv_img = {
        warmup:
        AdversarialSample(inputs=("loss", "x"),
                          outputs="x_adverse",
                          epsilon=epsilon,
                          mode="train")
    }
    adv_eval = {
        warmup:
        ModelOp(inputs="x_adverse",
                model=model,
                outputs="y_pred_adverse",
                mode="train")
    }
    adv_loss = {
        warmup:
        SparseCategoricalCrossentropy(y_true="y",
                                      y_pred="y_pred_adverse",
                                      outputs="adverse_loss",
                                      mode="train")
    }
    adv_avg = {
        warmup:
        Average(inputs=("loss", "adverse_loss"), outputs="loss", mode="train")
    }

    network = Network(ops=[
        ModelOp(inputs="x", model=model, outputs="y_pred", track_input=True),
        SparseCategoricalCrossentropy(
            y_true="y", y_pred="y_pred", outputs="loss"),
        Scheduler(adv_img),
        Scheduler(adv_eval),
        Scheduler(adv_loss),
        Scheduler(adv_avg)
    ])

    traces = [
        Accuracy(true_key="y", pred_key="y_pred"),
        ConfusionMatrix(true_key="y",
                        pred_key="y_pred",
                        num_classes=num_classes),
        ModelSaver(model_name="LeNet", save_dir=model_dir, save_freq=2)
    ]

    estimator = Estimator(network=network,
                          pipeline=pipeline,
                          epochs=epochs,
                          traces=traces)

    return estimator