def getaxes_broadcast(obj, indices): """ broadcast array-indices & integers, numpy's classical Examples -------- >>> import dimarray as da >>> a = da.zeros(shape=(3,4,5,6)) >>> a.take((slice(None),[0, 1],slice(None),2), broadcast=True).shape (2, 3, 5) >>> a.take((slice(None),[0, 1],2,slice(None)), broadcast=True).shape (3, 2, 6) """ from dimarray import Axis, Axes # new axes: broacast indices (should do the same as above, since integers are just broadcast) indices2 = broadcast_indices(indices) # assert np.all(newval == obj.values[indices2]) # make a multi-axis with tuples is_array2 = np.array([np.iterable(ix) for ix in indices2]) nb_array2 = is_array2.sum() # If none or one array is present, easy if nb_array2 <= 1: newaxes = [obj.axes[i][ix] for i, ix in enumerate(indices) if not np.isscalar(ix)] # indices or indices2, does not matter # else, finer check needed else: # same stats but on original indices is_array = np.array([np.iterable(ix) for ix in indices]) array_ix_pos = np.where(is_array)[0] # Determine where the axis will be inserted # - need to consider the integers as well (broadcast as arrays) # - if two indexed dimensions are not contiguous, new axis placed at first position... # obj = zeros((3,4,5,6)) # obj[:,[1,2],:,0].shape ==> (2, 3, 5) # obj[:,[1,2],0,:].shape ==> (3, 2, 6) array_ix_pos2 = np.where(is_array2)[0] if np.any(np.diff(array_ix_pos2) > 1): # that mean, if two indexed dimensions are not contiguous insert = 0 else: insert = array_ix_pos2[0] # Now determine axis value # ...if originally only one array was provided, use these values correspondingly if len(array_ix_pos) == 1: i = array_ix_pos[0] values = obj.axes[i].values[indices[i]] name = obj.axes[i].name # ...else use a list of tuples else: values = list(zip(*[obj.axes[i].values[indices2[i]] for i in array_ix_pos])) name = ",".join([obj.axes[i].name for i in array_ix_pos]) broadcastaxis = Axis(values, name) newaxes = Axes() for i, ax in enumerate(obj.axes): # axis is already part of the broadcast axis: skip if is_array2[i]: continue else: newaxis = ax[indices2[i]] ## do not append axis if scalar #if np.isscalar(newaxis): # continue newaxes.append(newaxis) # insert the right new axis at the appropriate position newaxes.insert(insert, broadcastaxis) return newaxes