def test_dead_layer_remove(self): input_features = [("data", datatypes.Array(*(3, 4)))] output_features = [("out", None)] builder = neural_network.NeuralNetworkBuilder( input_features, output_features, disable_rank5_shape_mapping=True ) builder.add_activation("relu1", "RELU", "data", "relu1") builder.add_load_constant_nd( "const1", "c1", constant_value=np.ones((5,)), shape=(5,) ) builder.add_load_constant_nd( "const2", "c2", constant_value=np.ones((5,)), shape=(5,) ) builder.add_split_nd( "splitnd1", "const2", ["s1", "s2", "s3"], axis=0, num_splits=3 ) builder.add_squeeze("squeeze", "s1", "squeeze_out") builder.add_activation("relu4", "RELU", "s2", "relu4") builder.add_activation("relu5", "RELU", "relu4", "relu5") builder.add_load_constant_nd( "const3", "c3", constant_value=np.ones((5,)), shape=(5,) ) builder.add_activation("relu2", "RELU", "relu1", "out") spec = builder.spec np.testing.assert_equal(9, len(spec.neuralNetwork.layers)) remove_disconnected_layers(spec) np.testing.assert_equal(2, len(spec.neuralNetwork.layers))
def test_dead_layer_partial_branch(self): convergence_tolerance = 1e-8 input_features = [("input", datatypes.Array(*(2,)))] output_features = [("out", None)] builder = neural_network.NeuralNetworkBuilder( input_features, output_features, disable_rank5_shape_mapping=True ) # add condition to break from the loop, if convergence criterion is met builder.add_less_than("cond", ["input"], "cond", alpha=convergence_tolerance) branch_layer = builder.add_branch("branch_layer", "cond") builder_ifbranch = neural_network.NeuralNetworkBuilder( nn_spec=branch_layer.branch.ifBranch ) builder_ifbranch.add_activation("relu1", "RELU", "input", "relu1_out") builder_ifbranch.add_activation("relu2_out", "RELU", "relu1_out", "relu2_out") builder_elsebranch = neural_network.NeuralNetworkBuilder( nn_spec=branch_layer.branch.elseBranch ) builder_elsebranch.add_activation("linear1", "LINEAR", "input", "linear1_out") builder_elsebranch.add_activation( "linear_red_1", "LINEAR", "input", "linear_red1_out" ) builder_elsebranch.add_activation( "linear_red_2", "LINEAR", "linear_red1_out", "linear_red2_out" ) builder_elsebranch.add_activation( "linear2", "LINEAR", "linear1_out", "relu2_out" ) builder.add_squeeze("out", "relu2_out", "out", squeeze_all=True) mlmodel = MLModel(builder.spec, compute_units=ComputeUnit.CPU_ONLY) if not _IS_MACOS: # Can not get predictions unless on macOS. return data = np.random.rand(2,) data_dict = {"input": data} before_pass_out = mlmodel.predict(data_dict)["out"] if DEBUG: print("\n mlmodel description before remove disconnected layers pass: \n") print_network_spec(builder.spec, style="coding") old_spec = copy.copy(builder.spec) remove_disconnected_layers(builder.spec) if DEBUG: print("\n mlmodel description after remove disconnected layers pass: \n") print_network_spec(builder.spec, style="coding") mlmodel = MLModel(builder.spec, compute_units=ComputeUnit.CPU_ONLY) after_pass_out = mlmodel.predict(data_dict)["out"] np.testing.assert_almost_equal(before_pass_out, after_pass_out, decimal=2) np.testing.assert_equal( len(old_spec.neuralNetwork.layers[1].branch.ifBranch.layers), len(builder.spec.neuralNetwork.layers[1].branch.ifBranch.layers), ) np.testing.assert_equal( len(builder.spec.neuralNetwork.layers[1].branch.elseBranch.layers), 2 )
def test_load_constant_remove(): input_features = [("data", datatypes.Array(*(3, 4)))] output_features = [("out", None)] builder = neural_network.NeuralNetworkBuilder( input_features, output_features, disable_rank5_shape_mapping=True ) builder.add_activation("relu1", "RELU", "data", "relu1") builder.add_load_constant_nd( "const1", "c1", constant_value=np.ones((5,)), shape=(5,) ) builder.add_activation("relu2", "RELU", "relu1", "out") builder.add_load_constant_nd( "const2", "c2", constant_value=np.ones((5,)), shape=(5,) ) builder.add_load_constant_nd( "const3", "c3", constant_value=np.ones((5,)), shape=(5,) ) spec = builder.spec np.testing.assert_equal(5, len(spec.neuralNetwork.layers)) remove_disconnected_layers(spec) np.testing.assert_equal(2, len(spec.neuralNetwork.layers))