Beispiel #1
0
def test_loss_sensitivity(net_const: Callable, inp_arr: numpy.ndarray,
                          labs_arr: numpy.ndarray):
    with tf_compat.Graph().as_default():
        out, inp = net_const()
        labels = tf_compat.placeholder(tf_compat.float32,
                                       [None, *labs_arr.shape[1:]],
                                       name="logits")
        loss = batch_cross_entropy_loss(out, labels)
        op_vars = pruning_loss_sens_op_vars()

        with tf_compat.Session() as sess:
            sess.run(tf_compat.global_variables_initializer())

            def add_ops_creator(step: int):
                return []

            def feed_dict_creator(step: int):
                return {inp: inp_arr, labels: labs_arr}

            analysis = pruning_loss_sens_one_shot(op_vars, loss, 5,
                                                  add_ops_creator,
                                                  feed_dict_creator)

            for res in analysis.results:
                assert res.name
                assert isinstance(res.index, int)
                assert len(res.sparse_measurements) > 0
                assert len(res.averages) > 0
                assert res.sparse_average > 0
                assert res.sparse_integral > 0
Beispiel #2
0
def create_split_iterators_handle(split_datasets: Iterable) -> Tuple[Any, Any, List]:
    """
    Create an iterators handle for switching between datasets easily while training.

    :param split_datasets: the datasets to create the splits and handle for
    :return: a tuple containing the handle that should be set with a feed dict,
        the iterator used to get the next batch,
        and a list of the iterators created from the split_datasets
    """
    output_types = None
    output_shapes = None
    split_iterators = []

    for split_dataset in split_datasets:
        # get_output_types and shapes are not available in TF 1.13 and prior
        # hence the following conditional assignments
        output_types = (
            tf_compat.data.get_output_types(split_dataset)
            if hasattr(tf_compat.data, "get_output_types")
            else split_dataset.output_types
        )
        output_shapes = (
            tf_compat.data.get_output_shapes(split_dataset)
            if hasattr(tf_compat.data, "get_output_shapes")
            else split_dataset.output_shapes
        )
        split_iterators.append(_make_initializable_iterator(split_dataset))

    handle = tf_compat.placeholder(tf_compat.string, shape=[])
    iterator = tf_compat.data.Iterator.from_string_handle(
        handle, output_types, output_shapes
    )

    return handle, iterator, split_iterators
Beispiel #3
0
def test_mnist_registry(key: str, pretrained: Union[bool, str],
                        test_input: bool):
    with tf_compat.Graph().as_default():
        inputs = tf_compat.placeholder(tf_compat.float32, [None, 28, 28, 1],
                                       name="inputs")
        logits = ModelRegistry.create(key, inputs)

        with tf_compat.Session() as sess:
            if test_input:
                sess.run(tf_compat.global_variables_initializer())
                out = sess.run(
                    logits,
                    feed_dict={inputs: numpy.random.random((1, 28, 28, 1))})
                assert out.sum() != 0

            if pretrained:
                ModelRegistry.load_pretrained(key, pretrained)

                if test_input:
                    out = sess.run(logits,
                                   feed_dict={
                                       inputs: numpy.random.random(
                                           (1, 28, 28, 1))
                                   })
                    assert out.sum() != 0
Beispiel #4
0
def simple_matmul_net(init_weights):
    tf_compat.reset_default_graph()
    n_inputs = 28 * 28
    n_hidden1 = 300
    n_hidden2 = 100
    n_outputs = 10
    X = tf_compat.placeholder(tf_compat.float32, shape=(None, n_inputs), name="X")

    def neuron_layer(X, n_neurons, name, activation=None):
        with tf_compat.name_scope(name):
            n_inputs = int(X.get_shape()[1])
            stddev = 2 / np.sqrt(n_inputs)
            init = tf_compat.truncated_normal((n_inputs, n_neurons), stddev=stddev)
            W = tf_compat.Variable(init, name="kernel")
            b = tf_compat.Variable(tf_compat.zeros([n_neurons]), name="bias")
            Z = tf_compat.matmul(X, W) + b
            if activation is not None:
                return activation(Z)
            else:
                return Z

    with tf_compat.name_scope("dnn"):
        hidden1 = neuron_layer(
            X, n_hidden1, name="hidden1", activation=tf_compat.nn.relu
        )
        hidden2 = neuron_layer(
            hidden1, n_hidden2, name="hidden2", activation=tf_compat.nn.relu
        )
        neuron_layer(hidden2, n_outputs, name="outputs")
        return tf_compat.get_default_graph()
