コード例 #1
0
def view_model_param(MODEL_NAME, net_params):
    model = gnn_model(MODEL_NAME, net_params)
    total_param = 0
    print("MODEL DETAILS:\n")
    #print(model)
    for param in model.parameters():
        # print(param.data.size())
        total_param += np.prod(list(param.data.size()))
    print('MODEL/Total parameters:', MODEL_NAME, total_param)
    return total_param
コード例 #2
0
def view_model_param(MODEL_NAME, net_params, verbose=False):
    model = gnn_model(MODEL_NAME, net_params)
    total_param = 0
    print("MODEL DETAILS:\n")
    #print(model)
    for param in model.parameters():
        # print(param.data.size())
        total_param += np.prod(list(param.data.size()))
    print('MODEL/Total parameters:', MODEL_NAME, total_param)

    if verbose:
        print('\n== Net Params:')
        print(net_params)
        print('\n== Model Structure:')
        print(model)
        

    return total_param
コード例 #3
0
def train_val_pipeline(MODEL_NAME, dataset, params, net_params, dirs):

    start0 = time.time()
    per_epoch_time = []

    DATASET_NAME = dataset.name

    if net_params['lap_pos_enc']:
        st = time.time()
        print("[!] Adding Laplacian positional encoding.")
        dataset._add_laplacian_positional_encodings(net_params['pos_enc_dim'])
        print('Time LapPE:', time.time() - st)

    if net_params['full_graph']:
        print("[!] Converting the given graphs to full graphs..")
        dataset._make_full_graph()
        print('Time taken to convert to full graphs:', time.time() - start0)

    if net_params['wl_pos_enc']:
        st = time.time()
        print("[!] Adding WL positional encoding.")
        dataset._add_wl_positional_encodings()
        print('Time WL PE:', time.time() - st)

    trainset, valset, testset = dataset.train, dataset.val, dataset.test

    root_log_dir, root_ckpt_dir, write_file_name, write_config_file = dirs
    device = net_params['device']

    # Write network and optimization hyper-parameters in folder config/
    with open(write_config_file + '.txt', 'w') as f:
        f.write(
            """Dataset: {},\nModel: {}\n\nparams={}\n\nnet_params={}\n\n\nTotal Parameters: {}\n\n"""
            .format(DATASET_NAME, MODEL_NAME, params, net_params,
                    net_params['total_param']))

    log_dir = os.path.join(root_log_dir, "RUN_" + str(0))
    writer = SummaryWriter(log_dir=log_dir)

    # setting seeds
    random.seed(params['seed'])
    np.random.seed(params['seed'])
    torch.manual_seed(params['seed'])
    if device.type == 'cuda':
        torch.cuda.manual_seed(params['seed'])

    print("Training Graphs: ", len(trainset))
    print("Validation Graphs: ", len(valset))
    print("Test Graphs: ", len(testset))
    print("Number of Classes: ", net_params['n_classes'])

    model = gnn_model(MODEL_NAME, net_params)
    model = model.to(device)

    optimizer = optim.Adam(model.parameters(),
                           lr=params['init_lr'],
                           weight_decay=params['weight_decay'])
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode='min',
        factor=params['lr_reduce_factor'],
        patience=params['lr_schedule_patience'],
        verbose=True)

    epoch_train_losses, epoch_val_losses = [], []
    epoch_train_accs, epoch_val_accs = [], []

    # import train and evaluate functions
    from train.train_SBMs_node_classification import train_epoch, evaluate_network

    train_loader = DataLoader(trainset,
                              batch_size=params['batch_size'],
                              shuffle=True,
                              collate_fn=dataset.collate)
    val_loader = DataLoader(valset,
                            batch_size=params['batch_size'],
                            shuffle=False,
                            collate_fn=dataset.collate)
    test_loader = DataLoader(testset,
                             batch_size=params['batch_size'],
                             shuffle=False,
                             collate_fn=dataset.collate)

    # At any point you can hit Ctrl + C to break out of training early.
    try:
        with tqdm(range(params['epochs'])) as t:
            for epoch in t:

                t.set_description('Epoch %d' % epoch)

                start = time.time()

                epoch_train_loss, epoch_train_acc, optimizer = train_epoch(
                    model, optimizer, device, train_loader, epoch)

                epoch_val_loss, epoch_val_acc = evaluate_network(
                    model, device, val_loader, epoch)
                _, epoch_test_acc = evaluate_network(model, device,
                                                     test_loader, epoch)

                epoch_train_losses.append(epoch_train_loss)
                epoch_val_losses.append(epoch_val_loss)
                epoch_train_accs.append(epoch_train_acc)
                epoch_val_accs.append(epoch_val_acc)

                writer.add_scalar('train/_loss', epoch_train_loss, epoch)
                writer.add_scalar('val/_loss', epoch_val_loss, epoch)
                writer.add_scalar('train/_acc', epoch_train_acc, epoch)
                writer.add_scalar('val/_acc', epoch_val_acc, epoch)
                writer.add_scalar('test/_acc', epoch_test_acc, epoch)
                writer.add_scalar('learning_rate',
                                  optimizer.param_groups[0]['lr'], epoch)

                t.set_postfix(time=time.time() - start,
                              lr=optimizer.param_groups[0]['lr'],
                              train_loss=epoch_train_loss,
                              val_loss=epoch_val_loss,
                              train_acc=epoch_train_acc,
                              val_acc=epoch_val_acc,
                              test_acc=epoch_test_acc)

                per_epoch_time.append(time.time() - start)

                # Saving checkpoint
                ckpt_dir = os.path.join(root_ckpt_dir, "RUN_")
                if not os.path.exists(ckpt_dir):
                    os.makedirs(ckpt_dir)
                torch.save(model.state_dict(),
                           '{}.pkl'.format(ckpt_dir + "/epoch_" + str(epoch)))

                files = glob.glob(ckpt_dir + '/*.pkl')
                for file in files:
                    epoch_nb = file.split('_')[-1]
                    epoch_nb = int(epoch_nb.split('.')[0])
                    if epoch_nb < epoch - 1:
                        os.remove(file)

                scheduler.step(epoch_val_loss)

                if optimizer.param_groups[0]['lr'] < params['min_lr']:
                    print("\n!! LR SMALLER OR EQUAL TO MIN LR THRESHOLD.")
                    break

                # Stop training after params['max_time'] hours
                if time.time() - start0 > params['max_time'] * 3600:
                    print('-' * 89)
                    print(
                        "Max_time for training elapsed {:.2f} hours, so stopping"
                        .format(params['max_time']))
                    break

    except KeyboardInterrupt:
        print('-' * 89)
        print('Exiting from training early because of KeyboardInterrupt')

    _, test_acc = evaluate_network(model, device, test_loader, epoch)
    _, train_acc = evaluate_network(model, device, train_loader, epoch)
    print("Test Accuracy: {:.4f}".format(test_acc))
    print("Train Accuracy: {:.4f}".format(train_acc))
    print("Convergence Time (Epochs): {:.4f}".format(epoch))
    print("TOTAL TIME TAKEN: {:.4f}s".format(time.time() - start0))
    print("AVG TIME PER EPOCH: {:.4f}s".format(np.mean(per_epoch_time)))

    writer.close()
    """
        Write the results in out_dir/results folder
    """
    with open(write_file_name + '.txt', 'w') as f:
        f.write("""Dataset: {},\nModel: {}\n\nparams={}\n\nnet_params={}\n\n{}\n\nTotal Parameters: {}\n\n
    FINAL RESULTS\nTEST ACCURACY: {:.4f}\nTRAIN ACCURACY: {:.4f}\n\n
    Convergence Time (Epochs): {:.4f}\nTotal Time Taken: {:.4f} hrs\nAverage Time Per Epoch: {:.4f} s\n\n\n"""\
          .format(DATASET_NAME, MODEL_NAME, params, net_params, model, net_params['total_param'],
                  test_acc, train_acc, epoch, (time.time()-start0)/3600, np.mean(per_epoch_time)))
