Esempio n. 1
0
def make_data_loader(opt, *args):
    if opt.dataset == "atomic":
        return atomic_data.GenerationDataLoader(opt, *args)
    elif opt.dataset == "conceptnet":
        return conceptnet_data.GenerationDataLoader(opt, *args)
    elif opt.dataset == "motiv_sent":
        print(f"load motiv_sent data")
        return motiv_data.GenerationDataLoader(opt, *args)
Esempio n. 2
0
def make_data_loader(opt, *args):
    if opt.dataset == "atomic":
        return atomic_data.GenerationDataLoader(opt, *args)
    elif opt.dataset == "conceptnet":
        return conceptnet_data.GenerationDataLoader(opt, *args)
Esempio n. 3
0
    'RelatedTo', 'SymbolOf', 'UsedFor'
]

special = [data.start_token, data.end_token]
special += ["<{}>".format(relation) for relation in relations]

encoder_path = "model/encoder_bpe_40000.json"
bpe_path = "model/vocab_40000.bpe"

text_encoder = TextEncoder(encoder_path, bpe_path)

for special_token in special:
    text_encoder.decoder[len(text_encoder.encoder)] = special_token
    text_encoder.encoder[special_token] = len(text_encoder.encoder)

data_loader = cdata.GenerationDataLoader(opt)
data_loader.load_data("data/conceptnet/")

data_loader.make_tensors(text_encoder, special, test=False)

opt.data.maxr = data_loader.max_r

save_path = "data/conceptnet/processed/generation"
save_name = os.path.join(save_path,
                         "{}.pickle".format(utils.make_name_string(opt.data)))

utils.mkpath(save_path)

print("Data Loader will be saved to {}".format(save_name))

torch.save(data_loader, save_name)