def test_trainable_params_modifier_with_training():
    modifier = TrainableParamsModifier(
        params=["mlp_net/fc1/weight"],
        trainable=False,
        params_strict=False,
    )
    manager = ScheduledModifierManager([modifier])
    steps_per_epoch = 5
    batch_size = 2

    with tf_compat.Graph().as_default() as graph:
        logits, inputs = mlp_net()
        labels = tf_compat.placeholder(tf_compat.float32,
                                       [None, *logits.shape[1:]])
        loss = batch_cross_entropy_loss(logits, labels)

        global_step = tf_compat.train.get_or_create_global_step()
        num_trainable_variabls_init = len(tf_compat.trainable_variables())

        mod_ops, mod_extras = manager.create_ops(steps_per_epoch)
        assert len(
            tf_compat.trainable_variables()) < num_trainable_variabls_init
        # Get the variables returned by the trainable_params modifier
        non_trainable_vars = mod_extras[EXTRAS_KEY_VAR_LIST]
        trainable_vars = tf_compat.trainable_variables()
        train_op = tf_compat.train.AdamOptimizer(learning_rate=1e-4).minimize(
            loss, global_step=global_step)

        with tf_compat.Session(graph=graph) as sess:
            sess.run(tf_compat.global_variables_initializer())
            manager.initialize_session(sess)
            init_non_trainable_vars = [
                var.eval(session=sess) for var in non_trainable_vars
            ]
            init_trainable_vars = [
                var.eval(session=sess) for var in trainable_vars
            ]
            batch_lab = numpy.random.random((batch_size, *logits.shape[1:]))
            batch_inp = numpy.random.random((batch_size, *inputs.shape[1:]))

            for epoch in range(10):
                for step in range(steps_per_epoch):
                    sess.run(train_op,
                             feed_dict={
                                 inputs: batch_inp,
                                 labels: batch_lab
                             })
                    sess.run(global_step)
            # Compare initial and final variable values
            for idx, init_non_trainable_var in enumerate(
                    init_non_trainable_vars):
                final_non_trainable_var = non_trainable_vars[idx].eval(
                    session=sess)
                assert numpy.array_equal(init_non_trainable_var,
                                         final_non_trainable_var)
            for idx, init_trainable_var in enumerate(init_trainable_vars):
                final_trainable_var = trainable_vars[idx].eval(session=sess)
                assert not numpy.array_equal(init_trainable_var,
                                             final_trainable_var)
            manager.complete_graph()
    def test_lifecycle(
        self,
        modifier_lambda: Callable[[], GMPruningModifier],
        graph_lambda: Callable[[], tf_compat.Graph],
        steps_per_epoch: int,
    ):
        modifier = modifier_lambda()
        graph = graph_lambda()
        with graph.as_default():
            global_step = tf_compat.train.get_or_create_global_step()
            step_placeholder = tf_compat.placeholder(dtype=tf_compat.int64,
                                                     name="step")
            global_assign = global_step.assign(step_placeholder)

            inp = graph.get_tensor_by_name("inp:0")
            out = graph.get_tensor_by_name("out:0")

            mod_ops, mod_extras = modifier.create_ops(steps_per_epoch,
                                                      global_step, graph)
            assert len(mod_ops) == 1
            assert mod_ops[0] is not None
            assert len(mod_extras) == 1
            assert EXTRAS_KEY_SUMMARIES in mod_extras
            assert modifier.prune_op_vars
            assert len(modifier.prune_op_vars) > 0
            last_sparsities = [0.0 for _ in range(len(modifier.prune_op_vars))]

            with tf_compat.Session(graph=graph) as sess:
                sess.run(tf_compat.global_variables_initializer())
                modifier.initialize_session(sess)
                step_counter = 0
                inp_arr = numpy.random.random((1, *inp.shape[1:]))

                for epoch in range(int(modifier.end_epoch + 5.0)):
                    for step in range(steps_per_epoch):
                        res = sess.run(out, feed_dict={inp: inp_arr})
                        assert res.sum() > 0

                        step_counter += 1
                        sess.run(global_assign,
                                 feed_dict={step_placeholder: step_counter})
                        sess.run(mod_ops)

                        for index, op_vars in enumerate(
                                modifier.prune_op_vars):
                            mask_sparsity = eval_tensor_sparsity(op_vars.mask)
                            masked_sparsity = eval_tensor_sparsity(
                                op_vars.masked)

                            assert abs(mask_sparsity - masked_sparsity) < 1e-5

                            if epoch < modifier.start_epoch:
                                assert masked_sparsity < 1e-2
                            else:
                                assert masked_sparsity == last_sparsities[
                                    index]
                                last_sparsities[index] = masked_sparsity

                modifier.complete_graph(graph, sess)
Beispiel #7
0
    def test_lifecycle(
        self,
        modifier_lambda: Callable[[], SetLearningRateModifier],
        graph_lambda: Callable[[], tf_compat.Graph],
        steps_per_epoch: int,
        optim_lambda,
    ):
        modifier = modifier_lambda()
        graph = graph_lambda()

        with graph.as_default():
            global_step = tf_compat.train.get_or_create_global_step()

            # Further set up for loss, optimizer and training op
            x_batch = graph.get_tensor_by_name("inp:0")
            y_pred = graph.get_tensor_by_name("out:0")
            n_inputs = x_batch.shape[1]
            n_outputs = y_pred.shape[1]
            y_lab = tf_compat.placeholder(tf_compat.float32,
                                          shape=(None, n_outputs),
                                          name="y")
            mod_ops, mod_extras = modifier.create_ops(steps_per_epoch,
                                                      global_step=global_step,
                                                      graph=graph)
            assert len(mod_ops) == 0
            assert len(mod_extras) == 2
            assert EXTRAS_KEY_LEARNING_RATE in mod_extras
            assert EXTRAS_KEY_SUMMARIES in mod_extras
            learning_rate = mod_extras[EXTRAS_KEY_LEARNING_RATE]

            with tf_compat.name_scope("train"):
                optimizer = optim_lambda(learning_rate=learning_rate)
                loss = tf_compat.losses.mean_squared_error(y_lab, y_pred)
                training_op = optimizer.minimize(loss, global_step=global_step)

        np.random.seed(12)
        batch_size = 8
        batch_x = np.random.randn(batch_size, n_inputs)
        batch_lab = np.random.randn(batch_size, n_outputs)

        with tf_compat.Session(graph=graph) as sess:
            sess.run(tf_compat.global_variables_initializer())
            for epoch in range(
                    int(max(modifier.start_epoch, modifier.end_epoch)) + 5):
                for step in range(steps_per_epoch):
                    gs = sess.run(global_step)
                    expected = modifier.learning_rate
                    optim_lr = sess.run(_get_lr(optimizer))
                    assert (
                        abs(optim_lr - expected) <= EPSILON
                    ), "Failed at epoch:{} step:{} global_step:{}".format(
                        epoch, step, gs)
                    sess.run(
                        training_op,
                        feed_dict={
                            x_batch: batch_x,
                            y_lab: batch_lab
                        },
                    )
Beispiel #8
0
def pruning_loss_sens_op_vars(
    graph: tf_compat.Graph = None,
    var_names: Union[List[str], Tuple[str]] = ("re:.*", ),
    mask_type: Union[str, List[int], PruningMaskCreator] = "unstructured",
) -> List[SparsePruningOpVars]:
    """
    Edit the graph for to inject pruning ops and vars to allow for a ks loss
    sensitivity analysis.

    Note: this must be run outside of a session for it to take effect.

    :param graph: the graph to inject pruning ops and vars into,
        if not supplied uses get_default_graph()
    :param var_names: List of variable names or regex patterns of variables to get
        the op vars for.  Defaults to matching all variables
    :param mask_type: String to define type of sparsity (options: ['unstructured',
        'channel', 'filter']), List to define block shape of a parameter's in and out
        channels, or a SparsityMaskCreator object. default is 'unstructured'
    :return: the created pruning op vars to be used in approx_ks_loss_sensitivity and
        one_shot_ks_loss_sensitivity
    """

    if not graph:
        graph = tf_compat.get_default_graph()

    mask_creator = mask_type
    if not isinstance(mask_type, PruningMaskCreator):
        mask_creator = load_mask_creator(mask_type)

    ks_group = pruning_loss_sens_one_shot.__name__
    prunable_ops_and_inputs = get_ops_and_inputs_by_name_or_regex(
        var_names, graph)
    op_vars = []

    with graph.as_default():
        for prune_op, prune_op_input in prunable_ops_and_inputs:
            with tf_compat.name_scope(
                    PruningScope.model(prune_op, ks_group,
                                       trailing_slash=True)):
                sparsity = tf_compat.placeholder(dtype=tf_compat.float32,
                                                 name="sparsity_placeholder")
                update = tf_compat.constant(True, tf_compat.bool)
            prune_op_var = create_op_pruning(
                prune_op,
                prune_op_input,
                sparsity,
                update,
                True,
                None,
                ks_group,
                mask_creator,
            )
            op_vars.append(SparsePruningOpVars(prune_op_var, sparsity))

    return op_vars