コード例 #4
0
def train_val_pipeline(MODEL_NAME, dataset, params, net_params, dirs):

    start0 = time.time()
    per_epoch_time = []

    DATASET_NAME = dataset.name

    if MODEL_NAME in ['GCN', 'GAT']:
        if net_params['self_loop']:
            print(
                "[!] Adding graph self-loops for GCN/GAT models (central node trick)."
            )
            dataset._add_self_loops()

    if MODEL_NAME in ['GatedGCN_pyg', 'ResGatedGCN_pyg']:
        if net_params['pos_enc']:
            print("[!] Adding graph positional encoding.")
            dataset._add_positional_encodings(net_params['pos_enc_dim'])
            print('Time PE:', time.time() - start0)

    trainset, valset, testset = dataset.train, dataset.val, dataset.test
    # transform = T.ToSparseTensor() To do to save memory
    # self.train.graph_lists = [positional_encoding(g, pos_enc_dim, framework='pyg') for _, g in enumerate(dataset.train)]
    root_log_dir, root_ckpt_dir, write_file_name, write_config_file = dirs
    device = net_params['device']

    # Write network and optimization hyper-parameters in folder config/
    with open(write_config_file + '.txt', 'w') as f:
        f.write(
            """Dataset: {},\nModel: {}\n\nparams={}\n\nnet_params={}\n\n\nTotal Parameters: {}\n\n"""
            .format(DATASET_NAME, MODEL_NAME, params, net_params,
                    net_params['total_param']))

    log_dir = os.path.join(root_log_dir, "RUN_" + str(0))
    writer = SummaryWriter(log_dir=log_dir)

    # setting seeds
    random.seed(params['seed'])
    np.random.seed(params['seed'])
    torch.manual_seed(params['seed'])
    if device.type == 'cuda':
        torch.cuda.manual_seed(params['seed'])

    print("Training Graphs: ", len(trainset))
    print("Validation Graphs: ", len(valset))
    print("Test Graphs: ", len(testset))
    print("Number of Classes: ", net_params['n_classes'])

    model = gnn_model(MODEL_NAME, net_params)
    model = model.to(device)

    optimizer = optim.Adam(model.parameters(),
                           lr=params['init_lr'],
                           weight_decay=params['weight_decay'])
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode='min',
        factor=params['lr_reduce_factor'],
        patience=params['lr_schedule_patience'],
        verbose=True)

    epoch_train_losses, epoch_val_losses = [], []
    epoch_train_accs, epoch_val_accs = [], []

    if MODEL_NAME in ['RingGNN', '3WLGNN']:
        # import train functions specific for WL-GNNs
        from train.train_SBMs_node_classification import train_epoch_dense as train_epoch, evaluate_network_dense as evaluate_network

        train_loader = DataLoader(trainset,
                                  shuffle=True,
                                  collate_fn=dataset.collate_dense_gnn)
        val_loader = DataLoader(valset,
                                shuffle=False,
                                collate_fn=dataset.collate_dense_gnn)
        test_loader = DataLoader(testset,
                                 shuffle=False,
                                 collate_fn=dataset.collate_dense_gnn)

    else:
        # import train functions for all other GCNs
        from train.train_SBMs_node_classification import train_epoch_sparse as train_epoch, evaluate_network_sparse as evaluate_network
        # train_loader = DataLoaderpyg(trainset, batch_size=2, shuffle=False)
        train_loader = DataLoaderpyg(
            trainset, batch_size=params['batch_size'],
            shuffle=True) if params['framework'] == 'pyg' else DataLoader(
                trainset,
                batch_size=params['batch_size'],
                shuffle=True,
                collate_fn=dataset.collate)
        val_loader = DataLoaderpyg(
            valset, batch_size=params['batch_size'],
            shuffle=False) if params['framework'] == 'pyg' else DataLoader(
                valset,
                batch_size=params['batch_size'],
                shuffle=True,
                collate_fn=dataset.collate)
        test_loader = DataLoaderpyg(
            testset, batch_size=params['batch_size'],
            shuffle=False) if params['framework'] == 'pyg' else DataLoader(
                testset,
                batch_size=params['batch_size'],
                shuffle=True,
                collate_fn=dataset.collate)

    # At any point you can hit Ctrl + C to break out of training early.
    try:
        with tqdm(range(params['epochs']), ncols=0) as t:
            for epoch in t:

                t.set_description('Epoch %d' % epoch)

                start = time.time()

                if MODEL_NAME in [
                        'RingGNN', '3WLGNN'
                ]:  # since different batch training function for dense GNNs
                    epoch_train_loss, epoch_train_acc, optimizer = train_epoch(
                        model, optimizer, device, train_loader, epoch)
                else:  # for all other models common train function
                    epoch_train_loss, epoch_train_acc, optimizer = train_epoch(
                        model, optimizer, device, train_loader, epoch,
                        params['framework'])

                epoch_val_loss, epoch_val_acc = evaluate_network(
                    model, device, val_loader, epoch, params['framework'])
                _, epoch_test_acc = evaluate_network(model, device,
                                                     test_loader, epoch,
                                                     params['framework'])

                epoch_train_losses.append(epoch_train_loss)
                epoch_val_losses.append(epoch_val_loss)
                epoch_train_accs.append(epoch_train_acc)
                epoch_val_accs.append(epoch_val_acc)

                writer.add_scalar('train/_loss', epoch_train_loss, epoch)
                writer.add_scalar('val/_loss', epoch_val_loss, epoch)
                writer.add_scalar('train/_acc', epoch_train_acc, epoch)
                writer.add_scalar('val/_acc', epoch_val_acc, epoch)
                writer.add_scalar('test/_acc', epoch_test_acc, epoch)
                writer.add_scalar('learning_rate',
                                  optimizer.param_groups[0]['lr'], epoch)

                t.set_postfix(time=time.time() - start,
                              lr=optimizer.param_groups[0]['lr'],
                              train_loss=epoch_train_loss,
                              val_loss=epoch_val_loss,
                              train_acc=epoch_train_acc,
                              val_acc=epoch_val_acc,
                              test_acc=epoch_test_acc)

                per_epoch_time.append(time.time() - start)

                # Saving checkpoint
                ckpt_dir = os.path.join(root_ckpt_dir, "RUN_")
                if not os.path.exists(ckpt_dir):
                    os.makedirs(ckpt_dir)
                # the function to save the checkpoint
                # torch.save(model.state_dict(), '{}.pkl'.format(ckpt_dir + "/epoch_" + str(epoch)))

                files = glob.glob(ckpt_dir + '/*.pkl')
                for file in files:
                    epoch_nb = file.split('_')[-1]
                    epoch_nb = int(epoch_nb.split('.')[0])
                    if epoch_nb < epoch - 1:
                        os.remove(file)

                scheduler.step(epoch_val_loss)
                # it used to test the scripts
                # if epoch == 1:
                #     break

                if optimizer.param_groups[0]['lr'] < params['min_lr']:
                    print("\n!! LR SMALLER OR EQUAL TO MIN LR THRESHOLD.")
                    break

                # Stop training after params['max_time'] hours
                if time.time() - start0 > params['max_time'] * 3600:
                    print('-' * 89)
                    print(
                        "Max_time for training elapsed {:.2f} hours, so stopping"
                        .format(params['max_time']))
                    break

    except KeyboardInterrupt:
        print('-' * 89)
        print('Exiting from training early because of KeyboardInterrupt')

    _, test_acc = evaluate_network(model, device, test_loader, epoch,
                                   params['framework'])
    _, val_acc = evaluate_network(model, device, val_loader, epoch,
                                  params['framework'])
    _, train_acc = evaluate_network(model, device, train_loader, epoch,
                                    params['framework'])
    print("Test Accuracy: {:.4f}".format(test_acc))
    print("Val Accuracy: {:.4f}".format(val_acc))
    print("Train Accuracy: {:.4f}".format(train_acc))
    print("Convergence Time (Epochs): {:.4f}".format(epoch))
    print("TOTAL TIME TAKEN: {:.4f}s".format(time.time() - start0))
    print("AVG TIME PER EPOCH: {:.4f}s".format(np.mean(per_epoch_time)))

    writer.close()
    """
        Write the results in out_dir/results folder
    """
    with open(write_file_name + '.txt', 'w') as f:
        f.write("""Dataset: {},\nModel: {}\n\nparams={}\n\nnet_params={}\n\n{}\n\nTotal Parameters: {}\n\n
    FINAL RESULTS\nTEST ACCURACY: {:.4f}\nval ACCURACY: {:.4f}\nTRAIN ACCURACY: {:.4f}\n\n
    Convergence Time (Epochs): {:.4f}\nTotal Time Taken: {:.4f} hrs\nAverage Time Per Epoch: {:.4f} s\n\n\n"""\
          .format(DATASET_NAME, MODEL_NAME, params, net_params, model, net_params['total_param'],
                  test_acc, val_acc,train_acc, epoch, (time.time()-start0)/3600, np.mean(per_epoch_time)))
