Exemplo n.º 1
0
def TestDataflowLSTM(
    path: pathlib.Path,
    log_dir: pathlib.Path,
    vocab: Dict[str, int],
):
    dataflow.PatchWarnings()
    dataflow.RecordExperimentalSetup(log_dir)

    # Create the logging directories.
    assert (log_dir / "epochs").is_dir()
    assert (log_dir / "checkpoints").is_dir()
    (log_dir / "graph_loader").mkdir(exist_ok=True)

    # Create the model, defining the shape of the graphs that it will process.
    #
    # For these data flow experiments, our graphs contain per-node binary
    # classification targets (e.g. reachable / not-reachable).
    model = Lstm(
        vocabulary=vocab,
        test_only=True,
        node_y_dimensionality=2,
    )
    restored_epoch, checkpoint = dataflow.SelectTestCheckpoint(log_dir)
    model.RestoreCheckpoint(checkpoint)

    batches = MakeBatchBuilder(
        dataset_root=path,
        log_dir=log_dir,
        epoch_type=epoch_pb2.TEST,
        model=model,
        min_graph_count=1,
    )

    start_time = time.time()
    test_results = model.RunBatches(epoch_pb2.TEST, batches, log_prefix="Test")
    epoch = epoch_pb2.EpochList(epoch=[
        epoch_pb2.Epoch(
            walltime_seconds=time.time() - start_time,
            epoch_num=restored_epoch.epoch_num,
            test_results=test_results,
        )
    ])
    print(epoch, end="")

    epoch_path = log_dir / "epochs" / "TEST.EpochList.pbtxt"
    pbutil.ToFile(epoch, epoch_path)
    logging.info("Wrote %s", epoch_path)
def Main():
    # NOTE(github.com/ChrisCummins/ProGraML/issues/13): F1 score computation
    # warns that it is undefined when there are missing instances from a class,
    # which is fine for our usage.
    warnings.filterwarnings("ignore", category=UndefinedMetricWarning)

    with data_directory() as path:
        Print("=== BENCHMARK 1: Loading graphs from filesystem ===")
        graph_loader = GraphLoader(path)
        graphs = ppar.ThreadedIterator(graph_loader, max_queue_size=100)
        with prof.Profile("Benchmark graph loader"):
            for _ in tqdm(graphs, unit=" graphs"):
                pass

        Print("=== BENCHMARK 2: Batch construction ===")
        model = Lstm(vocabulary=Vocab(), node_y_dimensionality=2)
        batches = BatchBuilder(model, GraphLoader(path), Vocab())
        batches = ppar.ThreadedIterator(batches, max_queue_size=100)
        cached_batches = []
        with prof.Profile("Benchmark batch construction"):
            for batch in tqdm(batches, unit=" batches"):
                cached_batches.append(batch)

        Print("=== BENCHMARK 3: Model training ===")
        model.Initialize()

        model.model.summary()

        with prof.Profile("Benchmark training (prebuilt batches)"):
            model.RunBatches(
                epoch_pb2.TRAIN,
                cached_batches[: FLAGS.train_batch_count],
                log_prefix="Train",
                total_graph_count=sum(
                    b.graph_count for b in cached_batches[: FLAGS.train_batch_count]
                ),
            )
        with prof.Profile("Benchmark training"):
            model.RunBatches(
                epoch_pb2.TRAIN,
                BatchBuilder(
                    model, GraphLoader(path), Vocab(), FLAGS.train_batch_count
                ),
                log_prefix="Train",
            )

        Print("=== BENCHMARK 4: Model inference ===")
        model = Lstm(
            vocabulary=Vocab(),
            node_y_dimensionality=2,
            test_only=True,
        )
        model.Initialize()

        with prof.Profile("Benchmark inference (prebuilt batches)"):
            model.RunBatches(
                epoch_pb2.TEST,
                cached_batches[: FLAGS.test_batch_count],
                log_prefix="Val",
                total_graph_count=sum(
                    b.graph_count for b in cached_batches[: FLAGS.test_batch_count]
                ),
            )
        with prof.Profile("Benchmark inference"):
            model.RunBatches(
                epoch_pb2.TEST,
                BatchBuilder(model, GraphLoader(path), Vocab(), FLAGS.test_batch_count),
                log_prefix="Val",
            )
