def view_model_param(MODEL_NAME, net_params): model = gnn_model(MODEL_NAME, net_params) total_param = 0 print("MODEL DETAILS:\n") print(model) for param in model.parameters(): # print(param.data.size()) total_param += np.prod(list(param.data.size())) print('MODEL/Total parameters:', MODEL_NAME, total_param) return total_param
def train_val_pipeline(MODEL_NAME, dataset, params, net_params, dirs): t0 = time.time() per_epoch_time = [] DATASET_NAME = dataset.name #assert net_params['self_loop'] == False, "No self-loop support for %s dataset" % DATASET_NAME trainset, valset, testset = dataset.train, dataset.val, dataset.test root_log_dir, root_ckpt_dir, write_file_name, write_config_file = dirs device = net_params['device'] # Write the network and optimization hyper-parameters in folder config/ with open(write_config_file + '.txt', 'w') as f: f.write( """Dataset: {},\nModel: {}\n\nparams={}\n\nnet_params={}\n\n\nTotal Parameters: {}\n\n""" .format(DATASET_NAME, MODEL_NAME, params, net_params, net_params['total_param'])) log_dir = os.path.join(root_log_dir, "RUN_" + str(0)) writer = SummaryWriter(log_dir=log_dir) # setting seeds random.seed(params['seed']) np.random.seed(params['seed']) torch.manual_seed(params['seed']) if device.type == 'cuda': torch.cuda.manual_seed(params['seed']) print("Training Graphs: ", len(trainset)) print("Validation Graphs: ", len(valset)) print("Test Graphs: ", len(testset)) print("Number of Classes: ", net_params['n_classes']) model = gnn_model(MODEL_NAME, net_params) model = model.to(device) optimizer = optim.Adam(model.parameters(), lr=params['init_lr'], weight_decay=params['weight_decay']) scheduler = optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode='min', factor=params['lr_reduce_factor'], patience=params['lr_schedule_patience'], verbose=True) epoch_train_losses, epoch_val_losses = [], [] epoch_train_f1s, epoch_val_f1s = [], [] if MODEL_NAME in ['RingGNN', '3WLGNN']: # import train functions specific for WL-GNNs from train.train_TSP_edge_classification import train_epoch_dense as train_epoch, evaluate_network_dense as evaluate_network from functools import partial # util function to pass edge_feat to collate function train_loader = DataLoader(trainset, shuffle=True, collate_fn=partial( dataset.collate_dense_gnn, edge_feat=net_params['edge_feat'])) val_loader = DataLoader(valset, shuffle=False, collate_fn=partial( dataset.collate_dense_gnn, edge_feat=net_params['edge_feat'])) test_loader = DataLoader(testset, shuffle=False, collate_fn=partial( dataset.collate_dense_gnn, edge_feat=net_params['edge_feat'])) else: # import train functions for all other GCNs from train.train_TSP_edge_classification import train_epoch_sparse as train_epoch, evaluate_network_sparse as evaluate_network train_loader = DataLoader(trainset, batch_size=params['batch_size'], shuffle=True, collate_fn=dataset.collate) val_loader = DataLoader(valset, batch_size=params['batch_size'], shuffle=False, collate_fn=dataset.collate) test_loader = DataLoader(testset, batch_size=params['batch_size'], shuffle=False, collate_fn=dataset.collate) # At any point you can hit Ctrl + C to break out of training early. try: with tqdm(range(params['epochs'])) as t: for epoch in t: t.set_description('Epoch %d' % epoch) start = time.time() if MODEL_NAME in [ 'RingGNN', '3WLGNN' ]: # since different batch training function for dense GNNs epoch_train_loss, epoch_train_f1, optimizer = train_epoch( model, optimizer, device, train_loader, epoch, params['batch_size']) else: # for all other models common train function epoch_train_loss, epoch_train_f1, optimizer = train_epoch( model, optimizer, device, train_loader, epoch) epoch_val_loss, epoch_val_f1 = evaluate_network( model, device, val_loader, epoch) _, epoch_test_f1 = evaluate_network(model, device, test_loader, epoch) epoch_train_losses.append(epoch_train_loss) epoch_val_losses.append(epoch_val_loss) epoch_train_f1s.append(epoch_train_f1) epoch_val_f1s.append(epoch_val_f1) writer.add_scalar('train/_loss', epoch_train_loss, epoch) writer.add_scalar('val/_loss', epoch_val_loss, epoch) writer.add_scalar('train/_f1', epoch_train_f1, epoch) writer.add_scalar('val/_f1', epoch_val_f1, epoch) writer.add_scalar('test/_f1', epoch_test_f1, epoch) writer.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], epoch) t.set_postfix(time=time.time() - start, lr=optimizer.param_groups[0]['lr'], train_loss=epoch_train_loss, val_loss=epoch_val_loss, train_f1=epoch_train_f1, val_f1=epoch_val_f1, test_f1=epoch_test_f1) per_epoch_time.append(time.time() - start) # Saving checkpoint ckpt_dir = os.path.join(root_ckpt_dir, "RUN_") if not os.path.exists(ckpt_dir): os.makedirs(ckpt_dir) torch.save(model.state_dict(), '{}.pkl'.format(ckpt_dir + "/epoch_" + str(epoch))) files = glob.glob(ckpt_dir + '/*.pkl') for file in files: epoch_nb = file.split('_')[-1] epoch_nb = int(epoch_nb.split('.')[0]) if epoch_nb < epoch - 1: os.remove(file) scheduler.step(epoch_val_loss) if optimizer.param_groups[0]['lr'] < params['min_lr']: print("\n!! LR EQUAL TO MIN LR SET.") break # Stop training after params['max_time'] hours if time.time() - t0 > params['max_time'] * 3600: print('-' * 89) print( "Max_time for training elapsed {:.2f} hours, so stopping" .format(params['max_time'])) break except KeyboardInterrupt: print('-' * 89) print('Exiting from training early because of KeyboardInterrupt') _, test_f1 = evaluate_network(model, device, test_loader, epoch) _, train_f1 = evaluate_network(model, device, train_loader, epoch) print("Test F1: {:.4f}".format(test_f1)) print("Train F1: {:.4f}".format(train_f1)) print("Convergence Time (Epochs): {:.4f}".format(epoch)) print("TOTAL TIME TAKEN: {:.4f}s".format(time.time() - t0)) print("AVG TIME PER EPOCH: {:.4f}s".format(np.mean(per_epoch_time))) writer.close() """ Write the results in out_dir/results folder """ with open(write_file_name + '.txt', 'w') as f: f.write("""Dataset: {},\nModel: {}\n\nparams={}\n\nnet_params={}\n\n{}\n\nTotal Parameters: {}\n\n FINAL RESULTS\nTEST F1: {:.4f}\nTRAIN F1: {:.4f}\n\n Convergence Time (Epochs): {:.4f}\nTotal Time Taken: {:.4f}hrs\nAverage Time Per Epoch: {:.4f}s\n\n\n"""\ .format(DATASET_NAME, MODEL_NAME, params, net_params, model, net_params['total_param'], np.mean(np.array(test_f1)), np.mean(np.array(train_f1)), epoch, (time.time()-t0)/3600, np.mean(per_epoch_time)))