Beispiel #9
0
def test_apply_op_vars_masks(
    sparsity_val: float,
    net_const: Callable,
    inp_arr: numpy.ndarray,
    var_names: List[str],
):
    group = "test-group"

    with tf_compat.Graph().as_default() as graph:
        out, inp = net_const()
        sparsity = tf_compat.placeholder(dtype=tf_compat.float32,
                                         name="sparsity_placeholder")
        update_ready = tf_compat.placeholder(dtype=tf_compat.bool,
                                             name="update_ready")
        pruning_op_vars = get_or_create_graph_ops_pruning(
            graph,
            var_names,
            sparsity,
            update_ready,
            True,
            None,
            group,
            UnstructuredPruningMaskCreator(),
        )

        with tf_compat.Session() as sess:
            sess.run(tf_compat.global_variables_initializer())

            for op_vars in pruning_op_vars:
                sess.run(
                    op_vars.update,
                    feed_dict={
                        sparsity: sparsity_val,
                        update_ready: True
                    },
                )

            apply_op_vars_masks(pruning_op_vars, group, sess)

            for op_vars in pruning_op_vars:
                var_sparsity = eval_tensor_sparsity(op_vars.op_input)
                assert abs(var_sparsity - sparsity_val) < 1e-2
Beispiel #10
0
def test_resnets(key: str, pretrained: Union[bool, str], test_input: bool,
                 const: Callable):
    input_shape = ModelRegistry.input_shape(key)
    # test out the stand alone constructor
    with tf_compat.Graph().as_default():
        inputs = tf_compat.placeholder(tf_compat.float32, [None, *input_shape],
                                       name="inputs")
        logits = const(inputs, training=False)

        if test_input:
            with tf_compat.Session() as sess:
                sess.run(tf_compat.global_variables_initializer())
                out = sess.run(
                    logits,
                    feed_dict={inputs: numpy.random.random((1, *input_shape))})
                assert out.sum() != 0

    # test out the registry
    with tf_compat.Graph().as_default():
        inputs = tf_compat.placeholder(tf_compat.float32, [None, *input_shape],
                                       name="inputs")
        logits = ModelRegistry.create(key, inputs, training=False)

        with tf_compat.Session() as sess:
            if test_input:
                sess.run(tf_compat.global_variables_initializer())
                out = sess.run(
                    logits,
                    feed_dict={inputs: numpy.random.random((1, *input_shape))})
                assert out.sum() != 0

            if pretrained:
                ModelRegistry.load_pretrained(key, pretrained)

                if test_input:
                    out = sess.run(
                        logits,
                        feed_dict={
                            inputs: numpy.random.random((1, *input_shape))
                        },
                    )
                    assert out.sum() != 0
Beispiel #11
0
def mlp_net():
    inp = tf_compat.placeholder(tf_compat.float32, [None, 16], name="inp")

    with tf_compat.name_scope("mlp_net"):
        fc1 = _fc("fc1", inp, 16, 32)
        fc2 = _fc("fc2", fc1, 32, 64)
        fc3 = _fc("fc3", fc2, 64, 64, add_relu=False)

    out = tf_compat.sigmoid(fc3, name="out")

    return out, inp
Beispiel #12
0
def test_mnist():
    with tf_compat.Graph().as_default():
        inputs = tf_compat.placeholder(tf_compat.float32, [None, 28, 28, 1],
                                       name="inputs")
        logits = mnist_net(inputs)

        with tf_compat.Session() as sess:
            sess.run(tf_compat.global_variables_initializer())
            out = sess.run(
                logits,
                feed_dict={inputs: numpy.random.random((1, 28, 28, 1))})
            assert out.sum() != 0
Beispiel #13
0
def simple_conv2d_net(init_weights):
    tf_compat.reset_default_graph()
    X = tf_compat.placeholder(tf_compat.float32, [None, 32, 40, 1])
    W = tf_compat.Variable(
        tf_compat.convert_to_tensor(init_weights, dtype=tf_compat.float32)
    )
    b = tf_compat.Variable(tf_compat.random_normal([64]), dtype=tf_compat.float32)
    conv1 = tf_compat.nn.conv2d(X, W, strides=[1, 1, 1, 1], padding="VALID")
    conv1 = tf_compat.nn.bias_add(conv1, b)
    conv1 = tf_compat.nn.max_pool(
        conv1, ksize=[1, 1, 3, 1], strides=[1, 1, 1, 1], padding="VALID"
    )
    return tf_compat.get_default_graph()
Beispiel #14
0
def conv_net():
    inp = tf_compat.placeholder(tf_compat.float32, [None, 28, 28, 1], name="inp")

    with tf_compat.name_scope("conv_net"):
        conv1 = _conv("conv1", inp, 1, 32, 3, 2, "SAME")
        conv2 = _conv("conv2", conv1, 32, 32, 3, 2, "SAME")
        avg_pool = tf_compat.reduce_mean(conv2, axis=[1, 2])
        reshape = tf_compat.reshape(avg_pool, [-1, 32])
        mlp = _fc("mlp", reshape, 32, 10, add_relu=False)

    out = tf_compat.sigmoid(mlp, name="out")

    return out, inp