Exemplo n.º 3
0
def TrainDataflowLSTM(
  path: pathlib.Path,
  vocab: Dict[str, int],
  val_seed: int,
  restore_from: pathlib.Path,
) -> pathlib.Path:
  if not path.is_dir():
    raise FileNotFoundError(path)

  if restore_from:
    log_dir = restore_from
  else:
    # Create the logging directories.
    log_dir = dataflow.CreateLoggingDirectories(
      dataset_root=path,
      model_name="inst2vec",
      analysis=FLAGS.analysis,
      run_id=FLAGS.run_id,
    )

  dataflow.PatchWarnings()
  dataflow.RecordExperimentalSetup(log_dir)

  # Cumulative totals for training graph counts at each "epoch".
  train_graph_counts = [int(x) for x in FLAGS.train_graph_counts]
  train_graph_cumsums = np.array(train_graph_counts, dtype=np.int32)
  # The number of training graphs in each "epoch".
  train_graph_counts = train_graph_cumsums - np.concatenate(
    ([0], train_graph_counts[:-1])
  )

  # Create the model, defining the shape of the graphs that it will process.
  #
  # For these data flow experiments, our graphs contain per-node binary
  # classification targets (e.g. reachable / not-reachable).
  model = Lstm(vocabulary=vocab, test_only=False, node_y_dimensionality=2,)

  if restore_from:
    # Pick up training where we left off.
    restored_epoch, checkpoint = dataflow.SelectTrainingCheckpoint(log_dir)
    # Skip the epochs that we have already done.
    # This requires that --train_graph_counts is the same as it was in the
    # run that we are resuming!
    start_epoch_step = restored_epoch.epoch_num
    start_graph_cumsum = sum(train_graph_counts[:start_epoch_step])
    train_graph_counts = train_graph_counts[start_epoch_step:]
    train_graph_cumsums = train_graph_cumsums[start_epoch_step:]
    model.RestoreCheckpoint(checkpoint)
  else:
    # Else initialize a new model.
    model.Initialize()
    start_epoch_step, start_graph_cumsum = 1, 0

  model.model.summary()

  # Create training batches and split into epochs.
  epochs = EpochBatchIterator(
    MakeBatchBuilder(
      dataset_root=path,
      log_dir=log_dir,
      epoch_type=epoch_pb2.TRAIN,
      model=model,
      seed=val_seed,
    ),
    train_graph_counts,
    start_graph_count=start_graph_cumsum,
  )

  # Read val batches asynchronously
  val_batches = AsyncBatchBuilder(
    batch_builder=MakeBatchBuilder(
      dataset_root=path,
      log_dir=log_dir,
      epoch_type=epoch_pb2.VAL,
      model=model,
      min_graph_count=FLAGS.val_graph_count,
      max_graph_count=FLAGS.val_graph_count,
      seed=val_seed,
    )
  )

  for (
    epoch_step,
    (train_graph_count, train_graph_cumsum, train_batches),
  ) in enumerate(epochs, start=start_epoch_step):
    start_time = time.time()
    hr_graph_cumsum = f"{humanize.Commas(train_graph_cumsum)} graphs"

    train_results = model.RunBatches(
      epoch_pb2.TRAIN,
      train_batches,
      log_prefix=f"Train to {hr_graph_cumsum}",
      total_graph_count=train_graph_count,
    )
    val_results = model.RunBatches(
      epoch_pb2.VAL,
      val_batches.batches,
      log_prefix=f"Val at {hr_graph_cumsum}",
      total_graph_count=FLAGS.val_graph_count,
    )

    # Write the epoch to file as an epoch list. This may seem redundant since
    # epoch list contains a single item, but it means that we can easily
    # concatenate a sequence of these epoch protos to produce a valid epoch
    # list using: `cat *.EpochList.pbtxt > epochs.pbtxt`
    epoch = epoch_pb2.EpochList(
      epoch=[
        epoch_pb2.Epoch(
          walltime_seconds=time.time() - start_time,
          epoch_num=epoch_step,
          train_results=train_results,
          val_results=val_results,
        )
      ]
    )
    print(epoch, end="")

    epoch_path = log_dir / "epochs" / f"{epoch_step:03d}.EpochList.pbtxt"
    pbutil.ToFile(epoch, epoch_path)
    app.Log(1, "Wrote %s", epoch_path)

    checkpoint_path = (
      log_dir / "checkpoints" / f"{epoch_step:03d}.Checkpoint.pb"
    )
    pbutil.ToFile(model.SaveCheckpoint(), checkpoint_path)

  return log_dir