def MakeBatchBuilder( dataset_root: pathlib.Path, log_dir: pathlib.Path, epoch_type: epoch_pb2.EpochType, model: Lstm, min_graph_count=None, max_graph_count=None, seed=None, ): logfile = ( log_dir / "graph_loader" / f"{epoch_pb2.EpochType.Name(epoch_type).lower()}.txt" ) return DataflowLstmBatchBuilder( graph_loader=DataflowGraphLoader( dataset_root, epoch_type=epoch_type, analysis=FLAGS.analysis, min_graph_count=min_graph_count, max_graph_count=max_graph_count, data_flow_step_max=FLAGS.max_data_flow_steps, require_inst2vec=True, # Append to logfile since we may be resuming a previous job. logfile=open(str(logfile), "a"), seed=seed, ), vocabulary=model.vocabulary, padded_sequence_length=model.padded_sequence_length, batch_size=model.batch_size, node_y_dimensionality=model.node_y_dimensionality, )
def GraphLoader(path): return DataflowGraphLoader( path=path, epoch_type=epoch_pb2.TRAIN, analysis="reachability", min_graph_count=FLAGS.graph_count or 1, max_graph_count=FLAGS.graph_count, logfile=open(path / "graph_reader_log.txt", "w"), )
def MakeBatchBuilder( dataset_root: pathlib.Path, log_dir: pathlib.Path, analysis: str, epoch_type: epoch_pb2.EpochType, model: Ggnn, batch_size: int, use_cdfg: bool, limit_max_data_flow_steps: bool, min_graph_count=None, max_graph_count=None, seed=None, ): if limit_max_data_flow_steps: data_flow_step_max = model.message_passing_step_count else: data_flow_step_max = None logfile = ( log_dir / "graph_loader" / f"{epoch_pb2.EpochType.Name(epoch_type).lower()}.txt" ) return DataflowGgnnBatchBuilder( graph_loader=DataflowGraphLoader( dataset_root, epoch_type=epoch_type, analysis=analysis, min_graph_count=min_graph_count, max_graph_count=max_graph_count, data_flow_step_max=data_flow_step_max, # Append to logfile since we may be resuming a previous job. logfile=open(str(logfile), "a"), seed=seed, use_cdfg=use_cdfg, ), vocabulary=model.vocabulary, max_node_size=batch_size, use_cdfg=use_cdfg, )