def main(): parser = argparse.ArgumentParser(description='OGBN (GNN)') parser.add_argument('--device', type=int, default=0) parser.add_argument('--project', type=str, default='lcgnn') parser.add_argument('--dataset', type=str, default='flickr') parser.add_argument('--log_steps', type=int, default=1) parser.add_argument('--num_layers', type=int, default=4) parser.add_argument('--num_heads', type=int, default=2) parser.add_argument('--ego_size', type=int, default=64) parser.add_argument('--hidden_size', type=int, default=64) parser.add_argument('--input_dropout', type=float, default=0.2) parser.add_argument('--hidden_dropout', type=float, default=0.4) parser.add_argument('--weight_decay', type=float, default=0.005) parser.add_argument('--lr', type=float, default=0.001) parser.add_argument('--epochs', type=int, default=500) parser.add_argument('--early_stopping', type=int, default=50) parser.add_argument('--batch_size', type=int, default=512) parser.add_argument('--eval_batch_size', type=int, default=2048) parser.add_argument('--layer_norm', type=int, default=0) parser.add_argument('--src_scale', type=int, default=0) parser.add_argument('--num_workers', type=int, default=4, help='number of workers') parser.add_argument('--pe_type', type=int, default=0) parser.add_argument('--mask', type=int, default=0) parser.add_argument('--mlp', type=int, default=0) parser.add_argument("--optimizer", type=str, default='adamw', choices=['adam', 'adamw'], help="optimizer") parser.add_argument("--scheduler", type=str, default='noam', choices=['noam', 'linear'], help="scheduler") parser.add_argument("--method", type=str, default='acl', choices=['acl', 'l1reg'], help="method for local clustering") parser.add_argument('--warmup', type=int, default=10000) parser.add_argument('--seed', type=int, default=0) parser.add_argument('--load_path', type=str, default='') parser.add_argument('--exp_name', type=str, default='') args = parser.parse_args() print(args) np.random.seed(args.seed) torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) para_dic = { 'nl': args.num_layers, 'nh': args.num_heads, 'es': args.ego_size, 'hs': args.hidden_size, 'id': args.input_dropout, 'hd': args.hidden_dropout, 'bs': args.batch_size, 'pe': args.pe_type, 'op': args.optimizer, 'lr': args.lr, 'wd': args.weight_decay, 'ln': args.layer_norm, 'sc': args.src_scale, 'sd': args.seed, 'md': args.method } para_dic['warm'] = args.warmup para_dic['mask'] = args.mask exp_name = get_exp_name(args.dataset, para_dic, args.exp_name) wandb_name = exp_name.replace('_sd' + str(args.seed), '') wandb.init(name=wandb_name, project=args.project) wandb.config.update(args) device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu' device = torch.device(device) if args.dataset == 'papers100M': dataset = MyNodePropPredDataset(name=args.dataset) elif args.dataset in ['flickr', 'reddit', 'yelp', 'amazon']: dataset = SAINTDataset(name=args.dataset) else: dataset = PygNodePropPredDataset(name=f'ogbn-{args.dataset}') split_idx = dataset.get_idx_split() train_idx = set(split_idx['train'].cpu().numpy()) valid_idx = set(split_idx['valid'].cpu().numpy()) test_idx = set(split_idx['test'].cpu().numpy()) if args.method != "acl": ego_graphs_unpadded = np.load( f'data/{args.dataset}-lc-{args.method}-ego-graphs-{args.ego_size}.npy', allow_pickle=True) conds_unpadded = np.load( f'data/{args.dataset}-lc-{args.method}-conds-{args.ego_size}.npy', allow_pickle=True) else: tmp_ego_size = 256 if args.dataset == 'products' else args.ego_size if args.ego_size < 64: tmp_ego_size = 64 ego_graphs_unpadded = np.load( f'data/{args.dataset}-lc-ego-graphs-{tmp_ego_size}.npy', allow_pickle=True) conds_unpadded = np.load( f'data/{args.dataset}-lc-conds-{tmp_ego_size}.npy', allow_pickle=True) ego_graphs_train, ego_graphs_valid, ego_graphs_test = [], [], [] cut_train, cut_valid, cut_test = [], [], [] for i, x in enumerate(ego_graphs_unpadded): idx = x[0] assert len(x) == len(conds_unpadded[i]) if len(x) > args.ego_size: x = x[:args.ego_size] conds_unpadded[i] = conds_unpadded[i][:args.ego_size] ego_graph = -np.ones(args.ego_size, dtype=np.int32) ego_graph[:len(x)] = x cut_position = np.argmin(conds_unpadded[i]) cut = np.zeros(args.ego_size, dtype=np.float32) cut[:cut_position + 1] = 1.0 if idx in train_idx: ego_graphs_train.append(ego_graph) cut_train.append(cut) elif idx in valid_idx: ego_graphs_valid.append(ego_graph) cut_valid.append(cut) elif idx in test_idx: ego_graphs_test.append(ego_graph) cut_test.append(cut) else: print(f"{idx} not in train/valid/test idx") ego_graphs_train, ego_graphs_valid, ego_graphs_test = torch.LongTensor( ego_graphs_train), torch.LongTensor( ego_graphs_valid), torch.LongTensor(ego_graphs_test) cut_train, cut_valid, cut_test = torch.FloatTensor( cut_train), torch.FloatTensor(cut_valid), torch.FloatTensor(cut_test) pe = None if args.pe_type == 1: pe = torch.load(f'data/{args.dataset}-embedding-{args.hidden_size}.pt') elif args.pe_type == 2: pe = np.fromfile("data/paper100m.pro", dtype=np.float32).reshape(-1, 128) pe = torch.FloatTensor(pe) if args.hidden_size < 128: pe = pe[:, :args.hidden_size] data = dataset[0] if len(data.y.shape) == 1: data.y = data.y.unsqueeze(1) adj = None if args.mask: adj = torch.BoolTensor(~np.load( f'data/{args.dataset}-ego-graphs-adj-{args.ego_size}.npy')) num_classes = dataset.num_classes train_dataset = NodeClassificationDataset(data.x, data.y, ego_graphs_train, pe, args, num_classes, adj, cut_train) train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, collate_fn=batcher(train_dataset), pin_memory=True) valid_dataset = NodeClassificationDataset(data.x, data.y, ego_graphs_valid, pe, args, num_classes, adj, cut_valid) valid_loader = DataLoader(valid_dataset, batch_size=args.eval_batch_size, shuffle=False, num_workers=args.num_workers, collate_fn=batcher(valid_dataset), pin_memory=True) test_dataset = NodeClassificationDataset(data.x, data.y, ego_graphs_test, pe, args, num_classes, adj, cut_test) test_loader = DataLoader(test_dataset, batch_size=args.eval_batch_size, shuffle=False, num_workers=args.num_workers, collate_fn=batcher(test_dataset), pin_memory=True) model = TransformerModel(data.x.size(1) + 1, args.hidden_size, args.num_heads, args.hidden_size, args.num_layers, num_classes, args.input_dropout, args.hidden_dropout, layer_norm=args.layer_norm, src_scale=args.src_scale, mlp=args.mlp).to(device) wandb.watch(model, log='all') if torch.cuda.device_count() > 1: print("Let's use", torch.cuda.device_count(), "GPUs!") # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs model = nn.DataParallel(model) pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print('model parameters:', pytorch_total_params) if not os.path.exists('saved'): os.mkdir('saved') if torch.cuda.device_count() > 1: model.module.init_weights() else: model.init_weights() if args.load_path: model.load_state_dict(torch.load(args.load_path, map_location='cuda:0')) valid_acc, valid_loss = test(model, valid_loader, device, args) valid_output = f'Valid: {100 * valid_acc:.2f}% ' cor_train_acc, _ = test(model, train_loader, device, args) cor_test_acc, cor_test_loss = test(model, test_loader, device, args) train_output = f'Train: {100 * cor_train_acc:.2f}%, ' test_output = f'Test: {100 * cor_test_acc:.2f}%' print(train_output + valid_output + test_output) return best_val_acc = 0 cor_train_acc = 0 cor_test_acc = 0 patience = 0 if args.optimizer == 'adam': optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) elif args.optimizer == 'adamw': optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) else: raise NotImplementedError if args.warmup > 0: if args.scheduler == 'noam': optimizer = NoamOptim( optimizer, args.hidden_size if args.hidden_size > 0 else data.x.size(1), n_warmup_steps=args.warmup) #, init_lr=args.lr) elif args.scheduler == 'linear': optimizer = LinearOptim(optimizer, n_warmup_steps=args.warmup, n_training_steps=args.epochs * len(train_loader), init_lr=args.lr) for epoch in range(1, 1 + args.epochs): # lp = LineProfiler() # lp_wrapper = lp(train) # loss = lp_wrapper(model, train_loader, device, optimizer, args) # lp.print_stats() loss = train(model, train_loader, device, optimizer, args) train_output = valid_output = test_output = '' if epoch >= 10 and epoch % args.log_steps == 0: valid_acc, valid_loss = test(model, valid_loader, device, args) valid_output = f'Valid: {100 * valid_acc:.2f}% ' if valid_acc > best_val_acc: best_val_acc = valid_acc # cor_train_acc, _ = test(model, train_loader, device, args) cor_test_acc, cor_test_loss = test(model, test_loader, device, args) # train_output = f'Train: {100 * cor_train_acc:.2f}%, ' test_output = f'Test: {100 * cor_test_acc:.2f}%' patience = 0 try: if torch.cuda.device_count() > 1: torch.save(model.module.state_dict(), 'saved/' + exp_name + '.pt') else: torch.save(model.state_dict(), 'saved/' + exp_name + '.pt') wandb.save('saved/' + exp_name + '.pt') except FileNotFoundError as e: print(e) else: patience += 1 if patience >= args.early_stopping: print('Early stopping...') break wandb.log({ 'Train Loss': loss, 'Valid Acc': valid_acc, 'best_val_acc': best_val_acc, 'cor_test_acc': cor_test_acc, 'LR': get_lr(optimizer), 'Valid Loss': valid_loss, 'cor_test_loss': cor_test_loss }) else: wandb.log({'Train Loss': loss, 'LR': get_lr(optimizer)}) # train_output + print(f'Epoch: {epoch:02d}, ' f'Loss: {loss:.4f}, ' + valid_output + test_output)
ntokens = len(input_vocab) # the size of vocabulary nclstokens = 4 # D0, D1, S0, S1 ntagtokens = 1 # Binary classification O or D emsize = 512 # embedding dimension nhid = 512 # the dimension of the feedforward network model in nn.TransformerEncoder nlayers = 6 # the number of nn.TransformerEncoderLayer in nn.TransformerEncoder nhead = 8 # the number of heads in the multiheadattention models dropout = 0.1 # the dropout value device = torch.device("cuda" if torch.cuda.is_available() else "cpu") transformer = TransformerModel(ntokens, nclstokens, ntagtokens, emsize, nhead, nhid, nlayers, dropout) if args.model: transformer.load_state_dict(torch.load(args.model)) else: # save random init model torch.save(transformer.state_dict(), "init.mdl") model = nn.DataParallel(transformer).to(device) ###################################################################### # Run the model # ------------- # cls_criterion = nn.CrossEntropyLoss() lr = float(args.lr) # learning rate optimizer = torch.optim.Adam(model.parameters(), lr=lr) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.995) import time