def export(
    args,
    save_dir,
    checkpoint_path=None,
    skip_samples=False,
    num_classes=None,
    opset=None,
):
    assert not skip_samples or num_classes
    # dataset creation
    if not skip_samples:
        val_dataset, num_classes = _create_dataset(args, train=False)

    with tf_compat.Graph().as_default():
        input_shape = ModelRegistry.input_shape(args.arch_key)
        inputs = tf_compat.placeholder(tf_compat.float32,
                                       [None] + list(input_shape),
                                       name="inputs")
        outputs = _create_model(args, num_classes, inputs)

        with tf_compat.Session() as sess:
            _load_model(args,
                        sess,
                        checkpoint_path=checkpoint_path
                        or args.checkpoint_path)

            exporter = GraphExporter(save_dir)

            if not skip_samples:
                # Export a batch of samples and expected outputs
                tf_dataset = val_dataset.build(args.num_samples,
                                               repeat_count=1,
                                               num_parallel_calls=1)
                tf_iter = tf_compat.data.make_one_shot_iterator(tf_dataset)
                features, _ = tf_iter.get_next()
                inputs_val = sess.run(features)
                exporter.export_samples([inputs], [inputs_val], [outputs],
                                        sess)

            # Export model to tensorflow checkpoint format
            LOGGER.info("exporting tensorflow in {}".format(save_dir))
            exporter.export_checkpoint(sess=sess)

            # Export model to pb format
            LOGGER.info("exporting pb in {}".format(exporter.pb_path))
            exporter.export_pb(outputs=[outputs])

    # Export model to onnx format
    LOGGER.info("exporting onnx in {}".format(exporter.onnx_path))
    exporter.export_onnx([inputs], [outputs], opset=opset or args.onnx_opset)
Beispiel #16
0
def test_multi_step_lr_schedule(start_step: int, milestone_steps: List[int],
                                init_lr: float, gamma: float):
    with tf_compat.Graph().as_default():
        global_step = tf_compat.placeholder(dtype=tf_compat.int64, shape=[])
        learning_rate = multi_step_lr_schedule(global_step, start_step,
                                               milestone_steps, init_lr, gamma)

        with tf_compat.Session() as sess:
            for step in range(start_step + milestone_steps[-1] + 10):
                measured = sess.run(learning_rate,
                                    feed_dict={global_step: step})

                gammas = sum([
                    1 for mile in milestone_steps if step >= mile + start_step
                ])
                expected = init_lr * gamma**gammas

                assert abs(measured - expected) < 1e-5
Beispiel #17
0
def test_step_lr_schedule(start_step: int, end_step: int, init_lr: float,
                          step_size: int, gamma: float):
    with tf_compat.Graph().as_default():
        global_step = tf_compat.placeholder(dtype=tf_compat.int64, shape=[])
        learning_rate = step_lr_schedule(global_step, start_step, end_step,
                                         step_size, init_lr, gamma)

        with tf_compat.Session() as sess:
            expected = init_lr

            for step in range(end_step + 10):
                measured = sess.run(learning_rate,
                                    feed_dict={global_step: step})

                if (step - start_step
                    ) % step_size == 0 and start_step < step <= end_step:
                    expected = expected * gamma

                assert abs(measured - expected) < 1e-5
Beispiel #18
0
def test_get_or_create_graph_ops_pruning(
    sparsity_val: float,
    net_const: Callable,
    inp_arr: numpy.ndarray,
    var_names: List[str],
    mask_creator: PruningMaskCreator,
):
    group = "test-group"
    is_grouped_mask = isinstance(mask_creator, GroupedPruningMaskCreator)

    with tf_compat.Graph().as_default() as graph:
        out, inp = net_const()
        sparsity = tf_compat.placeholder(dtype=tf_compat.float32,
                                         name="sparsity_placeholder")
        update_ready = tf_compat.placeholder(dtype=tf_compat.bool,
                                             name="update_ready")
        pruning_op_vars = get_or_create_graph_ops_pruning(
            graph,
            var_names,
            sparsity,
            update_ready,
            True,
            None,
            group,
            mask_creator,
        )
        pruning_op_vars_sec = get_or_create_graph_ops_pruning(
            graph,
            var_names,
            sparsity,
            update_ready,
            True,
            None,
            group,
            mask_creator,
        )

        assert len(pruning_op_vars) >= len(
            var_names)  # get at least 1 match per regex
        assert len(pruning_op_vars) == len(pruning_op_vars_sec)

        for op_vars, op_vars_sec in zip(pruning_op_vars, pruning_op_vars_sec):
            assert op_vars.op == op_vars_sec.op
            # import pdb
            # pdb.set_trace()
            assert op_vars.update == op_vars_sec.update
            assert op_vars.mask == op_vars_sec.mask
            assert op_vars.masked == op_vars_sec.masked

        with tf_compat.Session() as sess:
            sess.run(tf_compat.global_variables_initializer())
            for op_vars in pruning_op_vars:
                sess.run(
                    op_vars.update,
                    feed_dict={
                        sparsity: sparsity_val,
                        update_ready: False
                    },
                )
                print(op_vars.mask.shape)
                # When we reduce the values a mask can take, there can be higher error
                err_threshold = 1e-2 if not is_grouped_mask else 1e-1
                mask_sparsity = eval_tensor_sparsity(op_vars.mask)
                weight_sparsity = eval_tensor_sparsity(op_vars.op_input)
                assert mask_sparsity < err_threshold
                assert weight_sparsity == mask_sparsity

                masked_sparsity = eval_tensor_sparsity(op_vars.masked)
                assert masked_sparsity < err_threshold

                sess.run(
                    op_vars.update,
                    feed_dict={
                        sparsity: sparsity_val,
                        update_ready: True
                    },
                )

                mask_sparsity = eval_tensor_sparsity(op_vars.mask)
                assert abs(mask_sparsity - sparsity_val) < err_threshold

                masked_sparsity = eval_tensor_sparsity(op_vars.masked)
                assert abs(masked_sparsity - sparsity_val) < err_threshold

                res = sess.run(out, feed_dict={inp: inp_arr})
                assert res.sum() > 0.0

                if is_grouped_mask:
                    # Check that every value in the mask_creator grouping
                    # is the same within the mask.  Assumes grouping applies
                    # an absolte mean to each grouping
                    grouped_mask = mask_creator.group_tensor(op_vars.mask)
                    mask_vals_are_grouped = tf_compat.reduce_all(
                        tf_compat.logical_or(
                            tf_compat.equal(grouped_mask, 0.0),
                            tf_compat.equal(grouped_mask, 1.0),
                        ))
                    assert sess.run(mask_vals_are_grouped)
