def __init__(self, *args, **kwargs): """Constructor.""" super(Ggnn, self).__init__(*args, **kwargs) # set some global config values self.dev = (torch.device("cuda") if torch.cuda.is_available() and FLAGS.cuda else torch.device("cpu")) app.Log(1, "Using device %s with dtype %s", self.dev, torch.get_default_dtype()) # Instantiate model config = GGNNConfig( num_classes=self.y_dimensionality, has_graph_labels=self.graph_db.graph_y_dimensionality > 0, ) inst2vec_embeddings = node_encoder.GraphEncoder().embeddings_tables[0] inst2vec_embeddings = torch.from_numpy( np.array(inst2vec_embeddings, dtype=np.float32)) self.model = GGNNModel( config, pretrained_embeddings=inst2vec_embeddings, test_only=FLAGS.test_only, ) if DEBUG: for submodule in self.model.modules(): submodule.register_forward_hook(nan_hook) self.model.to(self.dev)
class Ggnn(classifier_base.ClassifierBase): """A gated graph neural network.""" def __init__(self, *args, **kwargs): """Constructor.""" super(Ggnn, self).__init__(*args, **kwargs) # set some global config values self.dev = (torch.device("cuda") if torch.cuda.is_available() and FLAGS.cuda else torch.device("cpu")) app.Log(1, "Using device %s with dtype %s", self.dev, torch.get_default_dtype()) # Instantiate model config = GGNNConfig( num_classes=self.y_dimensionality, has_graph_labels=self.graph_db.graph_y_dimensionality > 0, ) inst2vec_embeddings = node_encoder.GraphEncoder().embeddings_tables[0] inst2vec_embeddings = torch.from_numpy( np.array(inst2vec_embeddings, dtype=np.float32)) self.model = GGNNModel( config, pretrained_embeddings=inst2vec_embeddings, test_only=FLAGS.test_only, ) if DEBUG: for submodule in self.model.modules(): submodule.register_forward_hook(nan_hook) self.model.to(self.dev) def MakeBatch( self, epoch_type: epoch.Type, graphs: Iterable[graph_tuple_database.GraphTuple], ctx: progress.ProgressContext = progress.NullContext, ) -> batches.Data: """Create a mini-batch of data from an iterator of graphs. Returns: A single batch of data for feeding into RunBatch(). A batch consists of a list of graph IDs and a model-defined blob of data. If the list of graph IDs is empty, the batch is discarded and not fed into RunBatch(). """ # TODO(github.com/ChrisCummins/ProGraML/issues/24): The new graph batcher # implementation is not well suited for reading the graph IDs, hence this # somewhat clumsy iterator wrapper. A neater approach would be to create # a graph batcher which returns a list of graphs in the batch. class GraphIterator(object): """A wrapper around a graph iterator which records graph IDs.""" def __init__(self, graphs: Iterable[graph_tuple_database.GraphTuple]): self.input_graphs = graphs self.graphs_read: List[graph_tuple_database.GraphTuple] = [] def __iter__(self): return self def __next__(self): graph: graph_tuple_database.GraphTuple = next( self.input_graphs) self.graphs_read.append(graph) return graph.tuple graph_iterator = GraphIterator(graphs) # Create a disjoint graph out of one or more input graphs. batcher = graph_batcher.GraphBatcher.CreateFromFlags(graph_iterator, ctx=ctx) try: disjoint_graph = next(batcher) except StopIteration: # We have run out of graphs, return an empty batch. return batches.Data(graph_ids=[], data=None) # Workaround for the fact that graph batcher may read one more graph than # actually gets included in the batch. if batcher.last_graph: graphs = graph_iterator.graphs_read[:-1] else: graphs = graph_iterator.graphs_read return batches.Data( graph_ids=[graph.id for graph in graphs], data=GgnnBatchData(disjoint_graph=disjoint_graph, graphs=graphs), ) def GraphReader( self, epoch_type: epoch.Type, graph_db: graph_tuple_database.Database, filters: Optional[List[Callable[[], bool]]] = None, limit: Optional[int] = None, ctx: progress.ProgressContext = progress.NullContext, ) -> graph_database_reader.BufferedGraphReader: """Construct a buffered graph reader. Args: epoch_type: The type of graph reader to return a graph reader for. graph_db: The graph database to read graphs from. filters: A list of filters to impose on the graph database reader. limit: The maximum number of rows to read. ctx: A logging context. Returns: A buffered graph reader instance. """ filters = filters or [] # Only read graphs with data_flow_steps <= message_passing_step_count if # --limit_max_data_flow_steps_during_training is set and we are not # in a test epoch. if (FLAGS.limit_max_data_flow_steps_during_training and self.graph_db.has_data_flow and (epoch_type == epoch.Type.TRAIN or epoch_type == epoch.Type.VAL)): filters.append(lambda: graph_tuple_database.GraphTuple. data_flow_steps <= self.message_passing_step_count) return super(Ggnn, self).GraphReader( epoch_type=epoch_type, graph_db=graph_db, filters=filters, limit=limit, ctx=ctx, ) @property def message_passing_step_count(self) -> int: return self.layer_timesteps.sum() @property def layer_timesteps(self) -> np.array: return np.array([int(x) for x in FLAGS.layer_timesteps]) # TODO(github.com/ChrisCummins/ProGraML/issues/27): Split this into a separate # unroll_strategy.py module. def GetUnrollFactor( self, epoch_type: epoch.Type, batch: batches.Data, unroll_strategy: str, unroll_factor: float, ) -> int: """Determine the unroll factor from the --unroll_strategy and --unroll_factor flags, and the batch log. """ # Determine the unrolling strategy. if unroll_strategy == "none" or epoch_type == epoch.Type.TRAIN: # Perform no unrolling. The inputs are processed for a single run of # message_passing_step_count. This is required during training to # propagate gradients. return 1 elif unroll_strategy == "constant": # Unroll by a constant number of steps. The total number of steps is # (unroll_factor * message_passing_step_count). return int(unroll_factor) elif unroll_strategy == "data_flow_max_steps": max_data_flow_steps = max(graph.data_flow_steps for graph in batch.data.disjoint_graphs) unroll_factor = math.ceil(max_data_flow_steps / self.message_passing_step_count) app.Log( 2, "Determined unroll factor %d from max data flow steps %d", unroll_factor, max_data_flow_steps, ) return unroll_factor elif unroll_strategy == "edge_count": max_edge_count = max(graph.edge_count for graph in batch.data.graphs) unroll_factor = math.ceil((max_edge_count * unroll_factor) / self.message_passing_step_count) app.Log( 2, "Determined unroll factor %d from max edge count %d", unroll_factor, self.message_passing_step_count, ) return unroll_factor elif unroll_strategy == "label_convergence": return 0 else: raise app.UsageError( f"Unknown unroll strategy '{unroll_strategy}'") def RunBatch( self, epoch_type: epoch.Type, batch: batches.Data, ctx: progress.ProgressContext = progress.NullContext, ) -> batches.Results: disjoint_graph: graph_tuple.GraphTuple = batch.data.disjoint_graph # Batch to model-inputs # torch.from_numpy() shares memory with numpy! # TODO(github.com/ChrisCummins/ProGraML/issues/27): maybe we can save # memory copies in the training loop if we can turn the data into the # required types (np.int64 and np.float32) once they come off the network # from the database, where smaller i/o size (int32) is more important. with ctx.Profile(5, "Sent data to GPU"): vocab_ids = torch.from_numpy(disjoint_graph.node_x[:, 0]).to( self.dev, torch.long) selector_ids = torch.from_numpy(disjoint_graph.node_x[:, 1]).to( self.dev, torch.long) # we need those as a result on cpu and can save device i/o cpu_labels = (disjoint_graph.node_y if disjoint_graph.has_node_y else disjoint_graph.graph_y) labels = torch.from_numpy(cpu_labels).to(self.dev) edge_lists = [ torch.from_numpy(x).to(self.dev, torch.long) for x in disjoint_graph.adjacencies ] edge_positions = [ torch.from_numpy(x).to(self.dev, torch.long) for x in disjoint_graph.edge_positions ] model_inputs = (vocab_ids, selector_ids, labels, edge_lists, edge_positions) # maybe fetch more inputs. if disjoint_graph.has_graph_y: assert (disjoint_graph.disjoint_graph_count > 1), f"graph_count is {disjoint_graph.disjoint_graph_count}" num_graphs = torch.tensor(disjoint_graph.disjoint_graph_count).to( self.dev, torch.long) graph_nodes_list = torch.from_numpy( disjoint_graph.disjoint_nodes_list).to(self.dev, torch.long) # TODO(https://github.com/ChrisCummins/ProGraML/issues/37): remove this line on overflow fix! hotfixed_graph_x = np.abs(disjoint_graph.graph_x) aux_in = torch.from_numpy(hotfixed_graph_x).to( self.dev, torch.get_default_dtype()) model_inputs = model_inputs + ( num_graphs, graph_nodes_list, aux_in, ) # enter correct mode of model if epoch_type == epoch.Type.TRAIN and not self.model.training: self.model.train() elif self.model.training: self.model.eval() self.model.opt.zero_grad() outputs = self.model(*model_inputs) logits, accuracy, logits, correct, targets, graph_features = outputs loss = self.model.loss((logits, graph_features), targets) if epoch_type == epoch.Type.TRAIN: loss.backward() # TODO(github.com/ChrisCummins/ProGraML/issues/27):: Clip gradients # (done). NB, pytorch clips by norm of the gradient of the model, while # tf clips by norm of the grad of each tensor separately. Therefore we # change default from 1.0 to 6.0. # TODO(github.com/ChrisCummins/ProGraML/issues/27):: Anyway: Gradients # shouldn't really be clipped if not necessary? if self.model.config.clip_grad_norm > 0.0: nn.utils.clip_grad_norm_(self.model.parameters(), self.model.config.clip_grad_norm) self.model.opt.step() self.model.opt.zero_grad() # tg = targets.numpy() # tg = np.vstack(((tg + 1) % 2, tg)).T # assert np.all(labels.numpy() == tg), f"labels sanity check failed: labels={labels.numpy()}, tg={tg}" # TODO(github.com/ChrisCummins/ProGraML/issues/27): Learning rate schedule # will change this value. learning_rate = self.model.config.lr # TODO(github.com/ChrisCummins/ProGraML/issues/27): Set these. model_converged = False iteration_count = 1 loss_value = loss.item() assert not np.isnan(loss_value), loss return batches.Results.Create( targets=cpu_labels, predictions=logits.detach().cpu().numpy(), model_converged=model_converged, learning_rate=learning_rate, iteration_count=iteration_count, loss=loss_value, ) def GetModelData(self) -> typing.Any: return { "model_state_dict": self.model.state_dict(), "optimizer_state_dict": self.model.opt.state_dict(), } def LoadModelData(self, data_to_load: typing.Any) -> None: self.model.load_state_dict(data_to_load["model_state_dict"]) # only restore opt if needed. opt should be None o/w. if not FLAGS.test_only: self.model.opt.load_state_dict( data_to_load["optimizer_state_dict"])
class Ggnn(classifier_base.ClassifierBase): """A gated graph neural network.""" def __init__(self, *args, **kwargs): """Constructor.""" super(Ggnn, self).__init__(*args, **kwargs) # set some global config values # Instantiate model config = GGNNConfig( num_classes=self.y_dimensionality, has_graph_labels=self.graph_db.graph_y_dimensionality > 0, has_aux_input=self.graph_db.graph_x_dimensionality > 0, ) inst2vec_embeddings = node_encoder.GraphNodeEncoder().embeddings_tables[0] inst2vec_embeddings = torch.from_numpy( np.array(inst2vec_embeddings, dtype=np.float32) ) self.model = GGNNModel( config, pretrained_embeddings=inst2vec_embeddings, test_only=FLAGS.test_only, ) app.Log( 1, "Using device %s with dtype %s", self.model.dev, torch.get_default_dtype(), ) if DEBUG: for submodule in self.model.modules(): submodule.register_forward_hook(nan_hook) def MakeBatch( self, epoch_type: epoch.Type, graphs: Iterable[graph_tuple_database.GraphTuple], ctx: progress.ProgressContext = progress.NullContext, ) -> batches.Data: """Create a mini-batch of data from an iterator of graphs. Returns: A single batch of data for feeding into RunBatch(). A batch consists of a list of graph IDs and a model-defined blob of data. If the list of graph IDs is empty, the batch is discarded and not fed into RunBatch(). """ # TODO(github.com/ChrisCummins/ProGraML/issues/24): The new graph batcher # implementation is not well suited for reading the graph IDs, hence this # somewhat clumsy iterator wrapper. A neater approach would be to create # a graph batcher which returns a list of graphs in the batch. class GraphIterator(object): """A wrapper around a graph iterator which records graph IDs.""" def __init__(self, graphs: Iterable[graph_tuple_database.GraphTuple]): self.input_graphs = graphs self.graphs_read: List[graph_tuple_database.GraphTuple] = [] def __iter__(self): return self def __next__(self): graph: graph_tuple_database.GraphTuple = next(self.input_graphs) self.graphs_read.append(graph) return graph.tuple graph_iterator = GraphIterator(graphs) # Create a disjoint graph out of one or more input graphs. batcher = graph_batcher.GraphBatcher.CreateFromFlags( graph_iterator, ctx=ctx ) try: disjoint_graph = next(batcher) except StopIteration: # We have run out of graphs. return batches.EndOfBatches() # Workaround for the fact that graph batcher may read one more graph than # actually gets included in the batch. if batcher.last_graph: graphs = graph_iterator.graphs_read[:-1] else: graphs = graph_iterator.graphs_read # Discard single-graph batches during training when there are graph # features. This is because we use batch normalization on incoming features, # and batch normalization requires > 1 items to normalize. if ( len(graphs) <= 1 and epoch_type == epoch.Type.TRAIN and disjoint_graph.graph_x_dimensionality ): return batches.EmptyBatch() return batches.Data( graph_ids=[graph.id for graph in graphs], data=GgnnBatchData(disjoint_graph=disjoint_graph, graphs=graphs), ) def GraphReader( self, epoch_type: epoch.Type, graph_db: graph_tuple_database.Database, filters: Optional[List[Callable[[], bool]]] = None, limit: Optional[int] = None, ctx: progress.ProgressContext = progress.NullContext, ) -> graph_database_reader.BufferedGraphReader: """Construct a buffered graph reader. Args: epoch_type: The type of graph reader to return a graph reader for. graph_db: The graph database to read graphs from. filters: A list of filters to impose on the graph database reader. limit: The maximum number of rows to read. ctx: A logging context. Returns: A buffered graph reader instance. """ filters = filters or [] # Only read graphs with data_flow_steps <= message_passing_step_count if # --limit_max_data_flow_steps is set. if FLAGS.limit_max_data_flow_steps and self.graph_db.has_data_flow: filters.append( lambda: graph_tuple_database.GraphTuple.data_flow_steps <= self.message_passing_step_count ) # If we are batching my maximum node count and skipping graphs that are # larger than this, we can apply that filter to the SQL query now, rather # than reading the graphs and ignoring them later. This ensures that when # --max_{train,val}_per_epoch is set, the number of graphs that get used # matches the limit. if ( FLAGS.graph_batch_node_count and FLAGS.max_node_count_limit_handler == "skip" ): filters.append( lambda: ( graph_tuple_database.GraphTuple.node_count <= FLAGS.graph_batch_node_count ) ) return super(Ggnn, self).GraphReader( epoch_type=epoch_type, graph_db=graph_db, filters=filters, limit=limit, ctx=ctx, ) @property def message_passing_step_count(self) -> int: return self.layer_timesteps.sum() @property def layer_timesteps(self) -> np.array: return np.array([int(x) for x in FLAGS.layer_timesteps]) def get_unroll_steps( self, epoch_type: epoch.Type, batch: batches.Data, unroll_strategy: str, ) -> int: """Determine the unroll factor from the --unroll_strategy flag, and the batch log.""" # Determine the unrolling strategy. if unroll_strategy == "none": # Perform no unrolling. The inputs are processed according to layer_timesteps return 0 elif unroll_strategy == "constant": # Unroll by a constant number of steps according to test_layer_timesteps return 0 elif unroll_strategy == "data_flow_max_steps": max_data_flow_steps = max( graph.data_flow_steps for graph in batch.data.graphs ) app.Log(3, "Determined max data flow steps to be %d", max_data_flow_steps) return max_data_flow_steps elif unroll_strategy == "edge_count": max_edge_count = max(graph.edge_count for graph in batch.data.graphs) app.Log(3, "Determined max edge count to be %d", max_edge_count) return max_edge_count elif unroll_strategy == "label_convergence": return 0 else: raise app.UsageError(f"Unknown unroll strategy '{unroll_strategy}'") def PrepareModelInputs( self, epoch_type: epoch.Type, batch: batches.Data ) -> Tuple[np.array, Dict[str, torch.Tensor]]: """RunBatch() helper method to prepare inputs to model. Args: epoch_type: The type of epoch the model is performing. batch: A batch of data to prepare inputs from: Returns: A tuple of <expected outcomes, model inputs>. """ disjoint_graph: graph_tuple.GraphTuple = batch.data.disjoint_graph # Batch to model-inputs. torch.from_numpy() shares memory with numpy. # TODO(github.com/ChrisCummins/ProGraML/issues/27): maybe we can save # memory copies in the training loop if we can turn the data into the # required types (np.int64 and np.float32) once they come off the network # from the database, where smaller i/o size (int32) is more important. vocab_ids = torch.from_numpy(disjoint_graph.node_x[:, 0]).to( self.model.dev, torch.long ) selector_ids = torch.from_numpy(disjoint_graph.node_x[:, 1]).to( self.model.dev, torch.long ) # we need those as a result on cpu and can save device i/o cpu_labels = ( disjoint_graph.node_y if disjoint_graph.has_node_y else disjoint_graph.graph_y ) labels = torch.from_numpy(cpu_labels).to(self.model.dev) edge_lists = [ torch.from_numpy(x).to(self.model.dev, torch.long) for x in disjoint_graph.adjacencies ] edge_positions = [ torch.from_numpy(x).to(self.model.dev, torch.long) for x in disjoint_graph.edge_positions ] model_inputs = { "vocab_ids": vocab_ids, "selector_ids": selector_ids, "labels": labels, "edge_lists": edge_lists, "pos_lists": edge_positions, } # maybe fetch more inputs. if disjoint_graph.has_graph_y: assert ( epoch_type != epoch.Type.TRAIN or disjoint_graph.disjoint_graph_count > 1 ), f"graph_count is {disjoint_graph.disjoint_graph_count}" num_graphs = torch.tensor(disjoint_graph.disjoint_graph_count).to( self.model.dev, torch.long ) graph_nodes_list = torch.from_numpy( disjoint_graph.disjoint_nodes_list ).to(self.model.dev, torch.long) aux_in = torch.from_numpy(disjoint_graph.graph_x).to( self.model.dev, torch.get_default_dtype() ) model_inputs.update( { "num_graphs": num_graphs, "graph_nodes_list": graph_nodes_list, "aux_in": aux_in, } ) return cpu_labels, model_inputs def RunBatch( self, epoch_type: epoch.Type, batch: batches.Data, ctx: progress.ProgressContext = progress.NullContext, ) -> batches.Results: """Process a mini-batch of data through the GGNN. Args: epoch_type: The type of epoch being run. batch: The batch data returned by MakeBatch(). ctx: A logging context. Returns: A batch results instance. """ cpu_labels, model_inputs = self.PrepareModelInputs(epoch_type, batch) # maybe calculate manual timesteps if epoch_type != epoch.Type.TRAIN and FLAGS.unroll_strategy in { "constant", "edge_count", "data_flow_max_steps", "label_convergence", }: time_steps_cpu = np.array( self.get_unroll_steps(epoch_type, batch, FLAGS.unroll_strategy), dtype=np.int64, ) time_steps_gpu = torch.from_numpy(time_steps_cpu).to(self.model.dev) else: time_steps_cpu = 0 time_steps_gpu = None # RUN MODEL FORWARD PASS # enter correct mode of model if epoch_type == epoch.Type.TRAIN: if not self.model.training: self.model.train() outputs = self.model(**model_inputs, test_time_steps=time_steps_gpu) else: # not TRAIN if self.model.training: self.model.eval() self.model.opt.zero_grad() with torch.no_grad(): # don't trace computation graph! outputs = self.model(**model_inputs, test_time_steps=time_steps_gpu) ( logits, accuracy, logits, correct, targets, graph_features, *unroll_stats, ) = outputs loss = self.model.loss((logits, graph_features), targets) if epoch_type == epoch.Type.TRAIN: loss.backward() # TODO(github.com/ChrisCummins/ProGraML/issues/27): Clip gradients # (done). NB, pytorch clips by norm of the gradient of the model, while # tf clips by norm of the grad of each tensor separately. Therefore we # change default from 1.0 to 6.0. # TODO(github.com/ChrisCummins/ProGraML/issues/27): Anyway: Gradients # shouldn't really be clipped if not necessary? if self.model.config.clip_grad_norm > 0.0: nn.utils.clip_grad_norm_( self.model.parameters(), self.model.config.clip_grad_norm ) self.model.opt.step() self.model.opt.zero_grad() # TODO(github.com/ChrisCummins/ProGraML/issues/27): Learning rate schedule # will change this value. learning_rate = self.model.config.lr model_converged = unroll_stats[1] if unroll_stats else False iteration_count = unroll_stats[0] if unroll_stats else time_steps_cpu loss_value = loss.item() assert not np.isnan(loss_value), loss return batches.Results.Create( targets=cpu_labels, predictions=logits.detach().cpu().numpy(), model_converged=model_converged, learning_rate=learning_rate, iteration_count=iteration_count, loss=loss_value, ) def GetModelData(self) -> typing.Any: return { "model_state_dict": self.model.state_dict(), "optimizer_state_dict": self.model.opt.state_dict(), } def LoadModelData(self, data_to_load: typing.Any) -> None: self.model.load_state_dict(data_to_load["model_state_dict"]) # only restore opt if needed. opt should be None o/w. if not FLAGS.test_only: self.model.opt.load_state_dict(data_to_load["optimizer_state_dict"])