def test_assign_grouping_all_inputs_grouped(self): # Map ops to slices. self.op_slice_dict[self.id1_op] = [ self.id1_op_slice0, self.id1_op_slice1, self.id1_op_slice2, self.id1_op_slice3 ] # All inputs have groups. self.op_group_dict = { self.id1_op_slice0: self.id1_op_group, self.id1_op_slice1: self.id1_op_group, self.id1_op_slice2: self.id1_op_group, self.id1_op_slice3: self.id1_op_group, } # Call handler to assign grouping. handler = depth_to_space_op_handler.DepthToSpaceOpHandler() handler.assign_grouping(self.dts_op, self.mock_op_reg_manager) # Verify manager looks up OpSlice for ops of interest. self.mock_op_reg_manager.get_op_slices.assert_has_calls( # Checking for ops to process. [ mock.call(self.id1_op), mock.call(self.id2_op), # Reslicing. mock.call(self.id1_op), mock.call(self.dts_op), mock.call(self.id2_op), # Refreshing slice data. mock.call(self.dts_op), mock.call(self.id1_op) ]) # Verify manager groups DepthToSpace channel with individual input channels. self.mock_op_reg_manager.group_op_slices.assert_called_once_with([ self.id1_op_slice0, self.id1_op_slice1, self.id1_op_slice2, self.id1_op_slice3, self.dts_op_slice ]) # Verify manager processes grouping for identity ops. self.mock_op_reg_manager.process_ops.assert_called_once_with( [self.id2_op])
def test_assign_grouping_no_neighbor_groups(self): # No ops have groups. self.op_group_dict = {} # Call handler to assign grouping. handler = depth_to_space_op_handler.DepthToSpaceOpHandler() handler.assign_grouping(self.dts_op, self.mock_op_reg_manager) # Verify manager looks up OpSlice for ops of interest. self.mock_op_reg_manager.get_op_slices.assert_has_calls( [mock.call(self.id1_op), mock.call(self.id2_op)]) # Verify manager does not group. self.mock_op_reg_manager.group_op_slices.assert_not_called() # Verify manager processes grouping for identity ops. self.mock_op_reg_manager.process_ops.assert_called_once_with( [self.id1_op])