Beispiel #19
0
def test_get_or_create_ks_schedule_ops(
    begin_step: int,
    end_step: int,
    update_step_freq: int,
    init_sparsity: float,
    final_sparsity: float,
    exponent: float,
):
    group = "test-group"

    with tf_compat.Graph().as_default():
        global_step = tf_compat.train.get_or_create_global_step()
        step_placeholder = tf_compat.placeholder(dtype=tf_compat.int64,
                                                 name="step")
        global_assign = global_step.assign(step_placeholder)

        update_ready, sparsity = get_or_create_ks_schedule_ops(
            global_step,
            begin_step,
            end_step,
            update_step_freq,
            init_sparsity,
            final_sparsity,
            exponent,
            group,
        )
        update_ready_sec, sparsity_sec = get_or_create_ks_schedule_ops(
            global_step,
            begin_step,
            end_step,
            update_step_freq,
            init_sparsity,
            final_sparsity,
            exponent,
            group,
        )

        assert update_ready == update_ready_sec
        assert sparsity == sparsity_sec

        with tf_compat.Session() as sess:
            sess.run(tf_compat.global_variables_initializer())
            last_update_step = None
            last_update_sparsity = None

            for step in range(end_step + 10):
                sess.run(global_assign, feed_dict={step_placeholder: step})
                update_ready_val = sess.run(update_ready)
                sparsity_val = sess.run(sparsity)

                if step < begin_step:
                    assert not update_ready_val
                    assert abs(sparsity_val) < 1e-5
                elif step <= begin_step:
                    assert update_ready_val
                    assert abs(sparsity_val - init_sparsity) < 1e-5
                    last_update_step = step
                    last_update_sparsity = sparsity_val
                elif step == end_step:
                    assert update_ready_val
                    assert abs(sparsity_val - final_sparsity) < 1e-5
                    last_update_step = step
                    last_update_sparsity = sparsity_val
                elif step > end_step:
                    assert not update_ready_val
                    assert abs(sparsity_val - final_sparsity) < 1e-5
                else:
                    # check if update should be ready
                    check_ready = (last_update_step is None or
                                   step >= last_update_step + update_step_freq)
                    assert sparsity_val > last_update_sparsity

                    if check_ready:
                        assert update_ready_val
                        last_update_step = step
                        last_update_sparsity = sparsity_val
                    else:
                        assert not update_ready_val
Beispiel #20
0
def test_create_op_pruning_fc(sparsity_val):
    group = "test-group"

    with tf_compat.Graph().as_default() as graph:
        inp = tf_compat.placeholder(tf_compat.float32, [None, 64])

        with tf_compat.name_scope("fc"):
            weights = tf_compat.Variable(tf_compat.random_normal([64, 64]),
                                         name="weights")
            bias = tf_compat.Variable(tf_compat.random_normal([64]),
                                      name="bias")
            matmul = tf_compat.matmul(inp, weights, name="matmul")
            add = tf_compat.add(matmul, bias, name="bias_add")
            relu = tf_compat.nn.relu(add, name="relu")

        sparsity = tf_compat.placeholder(dtype=tf_compat.float32,
                                         name="sparsity_placeholder")
        update_ready = tf_compat.placeholder(dtype=tf_compat.bool,
                                             name="update_ready")

        matmul_op = graph.get_operation_by_name("fc/matmul")
        matmul_op_input = get_op_input_var(matmul_op, VAR_INDEX_FROM_TRAINABLE)
        pruning_op_vars = create_op_pruning(
            matmul_op,
            matmul_op_input,
            sparsity,
            update_ready,
            True,
            None,
            group,
            UnstructuredPruningMaskCreator(),
        )

        with tf_compat.Session() as sess:
            sess.run(tf_compat.global_variables_initializer())
            sess.run(
                pruning_op_vars.update,
                feed_dict={
                    sparsity: sparsity_val,
                    update_ready: False
                },
            )

            mask_sparsity = eval_tensor_sparsity(pruning_op_vars.mask)
            weight_sparsity = eval_tensor_sparsity(pruning_op_vars.op_input)
            assert mask_sparsity < 1e-3
            assert mask_sparsity == weight_sparsity

            masked_sparsity = eval_tensor_sparsity(pruning_op_vars.masked)
            assert masked_sparsity < 1e-3

            sess.run(
                pruning_op_vars.update,
                feed_dict={
                    sparsity: sparsity_val,
                    update_ready: True
                },
            )

            mask_sparsity = eval_tensor_sparsity(pruning_op_vars.mask)
            assert abs(mask_sparsity - sparsity_val) < 1e-3

            masked_sparsity = eval_tensor_sparsity(pruning_op_vars.masked)
            assert abs(masked_sparsity - sparsity_val) < 1e-3

            res = sess.run(relu, feed_dict={inp: numpy.random.random((4, 64))})
            assert res.sum() > 0.0
def pruning_loss_sensitivity(args, save_dir):
    input_shape = ModelRegistry.input_shape(args.arch_key)
    train_dataset, num_classes = _create_dataset(args,
                                                 train=True,
                                                 image_size=input_shape[1])
    with tf_compat.Graph().as_default() as graph:
        # create model graph
        inputs = tf_compat.placeholder(tf_compat.float32,
                                       [None] + list(input_shape),
                                       name="inputs")
        outputs = _create_model(args, num_classes, inputs)

        with tf_compat.Session() as sess:
            _load_model(args, sess, checkpoint_path=args.checkpoint_path)
            if args.approximate:
                LOGGER.info(
                    "Running weight magnitude loss sensitivity analysis...")
                analysis = pruning_loss_sens_magnitude(graph, sess)
            else:
                op_vars = pruning_loss_sens_op_vars(graph)
                train_steps = math.ceil(len(train_dataset) / args.batch_size)
                train_dataset = _build_dataset(args, train_dataset,
                                               args.batch_size)
                handle, iterator, dataset_iter = create_split_iterators_handle(
                    [train_dataset])
                dataset_iter = dataset_iter[0]
                images, labels = iterator.get_next()
                loss = batch_cross_entropy_loss(outputs, labels)
                tensor_names = ["inputs:0", labels.name]
                sess.run(dataset_iter.initializer)

                def feed_dict_creator(
                        step: int) -> Dict[str, tf_compat.Tensor]:
                    assert step < train_steps
                    batch_data = [
                        tens.eval(session=sess)
                        for tens in dataset_iter.get_next()
                    ]
                    return dict(zip(tensor_names, batch_data))

                LOGGER.info("Running one shot loss sensitivity analysis...")
                analysis = pruning_loss_sens_one_shot(
                    op_vars=op_vars,
                    loss_tensor=loss,
                    steps_per_measurement=args.steps_per_measurement,
                    feed_dict_creator=feed_dict_creator,
                    sess=sess,
                )
    # saving and printing results
    LOGGER.info("completed...")
    LOGGER.info("Saving results in {}".format(save_dir))
    analysis.save_json(
        os.path.join(
            save_dir,
            "ks_approx_sensitivity.json"
            if args.approximate else "ks_one_shot_sensitivity.json",
        ))
    analysis.plot(
        os.path.join(
            save_dir,
            os.path.join(
                save_dir,
                "ks_approx_sensitivity.png"
                if args.approximate else "ks_one_shot_sensitivity.png",
            ),
        ),
        plot_integral=True,
    )
    analysis.print_res()
