def setup_data_loaders(data_folder, batch_size=32, run_test=False, num_workers=3): """ Sets up the dataloaders. """ train_dataset = ZINC(data_folder, subset=True, split='train') val_dataset = ZINC(data_folder, subset=True, split='val') if run_test: test_dataset = ZINC(data_folder, subset=True, split='test') train_loader = DataLoader(train_dataset, batch_size, shuffle=True, num_workers=num_workers) val_loader = DataLoader(val_dataset, batch_size, shuffle=False, num_workers=num_workers) if run_test: test_loader = DataLoader(test_dataset, batch_size, shuffle=False, num_workers=num_workers) else: test_loader = None return train_loader, val_loader, test_loader
def prepare_data(self): path = osp.join( osp.dirname(osp.realpath(__file__)), "..", "..", "data", self.NAME ) self.train_dataset = ZINC(path, subset=True, split="train") self.val_dataset = ZINC(path, subset=True, split="val") self.test_dataset = ZINC(path, subset=True, split="test")
def load_split_data(config): train_data = ZINC(DATA_PATH, subset=True, split='train', transform=feat_transform, pre_transform=preproc) val_data = ZINC(DATA_PATH, subset=True, split='val', transform=feat_transform, pre_transform=preproc) train_iter = CRaWlLoader(train_data, shuffle=True, batch_size=config['batch_size'], num_workers=4) val_iter = CRaWlLoader(val_data, batch_size=100, num_workers=4) return train_iter, val_iter
def zinc_data(root, batch_size=128): data = dict() for split in ["train", "val", "test"]: data[split] = DataLoader( ZINC(root, subset=True, split=split), batch_size=batch_size, shuffle=(split == "train"), ) return data
def get_dataset(dataset_name): """ Retrieves the dataset corresponding to the given name. """ path = join('dataset', dataset_name) if dataset_name == 'reddit': dataset = Reddit(path) elif dataset_name == 'flickr': dataset = Flickr(path) elif dataset_name == 'zinc': dataset = ZINC(root='dataset', subset=True, split='train') elif dataset_name == 'QM9': dataset = QM9(root='dataset') elif dataset_name == 'github': dataset = GitHub(path) elif dataset_name == 'ppi': dataset = PPI(path) elif dataset_name in ['amazon_comp', 'amazon_photo']: dataset = Amazon(path, "Computers", T.NormalizeFeatures() ) if dataset_name == 'amazon_comp' else Amazon( path, "Photo", T.NormalizeFeatures()) data = dataset.data idx_train, idx_test = train_test_split(list(range(data.x.shape[0])), test_size=0.4, random_state=42) idx_val, idx_test = train_test_split(idx_test, test_size=0.5, random_state=42) data.train_mask = torch.tensor(idx_train) data.val_mask = torch.tensor(idx_val) data.test_mask = torch.tensor(idx_test) dataset.data = data elif dataset_name in ["Cora", "CiteSeer", "PubMed"]: dataset = Planetoid(path, name=dataset_name, split="public", transform=T.NormalizeFeatures()) else: raise NotImplementedError return dataset
from torch_geometric.nn.inits import zeros argparser = argparse.ArgumentParser("multi-gpu training") argparser.add_argument('--epochs', type=int, default=300) argparser.add_argument('--hidden', type=int, default=100) argparser.add_argument('--emb', type=int, default=100) argparser.add_argument('--layers', type=int, default=4) argparser.add_argument('--lr', type=float, default=0.001) argparser.add_argument('--dropout', type=float, default=0.0) argparser.add_argument('--rank', type=int, default=100) argparser.add_argument('--batch', type=int, default=1000) args = argparser.parse_args() #args.hidden,args.rank = args.emb,args.emb train_dataset = ZINC(osp.join('torch_geometric_data','zinc'), subset=True, split='train') val_dataset = ZINC(osp.join('torch_geometric_data','zinc'), subset=True, split='val') test_dataset = ZINC(osp.join('torch_geometric_data','zinc'), subset=True, split='test') train_loader = DataLoader(train_dataset, batch_size=args.batch, shuffle=True) val_loader = DataLoader(val_dataset, batch_size=args.batch) test_loader = DataLoader(test_dataset, batch_size=args.batch) class graph_cp_pooling(torch.nn.Module): def __init__(self, out_feats): super(graph_cp_pooling, self).__init__() self.w = torch.nn.Linear(out_feats+1, out_feats) self.reset_parameters() def reset_parameters(self):
import os.path as osp import torch import torch.nn.functional as F from torch.nn import Embedding, Linear, ModuleList, ReLU, Sequential from torch.optim.lr_scheduler import ReduceLROnPlateau from torch_geometric.datasets import ZINC from torch_geometric.loader import DataLoader from torch_geometric.nn import BatchNorm, PNAConv, global_add_pool from torch_geometric.utils import degree path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'ZINC') train_dataset = ZINC(path, subset=True, split='train') val_dataset = ZINC(path, subset=True, split='val') test_dataset = ZINC(path, subset=True, split='test') train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True) val_loader = DataLoader(val_dataset, batch_size=128) test_loader = DataLoader(test_dataset, batch_size=128) # Compute in-degree histogram over training data. deg = torch.zeros(5, dtype=torch.long) for data in train_dataset: d = degree(data.edge_index[1], num_nodes=data.num_nodes, dtype=torch.long) deg += torch.bincount(d, minlength=deg.numel()) class Net(torch.nn.Module): def __init__(self): super().__init__()
from model import Net parser = argparse.ArgumentParser() parser.add_argument('--device', type=int, default=0) parser.add_argument('--hidden_channels', type=int, default=128) parser.add_argument('--num_layers', type=int, default=3) parser.add_argument('--dropout', type=float, default=0.0) parser.add_argument('--epochs', type=int, default=300) parser.add_argument('--no_inter_message_passing', action='store_true') args = parser.parse_args() print(args) root = 'data/ZINC' transform = JunctionTree() train_dataset = ZINC(root, subset=True, split='train', pre_transform=transform) val_dataset = ZINC(root, subset=True, split='val', pre_transform=transform) test_dataset = ZINC(root, subset=True, split='test', pre_transform=transform) train_loader = DataLoader(train_dataset, 128, shuffle=True, num_workers=12) val_loader = DataLoader(val_dataset, 1000, shuffle=False, num_workers=12) test_loader = DataLoader(test_dataset, 1000, shuffle=False, num_workers=12) device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu' model = Net(hidden_channels=args.hidden_channels, out_channels=1, num_layers=args.num_layers, dropout=args.dropout, inter_message_passing=not args.no_inter_message_passing).to(device)
verbose=True) lr_limit = args.lr_limit if args.load_model: model = torch.load(savedir + model_name + '.pkl') pytorch_total_params = sum(p.numel() for p in model.parameters()) print("Total number of parameters", pytorch_total_params) loss_fct = nn.L1Loss(reduction='sum') # Load the data batch_size = args.batch_size transform = OneHotNodeEdgeFeatures(model_config['num_input_features'] - 1, model_config['num_edge_features']) train_data = ZINC(rootdir, subset=args.subset, split='train', pre_transform=transform) val_data = ZINC(rootdir, subset=args.subset, split='val', pre_transform=transform) test_data = ZINC(rootdir, subset=args.subset, split='test', pre_transform=transform) train_loader = DataLoader(train_data, batch_size, shuffle=True) val_loader = DataLoader(val_data, batch_size, shuffle=False) test_loader = DataLoader(test_data, batch_size, shuffle=False) print("Starting to train") for epoch in range(args.epochs): if args.load_model: break epoch_start = time.time() tr_loss = train() current_lr = optimizer.param_groups[0]["lr"] if current_lr < lr_limit:
def __init__(self, f1_alpha, f2_alpha, f3_alpha): dataset = ZINC('data/ZINC') super(ZINCSampler, self).__init__(dataset, f1_alpha, f2_alpha, f3_alpha)
def load_zinc(root): dataset = ZINC(root, subset=True) batch = Batch.from_data_list([dataset[i] for i in range(10000)]) return batch
def main(): # Training settings parser = argparse.ArgumentParser( description='GNN baselines on ogbgmol* data with Pytorch Geometrics') parser.add_argument('--device', type=int, default=0, help='which gpu to use if any (default: 0)') parser.add_argument( '--gnn', type=str, default='gin-virtual', help= 'GNN gin, gin-virtual, or gcn, or gcn-virtual (default: gin-virtual)') parser.add_argument('--drop_ratio', type=float, default=0.5, help='dropout ratio (default: 0.5)') parser.add_argument( '--num_layer', type=int, default=5, help='number of GNN message passing layers (default: 5)') parser.add_argument( '--emb_dim', type=int, default=512, help='dimensionality of hidden units in GNNs (default: 300)') parser.add_argument('--batch_size', type=int, default=32, help='input batch size for training (default: 32)') parser.add_argument('--epochs', type=int, default=100, help='number of epochs to train (default: 100)') parser.add_argument('--num_workers', type=int, default=0, help='number of workers (default: 0)') parser.add_argument('--dataset', type=str, default="molhiv", help='dataset name (default: ogbg-molhiv)') parser.add_argument( '--rank', type=int, default=512, help='dimensionality of rank units in GNNs (default: 300)') parser.add_argument('--filename', type=str, default="", help='filename to output result (default: )') parser.add_argument('--lr', type=float, default=0.003) parser.add_argument('--wd', type=float, default=5e-5, help='Weight decay (L2 loss on parameters).') args = parser.parse_args() device = torch.device( "cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu") train_dataset = ZINC(os.path.join('torch_geometric_data', 'zinc'), subset=True, split='train') val_dataset = ZINC(os.path.join('torch_geometric_data', 'zinc'), subset=True, split='val') test_dataset = ZINC(os.path.join('torch_geometric_data', 'zinc'), subset=True, split='test') n_classes = 1 in_feat = train_dataset[0].x.shape[1] train_graphs = [to_dgl(g) for g in train_dataset] val_graphs = [to_dgl(g) for g in val_dataset] test_graphs = [to_dgl(g) for g in test_dataset] ### automatic evaluator. takes dataset name as input train_loader = DataLoader(train_graphs, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, collate_fn=collate_dgl) valid_loader = DataLoader(val_graphs, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, collate_fn=collate_dgl) test_loader = DataLoader(test_graphs, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, collate_fn=collate_dgl) model = GNN(num_tasks=n_classes, in_dim=in_feat, num_layer=args.num_layer, emb_dim=args.emb_dim, rank=args.rank, drop_ratio=args.drop_ratio).to(device) optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd) valid_curve = [] test_curve = [] train_curve = [] for epoch in range(1, args.epochs + 1): print("=====Epoch {}".format(epoch)) print('Training...') train(model, device, train_loader, optimizer) print('Evaluating...') train_perf = eval(model, device, train_loader) valid_perf = eval(model, device, valid_loader) test_perf = eval(model, device, test_loader) print({ 'Train': train_perf, 'Validation': valid_perf, 'Test': test_perf }) train_curve.append(train_perf) valid_curve.append(valid_perf) test_curve.append(test_perf) best_val_epoch = np.argmin(np.array(valid_curve)) best_train = min(train_curve) print('Finished training!') print('Best validation score: {}'.format(valid_curve[best_val_epoch])) print('Test score: {}'.format(test_curve[best_val_epoch])) if not args.filename == '': torch.save( { 'Val': valid_curve[best_val_epoch], 'Test': test_curve[best_val_epoch], 'Train': train_curve[best_val_epoch], 'BestTrain': best_train }, args.filename)
type=str, default='test', choices={'test', 'val'}, help="split to evaluate on") args = parser.parse_args() torch.manual_seed(args.seed) np.random.seed(args.seed) device = torch.device(f"cuda:0" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") print('Loading Graphs...') data = ZINC(DATA_PATH, subset=True, split=args.split, transform=feat_transform, pre_transform=preproc) iter = CRaWlLoader(data, batch_size=50, num_workers=4) mean_list, std_list = [], [] model_list = sorted(list(glob(args.model_dir))) for model_dir in model_list: print(f'Evaluating {model_dir}...') model = CRaWl.load(model_dir) mean, std = test(model, iter, repeats=args.reps, steps=args.steps) mean_list.append(mean) std_list.append(std) print(
def _zinc(self): dataset = ZINC('data/ZINC', transform=ZINCTransformer()) mean = dataset.data.y.mean() std = dataset.data.y.std() dataset.data.y = (dataset.data.y - mean) / std return dataset, std.item(), 28, 4
def main(): args = get_parser() # get some argparse arguments that are parsed a bool string naive_encoder = not str2bool(args.full_encoder) pin_memory = str2bool(args.pin_memory) use_bias = str2bool(args.bias) downstream_bn = str(args.d_bn) same_dropout = str2bool(args.same_dropout) mlp_mp = str2bool(args.mlp_mp) subset_data = str2bool(args.subset_data) phm_dim = args.phm_dim learn_phm = str2bool(args.learn_phm) base_dir = "zinc/" if not os.path.exists(base_dir): os.makedirs(base_dir) if base_dir not in args.save_dir: args.save_dir = os.path.join(base_dir, args.save_dir) if not os.path.exists(args.save_dir): os.makedirs(args.save_dir) set_logging(save_dir=args.save_dir) logging.info(f"Creating log directory at {args.save_dir}.") with open(os.path.join(args.save_dir, "params.json"), 'w') as fp: json.dump(args.__dict__, fp) mp_layers = [int(item) for item in args.mp_units.split(',')] downstream_layers = [int(item) for item in args.d_units.split(',')] mp_dropout = [float(item) for item in args.dropout_mpnn.split(',')] dn_dropout = [float(item) for item in args.dropout_dn.split(',')] logging.info(f'Initialising model with {mp_layers} hidden units with dropout {mp_dropout} ' f'and downstream units: {downstream_layers} with dropout {dn_dropout}.') if args.pooling == "globalsum": logging.info("Using GlobalSum Pooling") else: logging.info("Using SoftAttention Pooling") logging.info(f"Using Adam optimizer with weight_decay ({args.weightdecay}) and regularization " f"norm ({args.regularization})") logging.info(f"Weight init: {args.w_init} \n Contribution init: {args.c_init}") # data path = osp.join(osp.dirname(osp.realpath(__file__)), 'dataset', 'ZINC') train_data = ZINC(path, subset=subset_data, split='train') valid_data = ZINC(path, subset=subset_data, split='val') test_data = ZINC(path, subset=subset_data, split='test') evaluator = Evaluator() train_loader = DataLoader(train_data, batch_size=args.batch_size, drop_last=False, shuffle=True, num_workers=args.nworkers, pin_memory=pin_memory) valid_loader = DataLoader(valid_data, batch_size=args.batch_size, drop_last=False, shuffle=False, num_workers=args.nworkers, pin_memory=pin_memory) test_loader = DataLoader(test_data, batch_size=args.batch_size, drop_last=False, shuffle=False, num_workers=args.nworkers, pin_memory=pin_memory) #transform = RemoveIsolatedNodes() transform = None device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu' #device = "cpu" # for hypercomplex model unique_phm = str2bool(args.unique_phm) if unique_phm: phm_rule = get_multiplication_matrices(phm_dim=args.phm_dim, type="phm") phm_rule = torch.nn.ParameterList( [torch.nn.Parameter(a, requires_grad=learn_phm) for a in phm_rule] ) else: phm_rule = None #https://github.com/graphdeeplearning/benchmarking-gnns/blob/master/data/molecules/prepare_molecules_ZINC_full.ipynb FULL_ATOM_FEATURE_DIMS = [28] FULL_BOND_FEATURE_DIMS = [4] if args.aggr_msg == "pna" or args.aggr_node == "pna": # if PNA is used # Compute in-degree histogram over training data. deg = torch.zeros(5, dtype=torch.long) for data in train_data: d = degree(data.edge_index[1], num_nodes=data.num_nodes, dtype=torch.long) deg += torch.bincount(d, minlength=deg.numel()) else: deg = None aggr_kwargs = {"aggregators": ['mean', 'min', 'max', 'std'], "scalers": ['identity', 'amplification', 'attenuation'], "deg": deg, "post_layers": 1, "msg_scalers": str2bool(args.msg_scale), # this key is for directional messagepassing layers. "initial_beta": 1.0, # Softmax "learn_beta": True } if "quaternion" in args.type: if args.aggr_msg == "pna" or args.aggr_msg == "pna": logging.info("PNA not implemented for quaternion models.") raise NotImplementedError if args.type == "undirectional-quaternion-sc-add": logging.info("Using Quaternion Undirectional MPNN with Skip Connection through Addition") model = UQ_SC_ADD(atom_input_dims=FULL_ATOM_FEATURE_DIMS, atom_encoded_dim=args.input_embed_dim, bond_input_dims=FULL_BOND_FEATURE_DIMS, naive_encoder=naive_encoder, mp_layers=mp_layers, dropout_mpnn=mp_dropout, init=args.w_init, same_dropout=same_dropout, norm_mp=args.mp_norm, add_self_loops=True, msg_aggr=args.aggr_msg, node_aggr=args.aggr_node, mlp=mlp_mp, pooling=args.pooling, activation=args.activation, real_trafo=args.real_trafo, downstream_layers=downstream_layers, target_dim=1, dropout_dn=dn_dropout, norm_dn=downstream_bn, msg_encoder=args.msg_encoder, **aggr_kwargs) elif args.type == "undirectional-quaternion-sc-cat": logging.info("Using Quaternion Undirectional MPNN with Skip Connection through Concatenation") model = UQ_SC_CAT(atom_input_dims=FULL_ATOM_FEATURE_DIMS, atom_encoded_dim=args.input_embed_dim, bond_input_dims=FULL_BOND_FEATURE_DIMS, naive_encoder=naive_encoder, mp_layers=mp_layers, dropout_mpnn=mp_dropout, init=args.w_init, same_dropout=same_dropout, norm_mp=args.mp_norm, add_self_loops=True, msg_aggr=args.aggr_msg, node_aggr=args.aggr_node, mlp=mlp_mp, pooling=args.pooling, activation=args.activation, real_trafo=args.real_trafo, downstream_layers=downstream_layers, target_dim=1, dropout_dn=dn_dropout, norm_dn=downstream_bn, msg_encoder=args.msg_encoder, **aggr_kwargs) elif args.type == "undirectional-phm-sc-add": logging.info("Using PHM Undirectional MPNN with Skip Connection through Addition") model = UPH_SC_ADD(phm_dim=phm_dim, learn_phm=learn_phm, phm_rule=phm_rule, atom_input_dims=FULL_ATOM_FEATURE_DIMS, atom_encoded_dim=args.input_embed_dim, bond_input_dims=FULL_BOND_FEATURE_DIMS, naive_encoder=naive_encoder, mp_layers=mp_layers, dropout_mpnn=mp_dropout, w_init=args.w_init, c_init=args.c_init, same_dropout=same_dropout, norm_mp=args.mp_norm, add_self_loops=True, msg_aggr=args.aggr_msg, node_aggr=args.aggr_node, mlp=mlp_mp, pooling=args.pooling, activation=args.activation, real_trafo=args.real_trafo, downstream_layers=downstream_layers, target_dim=1, dropout_dn=dn_dropout, norm_dn=downstream_bn, msg_encoder=args.msg_encoder, sc_type=args.sc_type, **aggr_kwargs) elif args.type == "undirectional-phm-sc-cat": logging.info("Using PHM Undirectional MPNN with Skip Connection through Concatenation") model = UPH_SC_CAT(phm_dim=phm_dim, learn_phm=learn_phm, phm_rule=phm_rule, atom_input_dims=FULL_ATOM_FEATURE_DIMS, atom_encoded_dim=args.input_embed_dim, bond_input_dims=FULL_BOND_FEATURE_DIMS, naive_encoder=naive_encoder, mp_layers=mp_layers, dropout_mpnn=mp_dropout, w_init=args.w_init, c_init=args.c_init, same_dropout=same_dropout, norm_mp=args.mp_norm, add_self_loops=True, msg_aggr=args.aggr_msg, node_aggr=args.aggr_node, mlp=mlp_mp, pooling=args.pooling, activation=args.activation, real_trafo=args.real_trafo, downstream_layers=downstream_layers, target_dim=1, dropout_dn=dn_dropout, norm_dn=downstream_bn, msg_encoder=args.msg_encoder, **aggr_kwargs) else: raise ModuleNotFoundError logging.info(f"Model consists of {model.get_number_of_params_()} trainable parameters") # do runs test_best_epoch_metrics_arr = [] test_last_epoch_metrics_arr = [] val_metrics_arr = [] t0 = time.time() for i in range(1, args.n_runs + 1): ogb_bestEpoch_test_metrics, ogb_lastEpoch_test_metric, ogb_val_metrics = do_run(i, model, args, transform, train_loader, valid_loader, test_loader, device, evaluator, t0) test_best_epoch_metrics_arr.append(ogb_bestEpoch_test_metrics) test_last_epoch_metrics_arr.append(ogb_lastEpoch_test_metric) val_metrics_arr.append(ogb_val_metrics) logging.info(f"Performance of model across {args.n_runs} runs:") test_bestEpoch_perf = torch.tensor(test_best_epoch_metrics_arr) test_lastEpoch_perf = torch.tensor(test_last_epoch_metrics_arr) valid_perf = torch.tensor(val_metrics_arr) logging.info('===========================') logging.info(f'Final Test (best val-epoch) ' f'"{evaluator.eval_metric}": {test_bestEpoch_perf.mean():.4f} ± {test_bestEpoch_perf.std():.4f}') logging.info(f'Final Test (last-epoch) ' f'"{evaluator.eval_metric}": {test_lastEpoch_perf.mean():.4f} ± {test_lastEpoch_perf.std():.4f}') logging.info(f'Final (best) Valid "{evaluator.eval_metric}": {valid_perf.mean():.4f} ± {valid_perf.std():.4f}')
def load_dataset(args): # automatic data loading and splitting transform = add_zeros if args.dataset == 'ogbg-ppa' else None cls_criterion = get_loss_function(args.dataset) idx2word_mapper = None if args.dataset == 'mnist': train_data = MNISTSuperpixels(root='dataset', train=True, transform=T.Polar()) dataset = train_data dataset.name = 'mnist' dataset.eval_metric = 'acc' validation_data = [] test_data = MNISTSuperpixels(root='dataset', train=False, transform=T.Polar()) train_data = list(train_data) test_data = list(test_data) elif args.dataset == 'QM9': # Contains 19 targets. Use only the first 12 (0-11) QM9_VALIDATION_START = 110000 QM9_VALIDATION_END = 120000 dataset = QM9(root='dataset', transform=ExtractTargetTransform(args.target)).shuffle() dataset.name = 'QM9' dataset.eval_metric = 'mae' train_data = dataset[:QM9_VALIDATION_START] validation_data = dataset[QM9_VALIDATION_START:QM9_VALIDATION_END] test_data = dataset[QM9_VALIDATION_END:] train_data = list(train_data) validation_data = list(validation_data) test_data = list(test_data) elif args.dataset == 'zinc': train_data = ZINC(root='dataset', subset=True, split='train') dataset = train_data dataset.name = 'zinc' validation_data = ZINC(root='dataset', subset=True, split='val') test_data = ZINC(root='dataset', subset=True, split='test') dataset.eval_metric = 'mae' train_data = list(train_data) validation_data = list(validation_data) test_data = list(test_data) elif args.dataset in [ 'ogbg-molhiv', 'ogbg-molpcba', 'ogbg-ppa', 'ogbg-code2' ]: dataset = PygGraphPropPredDataset(name=args.dataset, transform=transform) if args.dataset == 'obgb-code2': seq_len_list = np.array([len(seq) for seq in dataset.data.y]) max_seq_len = args.max_seq_len num_less_or_equal_to_max = np.sum( seq_len_list <= args.max_seq_len) / len(seq_len_list) print( f'Target sequence less or equal to {max_seq_len} is {num_less_or_equal_to_max}%.' ) split_idx = dataset.get_idx_split() # The following is only used in the evaluation of the ogbg-code classifier. if args.dataset == 'ogbg-code2': vocab2idx, idx2vocab = get_vocab_mapping( [dataset.data.y[i] for i in split_idx['train']], args.num_vocab) # specific transformations for the ogbg-code dataset dataset.transform = transforms.Compose([ augment_edge, lambda data: encode_y_to_arr(data, vocab2idx, args.max_seq_len) ]) idx2word_mapper = partial(decode_arr_to_seq, idx2vocab=idx2vocab) train_data = list(dataset[split_idx["train"]]) validation_data = list(dataset[split_idx["valid"]]) test_data = list(dataset[split_idx["test"]]) return dataset, train_data, validation_data, test_data, cls_criterion, idx2word_mapper