def _compute_moments(self, u_X): """ Tile the plates of the parent's moments. """ # Utilize broadcasting: If a tiled axis is unit length in u_X, there # is no need to tile it. u = list() for ind in range(len(u_X)): ui = u_X[ind] shape_u = np.shape(ui) if np.ndim(ui) > 0: # Add variable dimensions tiles_ind = tiles + (1,)*len(self.dims[ind]) # Utilize broadcasting: Do not tile leading empty axes nd = min(len(tiles_ind), np.ndim(ui)) tiles_ind = tiles_ind[(-nd):] # For simplicity, make tiles and shape equal length (tiles_ind, shape_u) = utils.make_equal_length(tiles_ind, shape_u) # Utilize broadcasting: Use tiling only if the parent's # moment has non-unit axis length. tiles_ind = [tile if sh > 1 else 1 for (tile, sh) in zip(tiles_ind, shape_u)] # Tile ui = np.tile(ui, tiles_ind) u.append(ui) return u
def _compute_mask_to_parent(self, index, mask): # Idea: Reshape the message array such that every other axis # will be summed and every other kept. # Make plates equal length plates = self._plates_to_parent(index) shape_m = np.shape(mask) (plates, tiles_m, shape_m) = utils.make_equal_length(plates, tiles, shape_m) # Handle broadcasting rules for axes that have unit length in # the message (although the plate may be non-unit length). Also, # compute the corresponding plate_multiplier. plates = list(plates) tiles_m = list(tiles_m) for j in range(len(plates)): if shape_m[j] == 1: plates[j] = 1 tiles_m[j] = 1 # Combine the tuples by picking every other from tiles_ind and # every other from shape shape = functools.reduce(lambda x, y: x + y, zip(tiles_m, plates)) # ..and reshape the array, that is, every other axis corresponds # to tiles and every other to plates/dimensions in parents mask = np.reshape(mask, shape) # Sum over every other axis axes = tuple(range(0, len(shape), 2)) mask = np.any(mask, axis=axes) # Remove extra leading axes ndim_parent = len(self.parents[index].plates) mask = utils.squeeze_to_dim(mask, ndim_parent) return mask
def _compute_message_to_parent(self, index, m, u_X): m = list(m) for ind in range(len(m)): # Idea: Reshape the message array such that every other axis # will be summed and every other kept. shape_ind = self._plates_to_parent(index) + self.dims[ind] # Add variable dimensions to tiles tiles_ind = tiles + (1,)*len(self.dims[ind]) # Make shape tuples equal length shape_m = np.shape(m[ind]) (tiles_ind, shape, shape_m) = utils.make_equal_length(tiles_ind, shape_ind, shape_m) # Handle broadcasting rules for axes that have unit length in # the message (although the plate may be non-unit length). Also, # compute the corresponding plate_multiplier. r = 1 shape = list(shape) tiles_ind = list(tiles_ind) for j in range(len(shape)): if shape_m[j] == 1: r *= tiles_ind[j] shape[j] = 1 tiles_ind[j] = 1 # Combine the tuples by picking every other from tiles_ind and # every other from shape shape = functools.reduce(lambda x,y: x+y, zip(tiles_ind, shape)) # ..and reshape the array, that is, every other axis corresponds # to tiles and every other to plates/dimensions in parents m[ind] = np.reshape(m[ind], shape) # Sum over every other axis axes = tuple(range(0,len(shape),2)) m[ind] = r * np.sum(m[ind], axis=axes) # Remove extra leading axes ndim_parent = len(self.parents[index].get_shape(ind)) m[ind] = utils.squeeze_to_dim(m[ind], ndim_parent) return m
def _compute_mask_to_parent(self, index, mask): # Idea: Reshape the message array such that every other axis # will be summed and every other kept. # Make plates equal length plates = self._plates_to_parent(index) shape_m = np.shape(mask) (plates, tiles_m, shape_m) = utils.make_equal_length(plates, tiles, shape_m) # Handle broadcasting rules for axes that have unit length in # the message (although the plate may be non-unit length). Also, # compute the corresponding plate_multiplier. plates = list(plates) tiles_m = list(tiles_m) for j in range(len(plates)): if shape_m[j] == 1: plates[j] = 1 tiles_m[j] = 1 # Combine the tuples by picking every other from tiles_ind and # every other from shape shape = functools.reduce(lambda x,y: x+y, zip(tiles_m, plates)) # ..and reshape the array, that is, every other axis corresponds # to tiles and every other to plates/dimensions in parents mask = np.reshape(mask, shape) # Sum over every other axis axes = tuple(range(0,len(shape),2)) mask = np.any(mask, axis=axes) # Remove extra leading axes ndim_parent = len(self.parents[index].plates) mask = utils.squeeze_to_dim(mask, ndim_parent) return mask