def train(args, save_dir, logs_dir):
    # setup dataset
    with tf_compat.device("/cpu:0"):
        train_dataset, _ = _create_dataset(args, train=True)
        val_dataset, num_classes = _create_dataset(args, train=False)
        # calc steps
        train_steps = math.ceil(len(train_dataset) / args.train_batch_size)
        val_steps = math.ceil(len(val_dataset) / args.test_batch_size)
        # build datasets
        train_dataset = _build_dataset(args, train_dataset,
                                       args.train_batch_size)
        val_dataset = _build_dataset(args, val_dataset, args.test_batch_size)
    handle, iterator, (train_iter, val_iter) = create_split_iterators_handle(
        [train_dataset, val_dataset])

    # set up model graph
    images, labels = iterator.get_next()
    training = tf_compat.placeholder(dtype=tf_compat.bool, shape=[])
    outputs = _create_model(args, num_classes, images, training)

    # set up training objects
    loss = batch_cross_entropy_loss(outputs, labels)
    acc = accuracy(outputs, labels)
    global_step = tf_compat.train.get_or_create_global_step()
    train_op = tf_compat.train.AdamOptimizer(learning_rate=args.init_lr,
                                             **args.optim_args).minimize(
                                                 loss, global_step=global_step)
    update_ops = tf_compat.get_collection(tf_compat.GraphKeys.UPDATE_OPS)
    LOGGER.info("Created update ops for training")

    # set up sparseml modifier ops
    add_mods = (ConstantPruningModifier(
        params="__ALL__") if args.sparse_transfer_learn else None)
    manager = ScheduledModifierManager.from_yaml(file_path=args.recipe_path,
                                                 add_modifiers=add_mods)
    mod_ops, mod_extras = manager.create_ops(train_steps, global_step)

    with tf_compat.Session() as sess:
        # set up tensorboard logging
        summary_writer = tf_compat.summary.FileWriter(logs_dir, sess.graph)
        summaries = tf_compat.summary.merge_all()
        LOGGER.info("Logging to tensorboard at {}".format(logs_dir))

        # initialize variables, load pretrained weights, initialize modifiers
        train_iter_handle, val_iter_handle = sess.run(
            [train_iter.string_handle(),
             val_iter.string_handle()])
        LOGGER.info("Initialized graph variables")
        _load_model(args, sess)
        manager.initialize_session()
        LOGGER.info("Initialized SparseML modifiers")

        best_loss = None
        for epoch in range(manager.max_epochs):
            # train
            LOGGER.info("Training for epoch {}...".format(epoch))
            sess.run(train_iter.initializer)
            train_acc, train_loss = [], []
            for step in range(train_steps):
                _, __, meas_step, meas_loss, meas_acc, meas_summ = sess.run(
                    [train_op, update_ops, global_step, loss, acc, summaries],
                    feed_dict={
                        handle: train_iter_handle,
                        training: True
                    },
                )
                if step >= train_steps - 1:
                    # log the general summaries on the last training step
                    summary_writer.add_summary(meas_summ, meas_step)
                # run modifier ops
                sess.run(mod_ops)
                # summarize
                write_simple_summary(summary_writer, "Train/Loss", meas_loss,
                                     meas_step)
                write_simple_summary(summary_writer, "Train/Acc",
                                     meas_acc * 100.0, meas_step)
                train_acc.append(meas_acc)
                train_loss.append(meas_loss)
            LOGGER.info("Epoch {} - Train Loss: {}, Train Acc: {}".format(
                epoch,
                numpy.mean(train_loss).item(),
                numpy.mean(train_acc).item()))

            # val
            LOGGER.info("Validating for epoch {}...".format(epoch))
            sess.run(val_iter.initializer)
            val_acc, val_loss = [], []
            for step in range(val_steps):
                meas_loss, meas_acc = sess.run(
                    [loss, acc],
                    feed_dict={
                        handle: val_iter_handle,
                        training: False
                    },
                )
                val_acc.append(meas_acc)
                val_loss.append(meas_loss)
                write_simple_summary(summary_writer, "Val/Loss",
                                     numpy.mean(val_loss).item(), epoch)
                write_simple_summary(summary_writer, "Val/Acc",
                                     numpy.mean(val_acc).item(), epoch)
            val_loss = numpy.mean(val_loss).item()
            LOGGER.info("Epoch {} - Val Loss: {}, Val Acc: {}".format(
                epoch, val_loss,
                numpy.mean(train_acc).item()))
            if epoch >= args.save_best_after and (best_loss is None
                                                  or val_loss <= best_loss):
                _save_checkpoint(args, sess, save_dir, "checkpoint-best")
                best_loss = val_loss
            if args.save_epochs and epoch in args.save_epochs:
                _save_checkpoint(args, sess, save_dir,
                                 "checkpoint-epoch-{}".format(epoch))

        # cleanup graph and save final checkpoint
        manager.complete_graph()
        checkpoint_path = _save_checkpoint(args, sess, save_dir,
                                           "final-checkpoint")
    LOGGER.info("Running ONNX export flow")
    export(
        args,
        save_dir,
        checkpoint_path=checkpoint_path,
        skip_samples=True,
        num_classes=num_classes,
        opset=11,
    )
