def model_forward(model, data, sampler=None, is_train=True):
    if sampler is None:
        sampler = RandomSampler(data, FLAGS.batch_size, FLAGS.sample_induced)
    if FLAGS.lower_level_layers and FLAGS.higher_level_layers:
        if "model_init" in FLAGS.init_embds and is_train:
            _get_initial_embd(data, model)
            data.dataset.init_interaction_graph_embds(device=FLAGS.device)
            model.init_x = data.dataset.interaction_combo_nxgraph.init_x.cpu(
            ).detach().numpy()

        batch_gids, sampled_gids, subgraph = sampler.sample_next_training_batch(
        )
        batch_data = BatchData(
            batch_gids,
            data.dataset,
            is_train=is_train,
            sampled_gids=sampled_gids,
            enforce_negative_sampling=FLAGS.enforce_negative_sampling,
            unique_graphs=FLAGS.batch_unique_graphs,
            subgraph=subgraph)

        if FLAGS.pair_interaction:
            model.use_layers = 'lower_layers'
            model(batch_data)
        model.use_layers = 'higher_layers'
    else:
        batch_gids, sampled_gids, subgraph = sampler.sample_next_training_batch(
        )
        batch_data = BatchData(
            batch_gids,
            data.dataset,
            is_train=is_train,
            sampled_gids=sampled_gids,
            enforce_negative_sampling=FLAGS.enforce_negative_sampling,
            unique_graphs=FLAGS.batch_unique_graphs,
            subgraph=subgraph)
    return batch_data