def test_assign_grouping_all_inputs_grouped(self):
        # Map ops to slices.
        self.op_slice_dict[self.id1_op] = [
            self.id1_op_slice0, self.id1_op_slice1, self.id1_op_slice2,
            self.id1_op_slice3
        ]

        # All inputs have groups.
        self.op_group_dict = {
            self.id1_op_slice0: self.id1_op_group,
            self.id1_op_slice1: self.id1_op_group,
            self.id1_op_slice2: self.id1_op_group,
            self.id1_op_slice3: self.id1_op_group,
        }

        # Call handler to assign grouping.
        handler = depth_to_space_op_handler.DepthToSpaceOpHandler()
        handler.assign_grouping(self.dts_op, self.mock_op_reg_manager)

        # Verify manager looks up OpSlice for ops of interest.
        self.mock_op_reg_manager.get_op_slices.assert_has_calls(
            # Checking for ops to process.
            [
                mock.call(self.id1_op),
                mock.call(self.id2_op),
                # Reslicing.
                mock.call(self.id1_op),
                mock.call(self.dts_op),
                mock.call(self.id2_op),
                # Refreshing slice data.
                mock.call(self.dts_op),
                mock.call(self.id1_op)
            ])

        # Verify manager groups DepthToSpace channel with individual input channels.
        self.mock_op_reg_manager.group_op_slices.assert_called_once_with([
            self.id1_op_slice0, self.id1_op_slice1, self.id1_op_slice2,
            self.id1_op_slice3, self.dts_op_slice
        ])

        # Verify manager processes grouping for identity ops.
        self.mock_op_reg_manager.process_ops.assert_called_once_with(
            [self.id2_op])
    def test_assign_grouping_no_neighbor_groups(self):
        # No ops have groups.
        self.op_group_dict = {}

        # Call handler to assign grouping.
        handler = depth_to_space_op_handler.DepthToSpaceOpHandler()
        handler.assign_grouping(self.dts_op, self.mock_op_reg_manager)

        # Verify manager looks up OpSlice for ops of interest.
        self.mock_op_reg_manager.get_op_slices.assert_has_calls(
            [mock.call(self.id1_op),
             mock.call(self.id2_op)])

        # Verify manager does not group.
        self.mock_op_reg_manager.group_op_slices.assert_not_called()

        # Verify manager processes grouping for identity ops.
        self.mock_op_reg_manager.process_ops.assert_called_once_with(
            [self.id1_op])