import random import numpy as np engine = get_engine() seed = 4269666 torch.manual_seed(seed) np.random.seed(seed) random.seed(seed) # Constant definition device = torch.device("cuda:2") # Le probleme vient du count vectorizer qui vire certains mots print("Load Dataset") dataset = Quora.torch_dataset() dataclasses = Quora.dataclasses() dataclasses = {q._id: q for q in dataclasses} def embedding_collate_decorator(collate_fn): def wrapper(batch): x, y, id_, qrels, seq_lens = collate_fn(batch) return x, y, id_, qrels, seq_lens return wrapper collate_fn = embedding_collate_decorator(sequence_collate_fn) train_len, val_len = int(0.7 * len(dataset)), int(0.15 * len(dataset))
import sys import os from os import path libpath = path.normpath( path.join(path.dirname(path.realpath(__file__)), os.pardir, "src")) sys.path.append(libpath) import pickle as pkl import torch import data from datasets import Quora, Robust2004 sys.modules["dataset"] = data quora_dc = Quora.dataclasses() quora_torch = Quora.torch_dataset() rb_dc = Robust2004.dataclasses() rb_torch = Robust2004.torch_dataset() del sys.modules["dataset"] with open(Quora.dataclasses_path, "wb") as f: pkl.dump(quora_dc, f) with open(Robust2004.dataclasses_path, "wb") as f: pkl.dump(rb_dc, f) torch.save(quora_torch, Quora.torch_path) torch.save(rb_torch, Robust2004.torch_path)