Beispiel #1
0
def load_data():
    print('loading data...\n')
    data = pickle.load(open(config.data+'data.pkl', 'rb'))
    data['train']['length'] = int(data['train']['length'] * opt.scale)

    trainset = utils.BiDataset(data['train'], char=config.char)
    validset = utils.BiDataset(data['valid'], char=config.char)

    src_vocab = data['dict']['src']
    tgt_vocab = data['dict']['tgt']
    config.src_vocab_size = src_vocab.size()
    config.tgt_vocab_size = tgt_vocab.size()

    trainloader = torch.utils.data.DataLoader(dataset=trainset,
                                              batch_size=config.batch_size,
                                              shuffle=True,
                                              num_workers=0,
                                              collate_fn=utils.padding)
    if hasattr(config, 'valid_batch_size'):
        valid_batch_size = config.valid_batch_size
    else:
        valid_batch_size = config.batch_size
    validloader = torch.utils.data.DataLoader(dataset=validset,
                                              batch_size=valid_batch_size,
                                              shuffle=False,
                                              num_workers=0,
                                              collate_fn=utils.padding)

    return {'trainset': trainset, 'validset': validset,
            'trainloader': trainloader, 'validloader': validloader,
            'src_vocab': src_vocab, 'tgt_vocab': tgt_vocab}
Beispiel #2
0
def prepare_mydataloaders(opt, device):
    batch_size = opt.batch_size
    data = pickle.load(open(opt.data_pkl, 'rb'))

    opt.max_token_seq_len = 140
    opt.src_pad_idx = data['dict']['src'].labelToIdx[Constants.PAD_WORD]
    opt.trg_pad_idx = data['dict']['tgt'].labelToIdx[Constants.PAD_WORD]

    opt.src_vocab_size = len(data['dict']['src'].labelToIdx)
    opt.trg_vocab_size = len(data['dict']['tgt'].labelToIdx)

    # ========= Preparing Model =========#
    if opt.embs_share_weight:
        assert data['dict']['src'].labelToIdx == data['dict']['tgt'].labelToIdx, \
            'To sharing word embedding the src/trg word2idx table shall be the same.'

    trainset = utils.BiDataset(data['train'])
    validset = utils.BiDataset(data['valid'])
    testset = utils.BiDataset(data['test'])
    trainloader = DataLoader(dataset=trainset,
                             batch_size=batch_size,
                             shuffle=True,
                             num_workers=0,
                             collate_fn=utils.padding)
    validloader = torch.utils.data.DataLoader(dataset=validset,
                                              batch_size=batch_size,
                                              shuffle=False,
                                              num_workers=0,
                                              collate_fn=utils.padding)
    testloader = torch.utils.data.DataLoader(dataset=testset,
                                             batch_size=batch_size,
                                             shuffle=False,
                                             collate_fn=utils.padding)
    return trainloader, validloader
Beispiel #3
0
def load_data(config):
    """
    load data.
    update "data" due to the saved path in the pickle file
    :return: a dict with data and vocabulary
    """
    print("loading data...\n")
    data = pickle.load(open(config.data + "data.pkl", "rb"))
    # retrieve data, due to the problem of path.
    data["train"]["length"] = int(data["train"]["length"] * config.scale)
    data["train"]["srcF"] = os.path.join(config.data, "train.src.id")
    data["train"]["original_srcF"] = os.path.join(config.data, "train.src.str")
    data["train"]["tgtF"] = os.path.join(config.data, "train.tgt.id")
    data["train"]["original_tgtF"] = os.path.join(config.data, "train.tgt.str")
    data["test"]["srcF"] = os.path.join(config.data, "test.src.id")
    data["test"]["original_srcF"] = os.path.join(config.data, "test.src.str")
    data["test"]["tgtF"] = os.path.join(config.data, "test.tgt.id")
    data["test"]["original_tgtF"] = os.path.join(config.data, "test.tgt.str")

    train_set = utils.BiDataset(data["train"], char=config.char)
    valid_set = utils.BiDataset(data["test"], char=config.char)

    src_vocab = data["dict"]["src"]
    tgt_vocab = data["dict"]["tgt"]
    config.src_vocab_size = src_vocab.size()
    config.tgt_vocab_size = tgt_vocab.size()

    train_loader = torch.utils.data.DataLoader(
        dataset=train_set,
        batch_size=config.batch_size,
        shuffle=True,
        num_workers=0,
        collate_fn=utils.padding,
    )
    if hasattr(config, "valid_batch_size"):
        valid_batch_size = config.valid_batch_size
    else:
        valid_batch_size = config.batch_size
    valid_loader = torch.utils.data.DataLoader(
        dataset=valid_set,
        batch_size=valid_batch_size,
        shuffle=False,
        num_workers=0,
        collate_fn=utils.padding,
    )
    return {
        "train_set": train_set,
        "valid_set": valid_set,
        "train_loader": train_loader,
        "valid_loader": valid_loader,
        "src_vocab": src_vocab,
        "tgt_vocab": tgt_vocab,
    }
