def view_model_param(MODEL_NAME, net_params):
    model = gnn_model(MODEL_NAME, net_params)
    total_param = 0
    print("MODEL DETAILS:\n")
    #print(model)
    for param in model.parameters():
        #print(param.data.size())
        total_param += np.prod(list(param.data.size()))
    print('MODEL/Total parameters:', MODEL_NAME, total_param)
    return total_param
def train_val_pipeline(MODEL_NAME, dataset, params, net_params, dirs):

    start0 = time.time()
    per_epoch_time = []

    DATASET_NAME = dataset.name

    if MODEL_NAME in ['GCN', 'GAT', 'SGC']:
        if net_params['self_loop']:
            print("[!] Adding graph self-loops (central node trick).")
            dataset._add_self_loops()

    root_log_dir, root_ckpt_dir, write_file_name, write_config_file = dirs
    device = net_params['device']

    print('The seed is ', params['seed'])
    random.seed(params['seed'])
    torch.manual_seed(params['seed'])
    if device == 'cuda':
        torch.cuda.manual_seed(params['seed'])

    np.random.seed(params['seed'])
    num_nodes = dataset.train_mask.size(0)
    index = torch.tensor(np.random.permutation(num_nodes))
    print('index:', index)
    train_index = index[:int(num_nodes * 0.6)]
    val_index = index[int(num_nodes * 0.6):int(num_nodes * 0.8)]
    test_index = index[int(num_nodes * 0.8):]

    train_mask = index_to_mask(train_index, size=num_nodes)
    val_mask = index_to_mask(val_index, size=num_nodes)
    test_mask = index_to_mask(test_index, size=num_nodes)

    train_mask = train_mask.to(device)
    val_mask = val_mask.to(device)
    test_mask = test_mask.to(device)

    labels = dataset.labels.to(device)

    # Write network and optimization hyper-parameters in folder config/
    with open(write_config_file + '.txt', 'w') as f:
        f.write(
            """Dataset: {},\nModel: {}\n\nparams={}\n\nnet_params={}\n\n\nTotal Parameters: {}\n\n"""
            .format(DATASET_NAME, MODEL_NAME, params, net_params,
                    net_params['total_param']))

    log_dir = os.path.join(root_log_dir, "RUN_" + str(0))
    writer = SummaryWriter(log_dir=log_dir)

    print("Training Nodes: ", train_mask.int().sum().item())
    print("Validation Nodes: ", val_mask.int().sum().item())
    print("Test Nodes: ", test_mask.int().sum().item())
    print("Number of Classes: ", net_params['n_classes'])

    model = gnn_model(MODEL_NAME, net_params)
    model = model.to(device)

    optimizer = optim.Adam(model.parameters(),
                           lr=params['init_lr'],
                           weight_decay=params['weight_decay'])
    # scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min',
    #                                                 factor=params['lr_reduce_factor'],
    #                                                 patience=params['lr_schedule_patience'],
    #                                                 verbose=True)

    epoch_train_losses, epoch_val_losses = [], []
    epoch_train_accs, epoch_val_accs = [], []

    graph = dataset.graph
    nfeat = graph.ndata['feat'].to(device)
    efeat = graph.edata['feat'].to(device)
    norm_n = dataset.norm_n.to(device)
    norm_e = dataset.norm_e.to(device)

    # At any point you can hit Ctrl + C to break out of training early.
    try:
        with tqdm(range(params['epochs'])) as t:
            best_val_acc = 0
            for epoch in t:

                t.set_description('Epoch %d' % epoch)

                start = time.time()

                epoch_train_loss, epoch_train_acc, optimizer = train_epoch(
                    model, optimizer, device, graph, nfeat, efeat, norm_n,
                    norm_e, train_mask, labels, epoch)

                epoch_val_loss, epoch_val_acc = evaluate_network(
                    model, graph, nfeat, efeat, norm_n, norm_e, val_mask,
                    labels, epoch)

                epoch_train_losses.append(epoch_train_loss)
                epoch_val_losses.append(epoch_val_loss)
                epoch_train_accs.append(epoch_train_acc)
                epoch_val_accs.append(epoch_val_acc)

                writer.add_scalar('train/_loss', epoch_train_loss, epoch)
                writer.add_scalar('val/_loss', epoch_val_loss, epoch)
                writer.add_scalar('train/_acc', epoch_train_acc, epoch)
                writer.add_scalar('val/_acc', epoch_val_acc, epoch)
                writer.add_scalar('learning_rate',
                                  optimizer.param_groups[0]['lr'], epoch)

                _, epoch_test_acc = evaluate_network(model, graph, nfeat,
                                                     efeat, norm_n, norm_e,
                                                     test_mask, labels, epoch)
                t.set_postfix(time=time.time() - start,
                              lr=optimizer.param_groups[0]['lr'],
                              train_loss=epoch_train_loss,
                              val_loss=epoch_val_loss,
                              train_acc=epoch_train_acc,
                              val_acc=epoch_val_acc,
                              test_acc=epoch_test_acc)

                per_epoch_time.append(time.time() - start)

                # Saving checkpoint
                ckpt_dir = os.path.join(root_ckpt_dir, "RUN_")
                if not os.path.exists(ckpt_dir):
                    os.makedirs(ckpt_dir)
                torch.save(model.state_dict(),
                           '{}.pkl'.format(ckpt_dir + "/epoch_" + str(epoch)))
                if best_val_acc < epoch_val_acc:
                    best_val_acc = epoch_val_acc
                    torch.save(model.state_dict(),
                               '{}.pkl'.format(ckpt_dir + "/best"))

                files = glob.glob(ckpt_dir + '/*.pkl')
                for file in files:
                    if file[-8:] == 'best.pkl':
                        continue
                    else:
                        epoch_nb = file.split('_')[-1]
                        epoch_nb = int(epoch_nb.split('.')[0])
                        if epoch_nb < epoch - 1:
                            os.remove(file)

                #scheduler.step(epoch_val_loss)

                if optimizer.param_groups[0]['lr'] < params['min_lr']:
                    optimizer.param_groups[0]['lr'] = params['min_lr']
                    #print("\n!! LR SMALLER OR EQUAL TO MIN LR THRESHOLD.")
                    #break

                # Stop training after params['max_time'] hours
                if time.time() - start0 > params['max_time'] * 3600:
                    print('-' * 89)
                    print(
                        "Max_time for training elapsed {:.2f} hours, so stopping"
                        .format(params['max_time']))
                    break

    except KeyboardInterrupt:
        print('-' * 89)
        print('Exiting from training early because of KeyboardInterrupt')

    model.load_state_dict(torch.load('{}.pkl'.format(ckpt_dir + "/best")))
    _, test_acc = evaluate_network(model, graph, nfeat, efeat, norm_n, norm_e,
                                   test_mask, labels, epoch)
    _, val_acc = evaluate_network(model, graph, nfeat, efeat, norm_n, norm_e,
                                  val_mask, labels, epoch)
    _, train_acc = evaluate_network(model, graph, nfeat, efeat, norm_n, norm_e,
                                    train_mask, labels, epoch)
    print("Test Accuracy: {:.4f}".format(test_acc))
    print("Train Accuracy: {:.4f}".format(train_acc))
    print("TOTAL TIME TAKEN: {:.4f}s".format(time.time() - start0))
    print("AVG TIME PER EPOCH: {:.4f}s".format(np.mean(per_epoch_time)))

    writer.close()
    """
        Write the results in out_dir/results folder
    """
    with open(write_file_name + '.txt', 'w') as f:
        f.write("""Dataset: {},\nModel: {}\n\nparams={}\n\nnet_params={}\n\n{}\n\nTotal Parameters: {}\n\n
    FINAL RESULTS\nTEST ACCURACY: {:.4f}\nTRAIN ACCURACY: {:.4f}\n\n
    Total Time Taken: {:.4f} hrs\nAverage Time Per Epoch: {:.4f} s\n\n\n"""\
          .format(DATASET_NAME, MODEL_NAME, params, net_params, model, net_params['total_param'],
                  test_acc, train_acc, (time.time()-start0)/3600, np.mean(per_epoch_time)))

    # send results to gmail
    try:
        from gmail import send
        subject = 'Result for Dataset: {}, Model: {}'.format(
            DATASET_NAME, MODEL_NAME)
        body = """Dataset: {},\nModel: {}\n\nparams={}\n\nnet_params={}\n\n{}\n\nTotal Parameters: {}\n\n
    FINAL RESULTS\nTEST ACCURACY: {:.4f}\nTRAIN ACCURACY: {:.4f}\n\n
    Total Time Taken: {:.4f} hrs\nAverage Time Per Epoch: {:.4f} s\n\n\n"""\
          .format(DATASET_NAME, MODEL_NAME, params, net_params, model, net_params['total_param'],
                  test_acc, train_acc, (time.time()-start0)/3600, np.mean(per_epoch_time))
        send(subject, body)
    except:
        pass

    return val_acc, test_acc