def view_model_param(MODEL_NAME, net_params):
    model = gnn_model(MODEL_NAME, net_params)
    total_param = 0
    print("MODEL DETAILS:\n")
    print(model)
    for param in model.parameters():
        # print(param.data.size())
        total_param += np.prod(list(param.data.size()))
    print('MODEL/Total parameters:', MODEL_NAME, total_param)
    return total_param
def train_val_pipeline(MODEL_NAME, dataset, params, net_params, dirs):
    t0 = time.time()
    per_epoch_time = []

    DATASET_NAME = dataset.name

    #assert net_params['self_loop'] == False, "No self-loop support for %s dataset" % DATASET_NAME

    trainset, valset, testset = dataset.train, dataset.val, dataset.test

    root_log_dir, root_ckpt_dir, write_file_name, write_config_file = dirs
    device = net_params['device']

    # Write the network and optimization hyper-parameters in folder config/
    with open(write_config_file + '.txt', 'w') as f:
        f.write(
            """Dataset: {},\nModel: {}\n\nparams={}\n\nnet_params={}\n\n\nTotal Parameters: {}\n\n"""
            .format(DATASET_NAME, MODEL_NAME, params, net_params,
                    net_params['total_param']))

    log_dir = os.path.join(root_log_dir, "RUN_" + str(0))
    writer = SummaryWriter(log_dir=log_dir)

    # setting seeds
    random.seed(params['seed'])
    np.random.seed(params['seed'])
    torch.manual_seed(params['seed'])
    if device.type == 'cuda':
        torch.cuda.manual_seed(params['seed'])

    print("Training Graphs: ", len(trainset))
    print("Validation Graphs: ", len(valset))
    print("Test Graphs: ", len(testset))
    print("Number of Classes: ", net_params['n_classes'])

    model = gnn_model(MODEL_NAME, net_params)
    model = model.to(device)

    optimizer = optim.Adam(model.parameters(),
                           lr=params['init_lr'],
                           weight_decay=params['weight_decay'])
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode='min',
        factor=params['lr_reduce_factor'],
        patience=params['lr_schedule_patience'],
        verbose=True)

    epoch_train_losses, epoch_val_losses = [], []
    epoch_train_f1s, epoch_val_f1s = [], []

    if MODEL_NAME in ['RingGNN', '3WLGNN']:
        # import train functions specific for WL-GNNs
        from train.train_TSP_edge_classification import train_epoch_dense as train_epoch, evaluate_network_dense as evaluate_network
        from functools import partial  # util function to pass edge_feat to collate function

        train_loader = DataLoader(trainset,
                                  shuffle=True,
                                  collate_fn=partial(
                                      dataset.collate_dense_gnn,
                                      edge_feat=net_params['edge_feat']))
        val_loader = DataLoader(valset,
                                shuffle=False,
                                collate_fn=partial(
                                    dataset.collate_dense_gnn,
                                    edge_feat=net_params['edge_feat']))
        test_loader = DataLoader(testset,
                                 shuffle=False,
                                 collate_fn=partial(
                                     dataset.collate_dense_gnn,
                                     edge_feat=net_params['edge_feat']))

    else:
        # import train functions for all other GCNs
        from train.train_TSP_edge_classification import train_epoch_sparse as train_epoch, evaluate_network_sparse as evaluate_network

        train_loader = DataLoader(trainset,
                                  batch_size=params['batch_size'],
                                  shuffle=True,
                                  collate_fn=dataset.collate)
        val_loader = DataLoader(valset,
                                batch_size=params['batch_size'],
                                shuffle=False,
                                collate_fn=dataset.collate)
        test_loader = DataLoader(testset,
                                 batch_size=params['batch_size'],
                                 shuffle=False,
                                 collate_fn=dataset.collate)

    # At any point you can hit Ctrl + C to break out of training early.
    try:
        with tqdm(range(params['epochs'])) as t:
            for epoch in t:

                t.set_description('Epoch %d' % epoch)

                start = time.time()

                if MODEL_NAME in [
                        'RingGNN', '3WLGNN'
                ]:  # since different batch training function for dense GNNs
                    epoch_train_loss, epoch_train_f1, optimizer = train_epoch(
                        model, optimizer, device, train_loader, epoch,
                        params['batch_size'])
                else:  # for all other models common train function
                    epoch_train_loss, epoch_train_f1, optimizer = train_epoch(
                        model, optimizer, device, train_loader, epoch)

                epoch_val_loss, epoch_val_f1 = evaluate_network(
                    model, device, val_loader, epoch)
                _, epoch_test_f1 = evaluate_network(model, device, test_loader,
                                                    epoch)

                epoch_train_losses.append(epoch_train_loss)
                epoch_val_losses.append(epoch_val_loss)
                epoch_train_f1s.append(epoch_train_f1)
                epoch_val_f1s.append(epoch_val_f1)

                writer.add_scalar('train/_loss', epoch_train_loss, epoch)
                writer.add_scalar('val/_loss', epoch_val_loss, epoch)
                writer.add_scalar('train/_f1', epoch_train_f1, epoch)
                writer.add_scalar('val/_f1', epoch_val_f1, epoch)
                writer.add_scalar('test/_f1', epoch_test_f1, epoch)
                writer.add_scalar('learning_rate',
                                  optimizer.param_groups[0]['lr'], epoch)

                t.set_postfix(time=time.time() - start,
                              lr=optimizer.param_groups[0]['lr'],
                              train_loss=epoch_train_loss,
                              val_loss=epoch_val_loss,
                              train_f1=epoch_train_f1,
                              val_f1=epoch_val_f1,
                              test_f1=epoch_test_f1)

                per_epoch_time.append(time.time() - start)

                # Saving checkpoint
                ckpt_dir = os.path.join(root_ckpt_dir, "RUN_")
                if not os.path.exists(ckpt_dir):
                    os.makedirs(ckpt_dir)
                torch.save(model.state_dict(),
                           '{}.pkl'.format(ckpt_dir + "/epoch_" + str(epoch)))

                files = glob.glob(ckpt_dir + '/*.pkl')
                for file in files:
                    epoch_nb = file.split('_')[-1]
                    epoch_nb = int(epoch_nb.split('.')[0])
                    if epoch_nb < epoch - 1:
                        os.remove(file)

                scheduler.step(epoch_val_loss)

                if optimizer.param_groups[0]['lr'] < params['min_lr']:
                    print("\n!! LR EQUAL TO MIN LR SET.")
                    break

                # Stop training after params['max_time'] hours
                if time.time() - t0 > params['max_time'] * 3600:
                    print('-' * 89)
                    print(
                        "Max_time for training elapsed {:.2f} hours, so stopping"
                        .format(params['max_time']))
                    break

    except KeyboardInterrupt:
        print('-' * 89)
        print('Exiting from training early because of KeyboardInterrupt')

    _, test_f1 = evaluate_network(model, device, test_loader, epoch)
    _, train_f1 = evaluate_network(model, device, train_loader, epoch)
    print("Test F1: {:.4f}".format(test_f1))
    print("Train F1: {:.4f}".format(train_f1))
    print("Convergence Time (Epochs): {:.4f}".format(epoch))
    print("TOTAL TIME TAKEN: {:.4f}s".format(time.time() - t0))
    print("AVG TIME PER EPOCH: {:.4f}s".format(np.mean(per_epoch_time)))

    writer.close()
    """
        Write the results in out_dir/results folder
    """
    with open(write_file_name + '.txt', 'w') as f:
        f.write("""Dataset: {},\nModel: {}\n\nparams={}\n\nnet_params={}\n\n{}\n\nTotal Parameters: {}\n\n
    FINAL RESULTS\nTEST F1: {:.4f}\nTRAIN F1: {:.4f}\n\n
    Convergence Time (Epochs): {:.4f}\nTotal Time Taken: {:.4f}hrs\nAverage Time Per Epoch: {:.4f}s\n\n\n"""\
          .format(DATASET_NAME, MODEL_NAME, params, net_params, model, net_params['total_param'],
                  np.mean(np.array(test_f1)), np.mean(np.array(train_f1)), epoch, (time.time()-t0)/3600, np.mean(per_epoch_time)))