Beispiel #4
0
def load_data():
    print('loading data...\n')
    data = pickle.load(open(opt.data + 'data.pkl', 'rb'))
    src_vocab = data['dict']['src']
    tgt_vocab = data['dict']['tgt']
    config.src_vocab_size = src_vocab.size()
    config.tgt_vocab_size = tgt_vocab.size()
    testset = utils.BiDataset(data['test'], char=config.char)
    testloader = torch.utils.data.DataLoader(dataset=testset,
                                             batch_size=opt.batch_size,
                                             shuffle=False,
                                             num_workers=0,
                                             collate_fn=utils.padding)
    return {
        'testset': testset,
        'testloader': testloader,
        'src_vocab': src_vocab,
        'tgt_vocab': tgt_vocab
    }
Beispiel #5
0
def prepare_mydataloaders(opt, device):
    data = pickle.load(open(opt.data_pkl, 'rb'))

    opt.max_token_seq_len = 140
    opt.src_pad_idx = data['dict']['src'].labelToIdx[Constants.PAD_WORD]
    opt.trg_pad_idx = data['dict']['tgt'].labelToIdx[Constants.PAD_WORD]
    opt.trg_bos_idx = data['dict']['tgt'].labelToIdx[Constants.BOS_WORD]
    opt.trg_eos_idx = data['dict']['tgt'].labelToIdx[Constants.EOS_WORD]
    opt.unk_idx = 1
    opt.src_vocab_size = len(data['dict']['src'].labelToIdx)
    opt.trg_vocab_size = len(data['dict']['tgt'].labelToIdx)
    # ========= Preparing Model =========#
    # if opt.embs_share_weight:
    #     assert data['dict']['src'].labelToIdx == data['dict']['tgt'].labelToIdx, \
    #         'To sharing word embedding the src/trg word2idx table shall be the same.'
    testset = utils.BiDataset(data['test'])
    testloader = torch.utils.data.DataLoader(dataset=testset,
                                             batch_size=1,
                                             shuffle=False,
                                             collate_fn=utils.padding)
    return data['dict']['tgt'], testloader
