Ejemplo n.º 1
0
class AndOrRegion(object):
  AND_NODE = True
  OR_NODE = False
  
  def __init__(self, image_shape, estimated_num_nodes=1000):
    self.image_shape = image_shape
    self.node_counts = numpy.zeros(image_shape, 'int')
    self.node_start_pos = numpy.zeros(image_shape, 'int')
    self.node_values = DynamicArray((estimated_num_nodes,), 'float32')
    self.cxn_start_pos = DynamicArray((estimated_num_nodes,), 'int')
    self.node_cxns = DynamicArray((estimated_num_nodes,3), 'int')
    self.node_types = DynamicArray((estimated_num_nodes,), 'bool')
    self.node_pos = DynamicArray((estimated_num_nodes,3), 'int')
    self.image_stride = self.create_stride(image_shape)
    self.num_nodes = 0
    self.nodes = PriorityQueue()
    
    # c buffers
    self.num_vals_buf = numpy.empty((1,), 'int')
    
  def set_child(self, region):
    self.child_region = region
  
  def prepare_for_inference(self):
    """ Creates all C-accessible properties. This assumes that all nodes have
    been created already. """
    lin_idx = 0
    cxn_pos = 0
    self.num_nodes = len(self.nodes)
    nodes = self.nodes.copy()
    while not nodes.is_empty():
      position, type, cxns, value = nodes.pop()
      self.node_types[lin_idx] = type
      self.node_values[lin_idx] = value
      self.node_pos[lin_idx] = position
      
      if cxns is None:
        num_cxns = 0
      else:
        num_cxns = len(cxns)
        self.node_cxns[cxn_pos : cxn_pos+num_cxns] = cxns
      
      lin_idx += 1
      cxn_pos += num_cxns
      self.cxn_start_pos[lin_idx] = cxn_pos
      
    # Define all node start positions
    self.node_start_pos.flat[1:] = numpy.cumsum(self.node_counts)[:-1]
    self.node_start_pos.shape = self.node_counts.shape
    
  def create_node(self, position, type=AND_NODE, cxns=None, value=0):
    """ Updates all arrays to insert the node. """
    flat_idx = numpy.dot(position, self.image_stride)
    depth = self.node_counts[position]
    self.node_counts[position] = depth + 1
    full_pos = position + (depth,)
    self.nodes.push((full_pos, type, cxns, value), flat_idx)
    return full_pos
  
  def create_stride(self, shape):
    """ Create the stride vector for this shape"""
    s_row = numpy.cumprod(shape[::-1])[::-1] # start end step
    s_row = numpy.concatenate((s_row[1:], [1]))
    return s_row
  
  def _get_lin_idx(self, idx):
    return self.node_start_pos[idx[:2]] + idx[2]
  
  def do_inference(self):
    """ Computes the values of all nodes given the input activations. """

    args = [ self.num_nodes,
             self.node_values,
             self.node_types,
             self.cxn_start_pos,
             self.node_cxns,
             
             self.child_region.node_start_pos,
             self.child_region.node_values ]

    c_methods.do_inference(*DynamicArray.csafe(args))
  
  def get_node_value(self, idx):
    """ idx = (x, y, depth) """
    return self.node_values[self._get_lin_idx(idx)]
  
  def set_node_value(self, idx, value):
    """ idx = (x, y, depth) """
    self.node_values[self._get_lin_idx(idx)] = value
  
  def get_node_cxns(self, idx):
    """ Gets all incoming connections to the node at idx = (x, y, depth) """
    lin_idx = self._get_lin_idx(idx)
    start = self.cxn_start_pos[lin_idx]
    end = self.cxn_start_pos[lin_idx+1]
    return self.node_cxns[start : end]
  
  def get_active_nodes(self):
    """ Returns a list of the indices of all active nodes. """
    nzn = self.node_values.nonzero()[0]
    return [self.node_pos[i] for i in nzn]
  
  def get_num_cxns(self):
    """ Returns the total number of connections across all nodes. """
    return self.cxn_start_pos[self.num_nodes]
  
  def get_window_values(self, pos, shape, values=None):
    """ Returns a numpy array containing all values within the window at the 
    given position and with width, height = shape. If values is not None,
    then the values given will be used instead of self.node_values. """
    if values is None:
      values = self.node_values
    
    # Assert valid window
    assert (pos[0] + shape[0]) <= self.image_shape[0]
    assert (pos[1] + shape[1]) <= self.image_shape[1]
    
    # First compute the number of elements in the array
    c_methods.get_num_window_values(pos[0], pos[1], shape[0], shape[1],
                                    self.node_counts, self.num_vals_buf)
    
    # Create the array, and assign all values to it
    window_values = numpy.empty((self.num_vals_buf[0],), 'float32')
    args = [pos[0], pos[1], shape[0], shape[1],
            self.node_start_pos,
            self.node_counts,
            values, 
            window_values]
    c_methods.get_window_values(*DynamicArray.csafe(args))
    
    return window_values
  
  def get_window_nonzeros(self, pos, shape):
    """ Returns a list of the nonzero values within the given window. """
    values = self.node_values
    
    # Assert valid window
    assert (pos[0] + shape[0]) <= self.image_shape[0]
    assert (pos[1] + shape[1]) <= self.image_shape[1]
    
    # First compute the number of elements in the array
    args = [pos[0], pos[1], shape[0], shape[1],
            self.node_start_pos,
            self.node_counts,
            values, 
            self.num_vals_buf]
    c_methods.get_num_window_nonzeros(*DynamicArray.csafe(args))
    
    # Create the array, and assign all values to it
    nonzeros = numpy.empty((self.num_vals_buf[0],3), 'int')
    args = [pos[0], pos[1], shape[0], shape[1],
            self.node_start_pos,
            self.node_counts,
            values,
            nonzeros]
    c_methods.get_window_nonzeros(*DynamicArray.csafe(args))
    
    return nonzeros
  
  def __len__(self):
    """ Reports the number of nodes in this region. """
    return self.num_nodes