def train(data_path, master_csv_path, split_path, batch_size, num_training_steps, save_dir): """OGB Training Script.""" # Initialize the dataset reader. reader = data_utils.DataReader(data_path=data_path, master_csv_path=master_csv_path, split_path=split_path, batch_size=batch_size) # Repeat the dataset forever for training. reader.repeat() # Transform impure `net_fn` to pure functions with hk.transform. net = hk.without_apply_rng(hk.transform(net_fn)) # Get a candidate graph and label to initialize the network. graph, _ = reader.get_graph_by_idx(0) # Initialize the network. logging.info('Initializing network.') params = net.init(jax.random.PRNGKey(42), graph) # Initialize the optimizer. opt_init, opt_update = optax.adam(1e-4) opt_state = opt_init(params) compute_loss_fn = functools.partial(compute_loss, net=net) # We jit the computation of our loss, since this is the main computation. # Using jax.jit means that we will use a single accelerator. If you want # to use more than 1 accelerator, use jax.pmap. More information can be # found in the jax documentation. compute_loss_fn = jax.jit(jax.value_and_grad(compute_loss_fn, has_aux=True)) for idx in range(num_training_steps): graph, label = next(reader) # Jax will re-jit your graphnet every time a new graph shape is encountered. # In the limit, this means a new compilation every training step, which # will result in *extremely* slow training. To prevent this, pad each # batch of graphs to the nearest power of two. Since jax maintains a cache # of compiled programs, the compilation cost is amortized. graph = pad_graph_to_nearest_power_of_two(graph) # Since padding is implemented with pad_with_graphs, an extra graph has # been added to the batch, which means there should be an extra label. label = jnp.concatenate([label, jnp.array([0])]) (loss, acc), grad = compute_loss_fn(params, graph, label) updates, opt_state = opt_update(grad, opt_state, params) params = optax.apply_updates(params, updates) if idx % 100 == 0: logging.info('step: %s, loss: %s, acc: %s', idx, loss, acc) if save_dir is not None: with pathlib.Path(save_dir, 'molhiv.pkl').open('wb') as fp: logging.info('Saving model to %s', save_dir) pickle.dump(params, fp) logging.info('Training finished')
def evaluate(data_path, master_csv_path, split_path, save_dir): """Evaluation Script.""" logging.info('Evaluating OGB molviv') logging.info('Dataset split: %s', split_path) # Initialize the dataset reader. reader = data_utils.DataReader(data_path=data_path, master_csv_path=master_csv_path, split_path=split_path, batch_size=1) # Transform impure `net_fn` to pure functions with hk.transform. net = hk.without_apply_rng(hk.transform(net_fn)) # Get a candidate graph and label to initialize the network. graph, _ = reader.get_graph_by_idx(0) with pathlib.Path(save_dir, 'molhiv.pkl').open('rb') as fp: params = pickle.load(fp) accumulated_loss = 0 accumulated_accuracy = 0 idx = 0 # We jit the computation of our loss, since this is the main computation. # Using jax.jit means that we will use a single accelerator. If you want # to use more than 1 accelerator, use jax.pmap. More information can be # found in the jax documentation. compute_loss_fn = jax.jit(functools.partial(compute_loss, net=net)) for graph, label in reader: # Jax will re-jit your graphnet every time a new graph shape is encountered. # In the limit, this means a new compilation every training step, which # will result in *extremely* slow training. To prevent this, pad each # batch of graphs to the nearest power of two. Since jax maintains a cache # of compiled programs, the compilation cost is amortized. graph = pad_graph_to_nearest_power_of_two(graph) # Since padding is implemented with pad_with_graphs, an extra graph has # been added to the batch, which means there should be an extra label. label = jnp.concatenate([label, jnp.array([0])]) loss, acc = compute_loss_fn(params, graph, label) accumulated_accuracy += acc accumulated_loss += loss idx += 1 if idx % 100 == 0: logging.info('Evaluated %s graphs', idx) logging.info('Completed evaluation.') loss = accumulated_loss / idx accuracy = accumulated_accuracy / idx logging.info('Eval loss: %s, accuracy %s', loss, accuracy) return loss, accuracy
def train(data_path, master_csv_path, split_path, batch_size, num_training_steps, save_dir): """OGB Training Script.""" # Initialize the dataset reader. reader = data_utils.DataReader( data_path=data_path, master_csv_path=master_csv_path, split_path=split_path, batch_size=batch_size) # Repeat the dataset forever for training. reader.repeat() net = GraphNetwork(mlp_features=(128, 128), latent_size=128) # Get a candidate graph and label to initialize the network. graph, _ = reader.get_graph_by_idx(0) # Initialize the network. logging.info('Initializing network.') params = net.init(jax.random.PRNGKey(42), graph) optimizer = optim.Adam(learning_rate=1e-4).create(params) optimizer = jax.device_put(optimizer) for idx in range(num_training_steps): graph, label = next(reader) # Jax will re-jit your graphnet every time a new graph shape is encountered. # In the limit, this means a new compilation every training step, which # will result in *extremely* slow training. To prevent this, pad each # batch of graphs to the nearest power of two. Since jax maintains a cache # of compiled programs, the compilation cost is amortized. graph = pad_graph_to_nearest_power_of_two(graph) # Since padding is implemented with pad_with_graphs, an extra graph has # been added to the batch, which means there should be an extra label. label = jnp.concatenate([label, jnp.array([0])]) optimizer, scalars = train_step(optimizer, graph, label, net) if idx % 100 == 0: logging.info('step: %s, loss: %s, acc: %s', idx, scalars['loss'], scalars['accuracy']) if save_dir is not None: with pathlib.Path(save_dir, 'molhiv.pkl').open('wb') as fp: logging.info('Saving model to %s', save_dir) pickle.dump(optimizer.target, fp) logging.info('Training finished')