コード例 #5
0
def train_val_pipeline(MODEL_NAME, DATASET_NAME, params, net_params, args):

    # setting seeds
    random.seed(params['seed'])
    np.random.seed(params['seed'])
    torch.manual_seed(params['seed'])
    device = net_params['device']
    if device == 'cuda':
        torch.cuda.manual_seed(params['seed'])

    dataset = LoadData(DATASET_NAME)
    trainset, valset, testset = dataset.train, dataset.val, dataset.test

    net_params['in_dim'] = torch.unique(dataset.train[0][0].ndata['feat'],
                                        dim=0).size(
                                            0)  # node_dim (feat is an integer)
    net_params['n_classes'] = torch.unique(dataset.train[0][1], dim=0).size(0)
    net_params['total_param'] = view_model_param(MODEL_NAME, net_params)

    load_model = args.load_model
    aug_type_list = [
        'drop_nodes', 'drop_add_edges', 'noise', 'mask', 'subgraph'
    ]

    if MODEL_NAME in ['GCN', 'GAT']:
        if net_params['self_loop']:
            print(
                "[!] Adding graph self-loops for GCN/GAT models (central node trick)."
            )
            dataset._add_self_loops()

    print('-' * 40 + "Finetune Option" + '-' * 40)
    print('SEED:           [{}]'.format(params['seed']))
    print("Data  Name:     [{}]".format(DATASET_NAME))
    print("Model Name:     [{}]".format(MODEL_NAME))
    print("Training Graphs:[{}]".format(len(trainset)))
    print("Valid Graphs:   [{}]".format(len(valset)))
    print("Test Graphs:    [{}]".format(len(testset)))
    print("Number Classes: [{}]".format(net_params['n_classes']))
    print("Learning rate:  [{}]".format(params['init_lr']))
    print('-' * 40 + "Contrastive Option" + '-' * 40)
    print("Load model:     [{}]".format(load_model))
    print("Aug Type:       [{}]".format(aug_type_list[args.aug]))
    print("Projection head:[{}]".format(args.head))
    print('-' * 100)

    model = gnn_model(MODEL_NAME, net_params)

    if load_model:
        output_path = './001_contrastive_models'
        save_model_dir0 = os.path.join(output_path, DATASET_NAME)
        save_model_dir1 = os.path.join(save_model_dir0,
                                       aug_type_list[args.aug])

        if args.head:
            save_model_dir1 += "_head"
        else:
            save_model_dir1 += "_no_head"
        save_model_dir2 = os.path.join(save_model_dir1, MODEL_NAME)
        load_file_name = glob.glob(save_model_dir2 + '/*.pkl')
        checkpoint = torch.load(load_file_name[-1])
        model_dict = model.state_dict()

        state_dict = {
            k: v
            for k, v in checkpoint.items() if k in model_dict.keys()
        }
        model.load_state_dict(state_dict)
        print('Success load pre-trained model!: [{}]'.format(
            load_file_name[-1]))
    else:
        print('No model load!: Test baseline! ')

    model = model.to(device)
    optimizer = optim.Adam(model.parameters(),
                           lr=params['init_lr'],
                           weight_decay=params['weight_decay'])
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode='min',
        factor=params['lr_reduce_factor'],
        patience=params['lr_schedule_patience'],
        verbose=True)

    train_loader = DataLoader(trainset,
                              batch_size=params['batch_size'],
                              shuffle=True,
                              collate_fn=dataset.collate)
    val_loader = DataLoader(valset,
                            batch_size=params['batch_size'],
                            shuffle=False,
                            collate_fn=dataset.collate)
    test_loader = DataLoader(testset,
                             batch_size=params['batch_size'],
                             shuffle=False,
                             collate_fn=dataset.collate)

    for epoch in range(params['epochs']):

        epoch_train_loss, epoch_train_acc, optimizer = train_epoch(
            model, optimizer, device, train_loader, epoch)
        epoch_val_loss, epoch_val_acc = evaluate_network(
            model, device, val_loader, epoch)
        epoch_test_loss, epoch_test_acc = evaluate_network(
            model, device, test_loader, epoch)
        _, epoch_test_acc = evaluate_network(model, device, test_loader, epoch)

        print('-' * 80)
        print(
            time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + ' | ' +
            "Epoch [{:>2d}]  Test Acc: [{:.4f}]".format(
                epoch + 1, epoch_test_acc))
        print('-' * 80)

        scheduler.step(epoch_val_loss)

        if optimizer.param_groups[0]['lr'] < params['min_lr']:
            print("\n!! LR SMALLER OR EQUAL TO MIN LR THRESHOLD.")
            break

    _, test_acc = evaluate_network(model, device, test_loader, epoch)
    _, train_acc = evaluate_network(model, device, train_loader, epoch)
    print("Test Accuracy: {:.4f}".format(test_acc))
    print("Train Accuracy: {:.4f}".format(train_acc))
    return train_acc, test_acc
