Exemple #1
0
class SequentialSupportGraph(object):
  """An analog of Keras Sequential model for test/support models."""
  def __init__(self, n_feat):
    """
    Parameters
    ----------
    n_feat: int
      Number of atomic features.
    """
    # Create graph topology and x
    self.test_graph_topology = GraphTopology(n_feat, name='test')
    self.support_graph_topology = GraphTopology(n_feat, name='support')
    self.test = self.test_graph_topology.get_atom_features_placeholder()
    self.support = self.support_graph_topology.get_atom_features_placeholder()

    # Keep track of the layers
    self.layers = []  
    # Whether or not we have used the GraphGather layer yet
    self.bool_pre_gather = True  

  def add(self, layer):
    """Adds a layer to both test/support stacks.

    Note that the layer transformation is performed independently on the
    test/support tensors.
    """
    self.layers.append(layer)

    # Update new value of x
    if type(layer).__name__ in ['GraphConv', 'GraphGather', 'GraphPool']:
      assert self.bool_pre_gather, "Cannot apply graphical layers after gather."
          
      self.test = layer([self.test] + self.test_graph_topology.topology)
      self.support = layer([self.support] + self.support_graph_topology.topology)
    else:
      self.test = layer(self.test)
      self.support = layer(self.support)

    if type(layer).__name__ == 'GraphGather':
      self.bool_pre_gather = False  # Set flag to stop adding topology

  def add_test(self, layer):
    """Adds a layer to test."""
    self.layers.append(layer)

    # Update new value of x
    if type(layer).__name__ in ['GraphConv', 'GraphPool', 'GraphGather']:
      self.test = layer([self.test] + self.test_graph_topology.topology)
    else:
      self.test = layer(self.test)

  def add_support(self, layer):
    """Adds a layer to support."""
    self.layers.append(layer)

    # Update new value of x
    if type(layer).__name__ in ['GraphConv', 'GraphPool', 'GraphGather']:
      self.support = layer([self.support] + self.support_graph_topology.topology)
    else:
      self.support = layer(self.support)

  def join(self, layer):
    """Joins test and support to a two input two output layer"""
    self.layers.append(layer)
    self.test, self.support = layer([self.test, self.support])

  def get_test_output(self):
    return self.test

  def get_support_output(self):
    return self.support
  
  def return_outputs(self):
    return [self.test] + [self.support]

  def return_inputs(self):
    return (self.test_graph_topology.get_inputs()
            + self.support_graph_topology.get_inputs())