Beispiel #23
0
def test_get_or_create_ks_scheduled_graph_ops(
    begin_step: int,
    end_step: int,
    update_step_freq: int,
    init_sparsity: float,
    final_sparsity: float,
    exponent: float,
    net_const: Callable,
    inp_arr: numpy.ndarray,
    var_names: List[str],
):
    group = "test-group"

    with tf_compat.Graph().as_default() as graph:
        global_step = tf_compat.train.get_or_create_global_step()
        step_placeholder = tf_compat.placeholder(dtype=tf_compat.int64,
                                                 name="step")
        global_assign = global_step.assign(step_placeholder)

        out, inp = net_const()

        (
            update_op,
            pruning_op_vars,
            update_ready,
            sparsity,
        ) = get_or_create_ks_scheduled_graph_ops(
            graph,
            global_step,
            var_names,
            begin_step,
            end_step,
            update_step_freq,
            init_sparsity,
            final_sparsity,
            exponent,
            True,
            group,
            UnstructuredPruningMaskCreator(),
        )
        (
            update_op_sec,
            pruning_op_vars_sec,
            update_ready,
            sparsity,
        ) = get_or_create_ks_scheduled_graph_ops(
            graph,
            global_step,
            var_names,
            begin_step,
            end_step,
            update_step_freq,
            init_sparsity,
            final_sparsity,
            exponent,
            True,
            group,
            UnstructuredPruningMaskCreator(),
        )

        assert update_op == update_op_sec
        assert update_ready == update_ready
        assert sparsity == sparsity
        assert len(pruning_op_vars) == 3
        assert len(pruning_op_vars) >= len(
            var_names)  # at least 1 regex match per name

        for op_vars, op_vars_sec in zip(pruning_op_vars, pruning_op_vars_sec):
            assert op_vars.op == op_vars_sec.op
            assert op_vars.update == op_vars_sec.update
            assert op_vars.mask == op_vars_sec.mask
            assert op_vars.masked == op_vars_sec.masked

        with tf_compat.Session() as sess:
            sess.run(tf_compat.global_variables_initializer())
            last_update_sparsity = None

            for step in range(end_step + 10):
                sess.run(global_assign, feed_dict={step_placeholder: step})
                update_ready_val = sess.run(update_ready)
                sparsity_val = sess.run(sparsity)
                sess.run(update_op)

                for op_var in pruning_op_vars:
                    mask_sparsity = eval_tensor_sparsity(op_var.mask)
                    masked_sparsity = eval_tensor_sparsity(op_var.masked)
                    weight_sparsity = eval_tensor_sparsity(op_vars.op_input)

                    assert abs(mask_sparsity - masked_sparsity) < 1e-5

                    if step < begin_step:
                        assert abs(masked_sparsity) < 1e-2
                        assert not update_ready_val
                    elif step == begin_step:
                        assert abs(masked_sparsity - init_sparsity) < 1e-2
                        assert abs(sparsity_val - init_sparsity) < 1e-5
                        assert update_ready_val
                        last_update_sparsity = masked_sparsity
                    elif step == end_step:
                        assert update_ready_val
                        assert abs(masked_sparsity - final_sparsity) < 1e-2
                        assert abs(sparsity_val - final_sparsity) < 1e-5
                        last_update_sparsity = masked_sparsity
                    elif step > end_step:
                        assert not update_ready_val
                        assert abs(masked_sparsity - final_sparsity) < 1e-2
                    else:
                        assert masked_sparsity >= last_update_sparsity - 1e-2
                        assert sparsity_val >= last_update_sparsity - 1e-2
                        assert abs(weight_sparsity - masked_sparsity) <= 1e-2
                        last_update_sparsity = masked_sparsity
                        if step < end_step and update_ready_val:
                            steps_count = sess.run(global_step) - begin_step
                            steps_range = end_step - begin_step
                            expected = _expected_sparsity(
                                steps_count,
                                steps_range,
                                init_sparsity,
                                final_sparsity,
                                exponent,
                            )
                            assert abs(sparsity_val - expected) < 1e-5

                res = sess.run(out, feed_dict={inp: inp_arr})
                assert res.sum() >= 0.0
def test_gm_pruning_training_with_manager():
    modifier = GMPruningModifier(
        params=["mlp_net/fc1/weight", "mlp_net/fc3/weight"],
        init_sparsity=0.05,
        final_sparsity=0.8,
        start_epoch=2.0,
        end_epoch=7.0,
        update_frequency=1.0,
    )
    sec_modifier = GMPruningModifier(
        params=["mlp_net/fc2/weight"],
        init_sparsity=0.05,
        final_sparsity=0.8,
        start_epoch=2.0,
        end_epoch=7.0,
        update_frequency=1.0,
    )
    manager = ScheduledModifierManager([modifier, sec_modifier])
    steps_per_epoch = 5
    batch_size = 2

    with tf_compat.Graph().as_default() as graph:
        logits, inputs = mlp_net()
        labels = tf_compat.placeholder(tf_compat.float32,
                                       [None, *logits.shape[1:]])
        loss = batch_cross_entropy_loss(logits, labels)

        global_step = tf_compat.train.get_or_create_global_step()
        train_op = tf_compat.train.AdamOptimizer(learning_rate=1e-4).minimize(
            loss, global_step=global_step)

        mod_ops, mod_extras = manager.create_ops(steps_per_epoch)
        last_sparsities = [0.0 for _ in range(len(modifier.prune_op_vars))]

        with tf_compat.Session(graph=graph) as sess:
            sess.run(tf_compat.global_variables_initializer())
            manager.initialize_session(sess)
            batch_lab = numpy.random.random((batch_size, *logits.shape[1:]))
            batch_inp = numpy.random.random((batch_size, *inputs.shape[1:]))

            for epoch in range(int(modifier.end_epoch + 2.0)):
                for step in range(steps_per_epoch):
                    sess.run(train_op,
                             feed_dict={
                                 inputs: batch_inp,
                                 labels: batch_lab
                             })
                    sess.run(global_step)

                    sess.run(mod_ops)
                    update_ready_val = sess.run(modifier.update_ready)
                    sess.run(modifier.sparsity)

                    for index, op_vars in enumerate(modifier.prune_op_vars):
                        mask_sparsity = eval_tensor_sparsity(op_vars.mask)
                        masked_sparsity = eval_tensor_sparsity(op_vars.masked)

                        assert abs(mask_sparsity - masked_sparsity) < 1e-5

                        if epoch < modifier.start_epoch:
                            assert masked_sparsity < 1e-2
                            assert not update_ready_val
                        elif epoch >= modifier.end_epoch:
                            assert abs(masked_sparsity -
                                       modifier.final_sparsity) < 1e-2
                            assert not update_ready_val
                        else:
                            assert masked_sparsity >= last_sparsities[
                                index] - 1e-2
                            last_sparsities[index] = masked_sparsity

            manager.complete_graph()

            for op_vars in modifier.prune_op_vars:
                assert (abs(modifier.final_sparsity -
                            eval_tensor_sparsity(op_vars.op_input)) < 1e-2)