コード例 #6
0
test_loader = DataLoader(dataset.test,
                         batch_size=args.batch_size,
                         collate_fn=dataset.collate,
                         num_workers=5)

for key, _net_param in net_params.items():
    print(f'Starting {args.net}{key} on {args.dataset}')
    time_start = time.perf_counter()

    net_param = deepcopy(_net_param)
    net_param['in_dim'] = torch.unique(dataset.train[0][0].ndata['feat'],
                                       dim=0).size(0)
    net_param['n_classes'] = torch.unique(dataset.train[0][1], dim=0).size(0)
    net_param['device'] = 'cuda:0'

    net = gnn_model(args.net, net_param)
    net.cuda()

    optimizer = optim.Adam(net.parameters(),
                           lr=params['init_lr'],
                           weight_decay=params['weight_decay'])
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode='min',
        factor=params['lr_reduce_factor'],
        patience=params['lr_schedule_patience'])

    with tqdm(range(params['epochs'])) as epochs:
        for e in epochs:
            epochs.set_description(f'Epoch #{e+1}')
コード例 #7
0
def train_val_pipeline(MODEL_NAME, dataset, params, net_params, dirs):
    
    start0 = time.time()
    per_epoch_time = []
    
    DATASET_NAME = dataset.name
    
    if MODEL_NAME in ['GCN', 'GAT']:
        if net_params['self_loop']:
            print("[!] Adding graph self-loops for GCN/GAT models (central node trick).")
            dataset._add_self_loops()
    
    trainset, valset, testset = dataset.train, dataset.val, dataset.test
        
    root_log_dir, root_ckpt_dir, write_file_name, write_config_file = dirs
    device = net_params['device']
    
    # Write network and optimization hyper-parameters in folder config/
    with open(write_config_file + '.txt', 'w') as f:
        f.write("""Dataset: {},\nModel: {}\n\nparams={}\n\nnet_params={}\n\n\nTotal Parameters: {}\n\n"""                .format(DATASET_NAME, MODEL_NAME, params, net_params, net_params['total_param']))
        
    log_dir = os.path.join(root_log_dir, "RUN_" + str(0))
    writer = SummaryWriter(log_dir=log_dir)
    print("Log Dir: %s" % log_dir)

    # setting seeds
    random.seed(params['seed'])
    np.random.seed(params['seed'])
    torch.manual_seed(params['seed'])
    if device == 'cuda':
        torch.cuda.manual_seed(params['seed'])
    
    print("Training Graphs: ", len(trainset))
    print("Validation Graphs: ", len(valset))
    print("Test Graphs: ", len(testset))
    print("Number of Classes: ", net_params['n_classes'])
    print("Number of Layers: %s" % net_params['L'])
    print("Enable My Layer: %s" % net_params['my_layer'])

    model = gnn_model(MODEL_NAME, net_params)
    model = model.to(device)

    optimizer = optim.Adam(model.parameters(), lr=params['init_lr'], weight_decay=params['weight_decay'])
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min',
                                                     factor=params['lr_reduce_factor'],
                                                     patience=params['lr_schedule_patience'],
                                                     verbose=True)
    
    epoch_train_losses, epoch_val_losses, epoch_test_losses = [], [], []
    epoch_train_accs, epoch_val_accs, epoch_test_accs = [], [], []
    
    train_loader = DataLoader(trainset, batch_size=params['batch_size'], shuffle=True, collate_fn=dataset.collate)
    val_loader = DataLoader(valset, batch_size=params['batch_size'], shuffle=False, collate_fn=dataset.collate)
    test_loader = DataLoader(testset, batch_size=params['batch_size'], shuffle=False, collate_fn=dataset.collate)

    start_time_str = time.strftime('%Hh%Mm%Ss on %b %d %Y')
        
    # At any point you can hit Ctrl + C to break out of training early.
    try:
        with tqdm(range(params['epochs'])) as t:
            for epoch in t:

                t.set_description('Epoch %d' % epoch)

                start = time.time()

                epoch_train_loss, epoch_train_acc, optimizer = train_epoch(model, optimizer, device, train_loader, epoch)
                epoch_val_loss, epoch_val_acc = evaluate_network(model, device, val_loader, epoch)
                epoch_test_loss, epoch_test_acc = evaluate_network(model, device, test_loader, epoch)

                epoch_train_losses.append(epoch_train_loss)
                epoch_val_losses.append(epoch_val_loss)
                epoch_test_losses.append(epoch_test_loss)
                epoch_train_accs.append(epoch_train_acc)
                epoch_val_accs.append(epoch_val_acc)
                epoch_test_accs.append(epoch_test_acc)

                writer.add_scalar('train/_loss', epoch_train_loss, epoch)
                writer.add_scalar('val/_loss', epoch_val_loss, epoch)
                writer.add_scalar('test/_loss', epoch_test_loss, epoch)
                writer.add_scalar('train/_acc', epoch_train_acc, epoch)
                writer.add_scalar('val/_acc', epoch_val_acc, epoch)
                writer.add_scalar('test/_acc', epoch_test_acc, epoch)
                writer.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], epoch)

                t.set_postfix(time=time.time()-start, lr=optimizer.param_groups[0]['lr'],
                              train_loss=epoch_train_loss, val_loss=epoch_val_loss,
                              train_acc=epoch_train_acc, val_acc=epoch_val_acc,
                              test_acc=epoch_test_acc)

                per_epoch_time.append(time.time()-start)

                # Saving checkpoint
                ckpt_dir = os.path.join(root_ckpt_dir, "RUN_")
                if not os.path.exists(ckpt_dir):
                    os.makedirs(ckpt_dir)
                torch.save(model.state_dict(), '{}.pkl'.format(ckpt_dir + "/epoch_" + str(epoch)))

                files = glob.glob(ckpt_dir + '/*.pkl')
                for file in files:
                    epoch_nb = file.split('_')[-1]
                    epoch_nb = int(epoch_nb.split('.')[0])
                    if epoch_nb < epoch-1:
                        os.remove(file)

                scheduler.step(epoch_val_loss)

                if optimizer.param_groups[0]['lr'] < params['min_lr']:
                    print("\n!! LR SMALLER OR EQUAL TO MIN LR THRESHOLD.")
                    break
                    
                # Stop training after params['max_time'] hours
                if time.time()-start0 > params['max_time']*3600:
                    print('-' * 89)
                    print("Max_time for training elapsed {:.2f} hours, so stopping".format(params['max_time']))
                    break
    
    except KeyboardInterrupt:
        print('-' * 89)
        print('Exiting from training early because of KeyboardInterrupt')
    
    
    _, test_acc = evaluate_network(model, device, test_loader, epoch)
    _, train_acc = evaluate_network(model, device, train_loader, epoch)
    print("Test Accuracy: {:.4f}".format(test_acc))
    print("Train Accuracy: {:.4f}".format(train_acc))
    print("TOTAL TIME TAKEN: {:.4f}s".format(time.time()-start0))
    print("AVG TIME PER EPOCH: {:.4f}s".format(np.mean(per_epoch_time)))

    writer.close()

    end_time_str = time.strftime('%Hh%Mm%Ss on %b %d %Y')

    """
        Write the results in out_dir/results folder
    """        
    result_txt = """
Log Dir: {}\nStart Time: {}\nEnd Time: {}\n\n
Dataset: {},\nModel: {}\n\nparams={}\n\nnet_params={}\n\n{}\n\nTotal Parameters: {}\n\n
FINAL RESULTS\nTEST ACCURACY: {:.4f}\nTRAIN ACCURACY: {:.4f}\n\n
Total Time Taken: {:.4f} hrs\nAverage Time Per Epoch: {:.4f} s\n\n\n"""\
        .format(log_dir, start_time_str, end_time_str,
                DATASET_NAME, MODEL_NAME, params, net_params, model, net_params['total_param'],
                test_acc, train_acc, (time.time()-start0)/3600, np.mean(per_epoch_time))

    with open(write_file_name + '.txt', 'w') as f:
        print("Writing results to %s" % write_file_name + '.txt')
        f.write(result_txt)

        
    # send results to gmail
    try:
        from gmail import send
        subject = 'Result for Dataset: {}, Model: {}'.format(DATASET_NAME, MODEL_NAME)
        body = """Dataset: {},\nModel: {}\n\nparams={}\n\nnet_params={}\n\n{}\n\nTotal Parameters: {}\n\n
    FINAL RESULTS\nTEST ACCURACY: {:.4f}\nTRAIN ACCURACY: {:.4f}\n\n
    Total Time Taken: {:.4f} hrs\nAverage Time Per Epoch: {:.4f} s\n\n\n"""\
          .format(DATASET_NAME, MODEL_NAME, params, net_params, model, net_params['total_param'],
                  test_acc, train_acc, (time.time()-start0)/3600, np.mean(per_epoch_time))
        send(subject, body)
    except:
        pass