Beispiel #6
0
def load_data():
    """
    作用:加载数据
    参数:无
    返回值:无
    """
    print('loading data...\n')
    # 加载preprocess过的data(包括train、valid、test数据以及词表)
    data = pickle.load(open(config.data + 'data.pkl', 'rb'))
    # scale为训练集的比例,默认为1
    data['train']['length'] = int(data['train']['length'] * opt.scale)

    # 获取训练集和验证集对象
    trainset = utils.BiDataset(data['train'], char=config.char)
    validset = utils.BiDataset(data['valid'], char=config.char)

    if opt.pointer:
        trainset = utils.BiDataset(data['train'],
                                   char=config.char,
                                   copy=opt.pointer)
        validset = utils.BiDataset(data['valid'],
                                   char=config.char,
                                   copy=opt.pointer)

    # 获取词表及其大小,不用改
    src_vocab = data['dict']['src']
    tgt_vocab = data['dict']['tgt']
    config.src_vocab_size = src_vocab.size()
    config.tgt_vocab_size = tgt_vocab.size()

    # 获取一个batch的data,将每个batch的数据长度取最大,不够长度的padding0
    trainloader = torch.utils.data.DataLoader(dataset=trainset,
                                              batch_size=config.batch_size,
                                              shuffle=True,
                                              num_workers=0,
                                              collate_fn=utils.padding)
    # 如果没有单独设置valid的batch大小,则使用与train相同大小batch
    if hasattr(config, 'valid_batch_size'):
        valid_batch_size = config.valid_batch_size
    else:
        valid_batch_size = config.batch_size
    validloader = torch.utils.data.DataLoader(dataset=validset,
                                              batch_size=valid_batch_size,
                                              shuffle=True,
                                              num_workers=0,
                                              collate_fn=utils.padding)
    if opt.pointer:
        trainloader = torch.utils.data.DataLoader(dataset=trainset,
                                                  batch_size=config.batch_size,
                                                  shuffle=True,
                                                  num_workers=0,
                                                  collate_fn=utils.padding)
        validloader = torch.utils.data.DataLoader(dataset=validset,
                                                  batch_size=valid_batch_size,
                                                  shuffle=True,
                                                  num_workers=0,
                                                  collate_fn=utils.padding)
    return {
        'trainset': trainset,
        'validset': validset,
        'trainloader': trainloader,
        'validloader': validloader,
        'src_vocab': src_vocab,
        'tgt_vocab': tgt_vocab
    }
Beispiel #7
0
def load_eval_data(config):
    """
    load data.
    update "data" due to the saved path in the pickle file
    :return: a dict with data and vocabulary
    """
    print("loading data...\n")
    data = pickle.load(open(config.data + "data.pkl", "rb"))
    # retrieve data, due to the problem of path.
    print(data.keys())
    data["test"]["srcF"] = os.path.join('core/dataloading/', "src.id")
    data["test"]["original_srcF"] = os.path.join('core/dataloading/',
                                                 "src.str")
    data["test"]["tgtF"] = os.path.join('core/dataloading/', "tgt.id")
    data["test"]["original_tgtF"] = os.path.join('core/dataloading/',
                                                 "tgt.str")
    data["test"]["length"] = 16
    if config.knowledge:
        train_set = utils.BiKnowledgeDataset(os.path.join(
            config.data, 'train.supporting_facts'),
                                             infos=data['train'],
                                             char=config.char)
        valid_set = utils.BiKnowledgeDataset(os.path.join(
            config.data, 'test.supporting_facts'),
                                             infos=data['test'],
                                             char=config.char)
    else:
        train_set = utils.BiDataset(data["train"], char=config.char)
        valid_set = utils.BiDataset(data["test"], char=config.char)

    src_vocab = data["dict"]["src"]
    tgt_vocab = data["dict"]["tgt"]
    config.src_vocab_size = src_vocab.size()
    config.tgt_vocab_size = tgt_vocab.size()

    train_loader = torch.utils.data.DataLoader(
        dataset=train_set,
        batch_size=config.batch_size,
        shuffle=True,
        num_workers=config.num_workers,
        collate_fn=utils.knowledge_padding
        if config.knowledge else utils.padding,
    )
    if hasattr(config, "valid_batch_size"):
        valid_batch_size = config.valid_batch_size
    else:
        valid_batch_size = config.batch_size
    valid_loader = torch.utils.data.DataLoader(
        dataset=valid_set,
        batch_size=valid_batch_size,
        shuffle=False,
        num_workers=config.num_workers,
        collate_fn=utils.knowledge_padding
        if config.knowledge else utils.padding,
    )
    print(data["test"]['length'])
    return {
        "train_set": train_set,
        "valid_set": valid_set,
        "train_loader": train_loader,
        "valid_loader": valid_loader,
        "src_vocab": src_vocab,
        "tgt_vocab": tgt_vocab,
    }