Beispiel #25
0
def test_create_op_pruning_conv(sparsity_val: float,
                                mask_creator: PruningMaskCreator):
    group = "test-group"
    is_grouped_mask = isinstance(mask_creator, GroupedPruningMaskCreator)
    with tf_compat.Graph().as_default() as graph:
        inp = tf_compat.placeholder(tf_compat.float32, [None, 8, 8, 64])

        with tf_compat.name_scope("conv"):
            weights = tf_compat.Variable(tf_compat.random_normal(
                [3, 3, 64, 64]),
                                         name="weights")
            bias = tf_compat.Variable(tf_compat.random_normal([64]),
                                      name="bias")
            conv = tf_compat.nn.conv2d(inp,
                                       weights,
                                       strides=[1, 1, 1, 1],
                                       padding="SAME",
                                       name="conv")
            add = tf_compat.add(conv, bias, name="bias_add")
            relu = tf_compat.nn.relu(add, name="relu")

        sparsity = tf_compat.placeholder(dtype=tf_compat.float32,
                                         name="sparsity_placeholder")
        update_ready = tf_compat.placeholder(dtype=tf_compat.bool,
                                             name="update_ready")

        conv_op = graph.get_operation_by_name("conv/conv")
        conv_op_input = get_op_input_var(conv_op, VAR_INDEX_FROM_TRAINABLE)
        pruning_op_vars = create_op_pruning(
            conv_op,
            conv_op_input,
            sparsity,
            update_ready,
            True,
            None,
            group,
            mask_creator=mask_creator,
        )

        with tf_compat.Session() as sess:
            sess.run(tf_compat.global_variables_initializer())
            sess.run(
                pruning_op_vars.update,
                feed_dict={
                    sparsity: sparsity_val,
                    update_ready: False
                },
            )

            err_threshold = 1e-3 if not is_grouped_mask else 0.05

            mask_sparsity = eval_tensor_sparsity(pruning_op_vars.mask)
            weight_sparsity = eval_tensor_sparsity(pruning_op_vars.op_input)
            assert mask_sparsity < err_threshold
            assert abs(mask_sparsity - weight_sparsity) <= 1e-4

            masked_sparsity = eval_tensor_sparsity(pruning_op_vars.masked)
            assert masked_sparsity < err_threshold

            sess.run(
                pruning_op_vars.update,
                feed_dict={
                    sparsity: sparsity_val,
                    update_ready: True
                },
            )

            mask_sparsity = eval_tensor_sparsity(pruning_op_vars.mask)
            assert abs(mask_sparsity - sparsity_val) < err_threshold

            masked_sparsity = eval_tensor_sparsity(pruning_op_vars.masked)
            assert abs(masked_sparsity - sparsity_val) < err_threshold

            res = sess.run(relu,
                           feed_dict={inp: numpy.random.random((4, 8, 8, 64))})
            assert res.sum() > 0.0

            if is_grouped_mask:
                # Check that every value in the mask_creator grouping
                # is the same within the mask.  Assumes grouping applies
                # an absolte mean to each grouping
                grouped_mask = mask_creator.group_tensor(pruning_op_vars.mask)
                mask_vals_are_grouped = tf_compat.reduce_all(
                    tf_compat.logical_or(
                        tf_compat.equal(grouped_mask, 0.0),
                        tf_compat.equal(grouped_mask, 1.0),
                    ))
                assert sess.run(mask_vals_are_grouped)
Beispiel #26
0
def test_lrs_with_manager(optim_lambda):
    manager = ScheduledModifierManager(modifiers=[
        SetLearningRateModifier(learning_rate=0.1, start_epoch=0),
        LearningRateModifier(
            lr_class="ExponentialLR",
            lr_kwargs={"gamma": 0.9},
            start_epoch=5,
            end_epoch=10,
            init_lr=0.01,
        ),
        LearningRateModifier(
            lr_class="MultiStepLR",
            lr_kwargs={
                "gamma": 0.95,
                "milestones": [15, 18]
            },
            start_epoch=12,
            end_epoch=20,
            init_lr=0.05,
        ),
    ])
    assert manager.max_epochs == 20
    assert manager.min_epochs == 0
    graph = mlp_graph_lambda()
    steps_per_epoch = 100

    with graph.as_default():
        global_step = tf_compat.train.get_or_create_global_step()

        # Further set up for loss, optimizer and training op
        x_batch = graph.get_tensor_by_name("inp:0")
        y_pred = graph.get_tensor_by_name("out:0")
        n_inputs = x_batch.shape[1]
        n_outputs = y_pred.shape[1]
        y_lab = tf_compat.placeholder(tf_compat.float32,
                                      shape=(None, n_outputs),
                                      name="y")
        mod_ops, mod_extras = manager.create_ops(steps_per_epoch,
                                                 global_step=global_step,
                                                 graph=graph)
        assert len(mod_ops) == 1
        assert len(mod_extras) == 2
        assert EXTRAS_KEY_LEARNING_RATE in mod_extras
        assert EXTRAS_KEY_SUMMARIES in mod_extras
        assert isinstance(mod_extras[EXTRAS_KEY_SUMMARIES], list)
        assert len(mod_extras[EXTRAS_KEY_SUMMARIES]) == 1
        learning_rate = mod_extras[EXTRAS_KEY_LEARNING_RATE]

        with tf_compat.name_scope("train"):
            optimizer = optim_lambda(learning_rate=learning_rate)
            loss = tf_compat.losses.mean_squared_error(y_lab, y_pred)
            training_op = optimizer.minimize(loss, global_step=global_step)

    np.random.seed(12)
    batch_size = 8
    batch_x = np.random.randn(batch_size, n_inputs)
    batch_lab = np.random.randn(batch_size, n_outputs)

    with tf_compat.Session(graph=graph) as sess:
        sess.run(tf_compat.global_variables_initializer())

        for epoch in range(manager.max_epochs + 5):
            # for now hardcoding the tests to get out the door
            if epoch < 5:
                expected = 0.1
            elif epoch < 10:
                expected = 0.01 * 0.9**(epoch - 5)
            elif epoch < 12:
                expected = 0.01 * 0.9**4
            elif epoch < 15:
                expected = 0.05
            elif epoch < 18:
                expected = 0.05 * 0.95
            else:
                expected = 0.05 * 0.95**2

            for step in range(steps_per_epoch):
                gs = sess.run(global_step)
                optim_lr = sess.run(_get_lr(optimizer))
                assert (abs(optim_lr - expected) <= EPSILON
                        ), "Failed at epoch:{} step:{} global_step:{}".format(
                            epoch, step, gs)
                sess.run(
                    training_op,
                    feed_dict={
                        x_batch: batch_x,
                        y_lab: batch_lab
                    },
                )