def test_create_mask_basic_indexer(): indexer = indexing.BasicIndexer((-1, )) actual = indexing.create_mask(indexer, (3, )) np.testing.assert_array_equal(True, actual) indexer = indexing.BasicIndexer((0, )) actual = indexing.create_mask(indexer, (3, )) np.testing.assert_array_equal(False, actual)
def get_indexers(shape, mode): if mode == "vectorized": indexed_shape = (3, 4) indexer = tuple( np.random.randint(0, s, size=indexed_shape) for s in shape) return indexing.VectorizedIndexer(indexer) elif mode == "outer": indexer = tuple(np.random.randint(0, s, s + 2) for s in shape) return indexing.OuterIndexer(indexer) elif mode == "outer_scalar": indexer = (np.random.randint(0, 3, 4), 0, slice(None, None, 2)) return indexing.OuterIndexer(indexer[:len(shape)]) elif mode == "outer_scalar2": indexer = (np.random.randint(0, 3, 4), -2, slice(None, None, 2)) return indexing.OuterIndexer(indexer[:len(shape)]) elif mode == "outer1vec": indexer = [slice(2, -3) for s in shape] indexer[1] = np.random.randint(0, shape[1], shape[1] + 2) return indexing.OuterIndexer(tuple(indexer)) elif mode == "basic": # basic indexer indexer = [slice(2, -3) for s in shape] indexer[0] = 3 return indexing.BasicIndexer(tuple(indexer)) elif mode == "basic1": # basic indexer return indexing.BasicIndexer((3, )) elif mode == "basic2": # basic indexer indexer = [0, 2, 4] return indexing.BasicIndexer(tuple(indexer[:len(shape)])) elif mode == "basic3": # basic indexer indexer = [slice(None) for s in shape] indexer[0] = slice(-2, 2, -2) indexer[1] = slice(1, -1, 2) return indexing.BasicIndexer(tuple(indexer[:len(shape)]))
def test_unwrap_explicit_indexer(): indexer = indexing.BasicIndexer((1, 2)) target = None unwrapped = indexing.unwrap_explicit_indexer( indexer, target, allow=indexing.BasicIndexer) assert unwrapped == (1, 2) with raises_regex(NotImplementedError, 'Load your data'): indexing.unwrap_explicit_indexer( indexer, target, allow=indexing.OuterIndexer) with raises_regex(TypeError, 'unexpected key type'): indexing.unwrap_explicit_indexer( indexer.tuple, target, allow=indexing.OuterIndexer)