Ejemplo n.º 1
0
def train(data_path, master_csv_path, split_path, batch_size,
          num_training_steps, save_dir):
    """OGB Training Script."""
    # Initialize the dataset reader.
    reader = data_utils.DataReader(data_path=data_path,
                                   master_csv_path=master_csv_path,
                                   split_path=split_path,
                                   batch_size=batch_size)
    # Repeat the dataset forever for training.
    reader.repeat()

    # Transform impure `net_fn` to pure functions with hk.transform.
    net = hk.without_apply_rng(hk.transform(net_fn))
    # Get a candidate graph and label to initialize the network.
    graph, _ = reader.get_graph_by_idx(0)

    # Initialize the network.
    logging.info('Initializing network.')
    params = net.init(jax.random.PRNGKey(42), graph)
    # Initialize the optimizer.
    opt_init, opt_update = optax.adam(1e-4)
    opt_state = opt_init(params)

    compute_loss_fn = functools.partial(compute_loss, net=net)
    # We jit the computation of our loss, since this is the main computation.
    # Using jax.jit means that we will use a single accelerator. If you want
    # to use more than 1 accelerator, use jax.pmap. More information can be
    # found in the jax documentation.
    compute_loss_fn = jax.jit(jax.value_and_grad(compute_loss_fn,
                                                 has_aux=True))

    for idx in range(num_training_steps):
        graph, label = next(reader)
        # Jax will re-jit your graphnet every time a new graph shape is encountered.
        # In the limit, this means a new compilation every training step, which
        # will result in *extremely* slow training. To prevent this, pad each
        # batch of graphs to the nearest power of two. Since jax maintains a cache
        # of compiled programs, the compilation cost is amortized.
        graph = pad_graph_to_nearest_power_of_two(graph)

        # Since padding is implemented with pad_with_graphs, an extra graph has
        # been added to the batch, which means there should be an extra label.
        label = jnp.concatenate([label, jnp.array([0])])

        (loss, acc), grad = compute_loss_fn(params, graph, label)
        updates, opt_state = opt_update(grad, opt_state, params)
        params = optax.apply_updates(params, updates)
        if idx % 100 == 0:
            logging.info('step: %s, loss: %s, acc: %s', idx, loss, acc)
    if save_dir is not None:
        with pathlib.Path(save_dir, 'molhiv.pkl').open('wb') as fp:
            logging.info('Saving model to %s', save_dir)
            pickle.dump(params, fp)
    logging.info('Training finished')
Ejemplo n.º 2
0
def evaluate(data_path, master_csv_path, split_path, save_dir):
    """Evaluation Script."""
    logging.info('Evaluating OGB molviv')
    logging.info('Dataset split: %s', split_path)
    # Initialize the dataset reader.
    reader = data_utils.DataReader(data_path=data_path,
                                   master_csv_path=master_csv_path,
                                   split_path=split_path,
                                   batch_size=1)
    # Transform impure `net_fn` to pure functions with hk.transform.
    net = hk.without_apply_rng(hk.transform(net_fn))
    # Get a candidate graph and label to initialize the network.
    graph, _ = reader.get_graph_by_idx(0)
    with pathlib.Path(save_dir, 'molhiv.pkl').open('rb') as fp:
        params = pickle.load(fp)
    accumulated_loss = 0
    accumulated_accuracy = 0
    idx = 0

    # We jit the computation of our loss, since this is the main computation.
    # Using jax.jit means that we will use a single accelerator. If you want
    # to use more than 1 accelerator, use jax.pmap. More information can be
    # found in the jax documentation.
    compute_loss_fn = jax.jit(functools.partial(compute_loss, net=net))
    for graph, label in reader:

        # Jax will re-jit your graphnet every time a new graph shape is encountered.
        # In the limit, this means a new compilation every training step, which
        # will result in *extremely* slow training. To prevent this, pad each
        # batch of graphs to the nearest power of two. Since jax maintains a cache
        # of compiled programs, the compilation cost is amortized.
        graph = pad_graph_to_nearest_power_of_two(graph)

        # Since padding is implemented with pad_with_graphs, an extra graph has
        # been added to the batch, which means there should be an extra label.
        label = jnp.concatenate([label, jnp.array([0])])
        loss, acc = compute_loss_fn(params, graph, label)
        accumulated_accuracy += acc
        accumulated_loss += loss
        idx += 1
        if idx % 100 == 0:
            logging.info('Evaluated %s graphs', idx)
    logging.info('Completed evaluation.')
    loss = accumulated_loss / idx
    accuracy = accumulated_accuracy / idx
    logging.info('Eval loss: %s, accuracy %s', loss, accuracy)
    return loss, accuracy
Ejemplo n.º 3
0
def train(data_path, master_csv_path, split_path, batch_size,
          num_training_steps, save_dir):
  """OGB Training Script."""

  # Initialize the dataset reader.
  reader = data_utils.DataReader(
      data_path=data_path,
      master_csv_path=master_csv_path,
      split_path=split_path,
      batch_size=batch_size)
  # Repeat the dataset forever for training.
  reader.repeat()

  net = GraphNetwork(mlp_features=(128, 128), latent_size=128)

  # Get a candidate graph and label to initialize the network.
  graph, _ = reader.get_graph_by_idx(0)

  # Initialize the network.
  logging.info('Initializing network.')
  params = net.init(jax.random.PRNGKey(42), graph)
  optimizer = optim.Adam(learning_rate=1e-4).create(params)
  optimizer = jax.device_put(optimizer)

  for idx in range(num_training_steps):
    graph, label = next(reader)
    # Jax will re-jit your graphnet every time a new graph shape is encountered.
    # In the limit, this means a new compilation every training step, which
    # will result in *extremely* slow training. To prevent this, pad each
    # batch of graphs to the nearest power of two. Since jax maintains a cache
    # of compiled programs, the compilation cost is amortized.
    graph = pad_graph_to_nearest_power_of_two(graph)

    # Since padding is implemented with pad_with_graphs, an extra graph has
    # been added to the batch, which means there should be an extra label.
    label = jnp.concatenate([label, jnp.array([0])])
    optimizer, scalars = train_step(optimizer, graph, label, net)
    if idx % 100 == 0:
      logging.info('step: %s, loss: %s, acc: %s', idx, scalars['loss'],
                   scalars['accuracy'])
  if save_dir is not None:
    with pathlib.Path(save_dir, 'molhiv.pkl').open('wb') as fp:
      logging.info('Saving model to %s', save_dir)
      pickle.dump(optimizer.target, fp)
  logging.info('Training finished')