コード例 #1
0
ファイル: train_lstm.py プロジェクト: sailfish009/ProGraML
def MakeBatchBuilder(
  dataset_root: pathlib.Path,
  log_dir: pathlib.Path,
  epoch_type: epoch_pb2.EpochType,
  model: Lstm,
  min_graph_count=None,
  max_graph_count=None,
  seed=None,
):
  logfile = (
    log_dir
    / "graph_loader"
    / f"{epoch_pb2.EpochType.Name(epoch_type).lower()}.txt"
  )
  return DataflowLstmBatchBuilder(
    graph_loader=DataflowGraphLoader(
      dataset_root,
      epoch_type=epoch_type,
      analysis=FLAGS.analysis,
      min_graph_count=min_graph_count,
      max_graph_count=max_graph_count,
      data_flow_step_max=FLAGS.max_data_flow_steps,
      require_inst2vec=True,
      # Append to logfile since we may be resuming a previous job.
      logfile=open(str(logfile), "a"),
      seed=seed,
    ),
    vocabulary=model.vocabulary,
    padded_sequence_length=model.padded_sequence_length,
    batch_size=model.batch_size,
    node_y_dimensionality=model.node_y_dimensionality,
  )
コード例 #2
0
def GraphLoader(path):
    return DataflowGraphLoader(
        path=path,
        epoch_type=epoch_pb2.TRAIN,
        analysis="reachability",
        min_graph_count=FLAGS.graph_count or 1,
        max_graph_count=FLAGS.graph_count,
        logfile=open(path / "graph_reader_log.txt", "w"),
    )
コード例 #3
0
ファイル: ggnn.py プロジェクト: deeplearning2012/ProGraML
def MakeBatchBuilder(
  dataset_root: pathlib.Path,
  log_dir: pathlib.Path,
  analysis: str,
  epoch_type: epoch_pb2.EpochType,
  model: Ggnn,
  batch_size: int,
  use_cdfg: bool,
  limit_max_data_flow_steps: bool,
  min_graph_count=None,
  max_graph_count=None,
  seed=None,
):
  if limit_max_data_flow_steps:
    data_flow_step_max = model.message_passing_step_count
  else:
    data_flow_step_max = None
  logfile = (
    log_dir
    / "graph_loader"
    / f"{epoch_pb2.EpochType.Name(epoch_type).lower()}.txt"
  )
  return DataflowGgnnBatchBuilder(
    graph_loader=DataflowGraphLoader(
      dataset_root,
      epoch_type=epoch_type,
      analysis=analysis,
      min_graph_count=min_graph_count,
      max_graph_count=max_graph_count,
      data_flow_step_max=data_flow_step_max,
      # Append to logfile since we may be resuming a previous job.
      logfile=open(str(logfile), "a"),
      seed=seed,
      use_cdfg=use_cdfg,
    ),
    vocabulary=model.vocabulary,
    max_node_size=batch_size,
    use_cdfg=use_cdfg,
  )