示例#1
0
import random
import numpy as np

engine = get_engine()

seed = 4269666
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
# Constant definition
device = torch.device("cuda:2")

# Le probleme vient du count vectorizer qui vire certains mots
print("Load Dataset")
dataset = Quora.torch_dataset()
dataclasses = Quora.dataclasses()
dataclasses = {q._id: q for q in dataclasses}


def embedding_collate_decorator(collate_fn):
    def wrapper(batch):
        x, y, id_, qrels, seq_lens = collate_fn(batch)
        return x, y, id_, qrels, seq_lens

    return wrapper


collate_fn = embedding_collate_decorator(sequence_collate_fn)

train_len, val_len = int(0.7 * len(dataset)), int(0.15 * len(dataset))
import sys
import os
from os import path

libpath = path.normpath(
    path.join(path.dirname(path.realpath(__file__)), os.pardir, "src"))
sys.path.append(libpath)

import pickle as pkl
import torch

import data
from datasets import Quora, Robust2004

sys.modules["dataset"] = data

quora_dc = Quora.dataclasses()
quora_torch = Quora.torch_dataset()
rb_dc = Robust2004.dataclasses()
rb_torch = Robust2004.torch_dataset()

del sys.modules["dataset"]

with open(Quora.dataclasses_path, "wb") as f:
    pkl.dump(quora_dc, f)

with open(Robust2004.dataclasses_path, "wb") as f:
    pkl.dump(rb_dc, f)

torch.save(quora_torch, Quora.torch_path)
torch.save(rb_torch, Robust2004.torch_path)