コード例 #1
0
def create_dataset(opt, SRC, TRG):

    print("creating dataset and iterator... ")

    raw_data = {
        'src': [line for line in opt.src_data],
        'trg': [line for line in opt.trg_data]
    }
    df = pd.DataFrame(raw_data, columns=["src", "trg"])

    mask = (df['src'].str.count(' ') <
            opt.max_strlen) & (df['trg'].str.count(' ') < opt.max_strlen)
    import pdb
    #pdb.set_trace()
    df = df.loc[mask]

    df.to_csv("translate_transformer_temp.csv", index=False)

    data_fields = [('src', SRC), ('trg', TRG)]
    train = data.TabularDataset('./translate_transformer_temp.csv',
                                format='csv',
                                fields=data_fields)

    # train_iter = MyIterator(train, batch_size=opt.batchsize, device=opt.device,
    #                     repeat=False, sort_key=lambda x: (len(x.src), len(x.trg)),
    #                     batch_size_fn=batch_size_fn, train=True, shuffle=True)

    train_iter = MyIterator(train,
                            batch_size=opt.batchsize,
                            device=torch.device('cuda'),
                            repeat=False,
                            sort_key=lambda x: (len(x.src), len(x.trg)),
                            batch_size_fn=batch_size_fn,
                            train=True,
                            shuffle=True)

    #os.remove('translate_transformer_temp.csv')

    if opt.load_weights is None:
        SRC.build_vocab(train)
        TRG.build_vocab(train)
        if opt.checkpoint > 0:
            try:
                os.mkdir("weights")
            except:
                print(
                    "weights folder already exists, run program with -load_weights weights to load them"
                )
                quit()
            pickle.dump(SRC, open('weights/SRC.pkl', 'wb'))
            pickle.dump(TRG, open('weights/TRG.pkl', 'wb'))

    opt.src_pad = SRC.vocab.stoi['<pad>']
    opt.trg_pad = TRG.vocab.stoi['<pad>']

    opt.train_len = get_len(train_iter)

    return train_iter
コード例 #2
0
	def create_dataset(self):

		print('Creating dataset and iterator')

		raw_data = {'src': [line for line in self.english_data], 'trg': [line for line in self.french_data]}
		df = pd.DataFrame(raw_data, columns=['src', 'trg'])

		mask = (df['src'].str.count(' ') < this.max_length) & (df['trg'].str.count(' ') < this.max_length)
		df = df.loc[mask]

		df.to_csv('data/translate_sentences.csv', index=False)

		data_fields = [('src', self.SRC), ('trg', self.TRG)]
		train = data.TabularDataset(path='data/translate_sentences.csv', format='csv', feilds=data_fields)

		train_iter = MyIterator(train, batch_size=)
コード例 #3
0
def create_dataset(opt, SRC, TRG):
    print("creating dataset and iterator... ")
    raw_data = {
        'src': [line for line in opt.src_data],
        'trg': [line for line in opt.trg_data]
    }
    # print("raw_data",raw_data)
    # 此处开始制作 一个 csv 文件
    df = pd.DataFrame(raw_data, columns=["src", "trg"])
    mask = (df['src'].str.count(' ') <
            opt.max_strlen) & (df['trg'].str.count(' ') < opt.max_strlen)
    # print("mask",len(mask),mask)
    df = df.loc[mask]
    # print("df.loc[mask]",df.loc[mask])
    df.to_csv("translate_transformer_temp.csv", index=False)
    data_fields = [('src', SRC), ('trg', TRG)]
    train = data.TabularDataset('./translate_transformer_temp.csv',
                                format='csv',
                                fields=data_fields)
    train_iter = MyIterator(train,
                            batch_size=opt.batchsize,
                            device=opt.device,
                            repeat=False,
                            sort_key=lambda x: (len(x.src), len(x.trg)),
                            batch_size_fn=batch_size_fn,
                            train=True,
                            shuffle=True)
    # os.remove('translate_transformer_temp.csv')
    # 此处 删除 制作的 csv 文件
    if not opt.premodels or os.path.exists(opt.load_weights + "/" +
                                           opt.premodels_path):  # 加载权重
        SRC.build_vocab(train)  # 制作数据词表
        TRG.build_vocab(train)
        if opt.checkpoint > 0:
            if not os.path.exists(opt.load_weights):
                os.mkdir(opt.load_weights)
                print(
                    "weights folder already exists, run program with -load_weights weights to load them"
                )
            pickle.dump(SRC, open(config.weights + '/SRC.pkl', 'wb'))
            pickle.dump(TRG, open(config.weights + '/TRG.pkl', 'wb'))
    opt.src_pad = SRC.vocab.stoi['<pad>']
    opt.trg_pad = TRG.vocab.stoi['<pad>']
    return train_iter
コード例 #4
0
def create_dataset(opt, SRC, TRG):

    print("creating dataset and iterator...")

    raw_data = {
        'src': [line for line in opt.src_data],
        'trg': [line for line in opt.trg_data]
    }
    df = pd.DataFrame(raw_data, columns=["src", "trg"])

    mask = (df['src'].str.count(' ') <
            opt.max_strlen) & (df['trg'].str.count(' ') < opt.max_strlen)
    df = df.loc[mask]

    df.to_csv("translate_transformer_temp.csv", index=False)

    data_fields = [('src', SRC), ('trg', TRG)]
    train = data.TabularDataset('./translate_transformer_temp.csv',
                                format='csv',
                                fields=data_fields)

    train_iter = MyIterator(train,
                            batch_size=opt.batchsize,
                            device=opt.device,
                            repeat=False,
                            sort_key=lambda x: (len(x.src), len(x.trg)),
                            batch_size_fn=batch_size_fn,
                            train=True,
                            shuffle=True)

    os.remove('translate_transformer_temp.csv')

    if opt.load_weights is None:
        SRC.build_vocab(train)
        TRG.build_vocab(train)

    opt.src_pad = SRC.vocab.stoi['<pad>']
    opt.trg_pad = TRG.vocab.stoi['<pad>']

    opt.train_len = get_len(train_iter)

    return train_iter
コード例 #5
0
ファイル: Process.py プロジェクト: c-col/Transformer
def create_dataset(opt, SRC, TRG):
    print("creating dataset and iterator... ")

    if opt.task == 'toy_task' or opt.task == 'e_snli_o':
        # Load in validation data
        f_in, f_out = open(opt.data_path + '/val_in.txt',
                           'r',
                           encoding='utf-8'), open(opt.data_path +
                                                   '/val_out.txt',
                                                   'r',
                                                   encoding='utf-8')
        in_ = [x.replace('\n', '') for x in f_in.readlines()]
        out_ = [x.replace('\n', '') for x in f_out.readlines()]

        raw_data = {'src': in_, 'trg': out_}
        df = pd.DataFrame(raw_data, columns=["src", "trg"])

        mask = (df['src'].str.count(' ') <
                opt.max_strlen) & (df['trg'].str.count(' ') < opt.max_strlen)
        df = df.loc[mask]

        df.to_csv("translate_transformer_temp.csv", index=False)
        data_fields = [('src', SRC), ('trg', TRG)]
        val = data.TabularDataset('./translate_transformer_temp.csv',
                                  format='csv',
                                  fields=data_fields,
                                  skip_header=True)
        os.remove('translate_transformer_temp.csv')

        val_iter = MyIterator(val,
                              batch_size=opt.batchsize,
                              repeat=False,
                              sort_key=lambda x: (len(x.src), len(x.trg)),
                              train=False,
                              shuffle=False)
    elif opt.task == 'e_snli_r':
        # Load in validation data
        f_in, f_out = open(opt.data_path + '/val_in.txt',
                           'r',
                           encoding='utf-8'), open(opt.data_path +
                                                   '/val_out.txt',
                                                   'r',
                                                   encoding='utf-8')
        if opt.label_path is None:
            raise AssertionError(
                'Need to provide a path to label data for validation checks')

        f_label = open(opt.label_path + '/val_out.txt', 'r', encoding='utf-8')

        in_ = [x.replace('\n', '') for x in f_in.readlines()]
        out_ = [x.replace('\n', '') for x in f_out.readlines()]
        labels_ = [x.replace('\n', '') for x in f_label.readlines()]
        out1, out2, out3 = [], [], []
        for o in out_:
            split = o.split(' @@SEP@@ ')
            out1.append(split[0])
            out2.append(split[1])
            out3.append(split[2])

        raw_data = {
            'src': in_,
            'trg1': out1,
            'trg2': out2,
            'trg3': out3,
            'labels': labels_
        }
        df = pd.DataFrame(raw_data,
                          columns=["src", "trg1", "trg2", "trg3", "labels"])

        mask = (df['src'].str.count(' ') < opt.max_strlen) & (df['trg1'].str.count(' ') < opt.max_strlen) & \
               (df['trg2'].str.count(' ') < opt.max_strlen) & (df['trg3'].str.count(' ') < opt.max_strlen)
        df = df.loc[mask]

        df.to_csv("translate_transformer_temp.csv", index=False)
        data_fields = [('src', SRC), ('trg1', TRG), ('trg2', TRG),
                       ('trg3', TRG), ('label', opt.classifier_TRG)]
        val = data.TabularDataset('./translate_transformer_temp.csv',
                                  format='csv',
                                  fields=data_fields,
                                  skip_header=True)
        os.remove('translate_transformer_temp.csv')

        val_iter = MyIterator(
            val,
            batch_size=opt.batchsize,
            repeat=False,
            sort_key=lambda x:
            (len(x.src), len(x.trg1), len(x.trg2), len(x.trg3)),
            train=False,
            shuffle=False)

    else:
        # cos_e
        raise NotImplementedError(
            "No implementation provided in process.py for cos-e (yet)")

    ##### TRAIN DATA #####
    raw_data = {
        'src': [line for line in opt.src_data],
        'trg': [line for line in opt.trg_data]
    }
    df = pd.DataFrame(raw_data, columns=["src", "trg"])

    mask = (df['src'].str.count(' ') <
            opt.max_strlen) & (df['trg'].str.count(' ') < opt.max_strlen)
    df = df.loc[mask]

    df.to_csv("translate_transformer_temp.csv", index=False)
    data_fields = [('src', SRC), ('trg', TRG)]
    train = data.TabularDataset('./translate_transformer_temp.csv',
                                format='csv',
                                fields=data_fields,
                                skip_header=True)
    print('desired batch size', opt.batchsize)

    train_iter = MyIterator(
        train,
        batch_size=opt.batchsize,  # device=opt.device,
        repeat=False,
        sort_key=lambda x: (len(x.src), len(x.trg)),
        train=True,
        shuffle=True)
    os.remove('translate_transformer_temp.csv')

    if opt.load_weights is None:
        if opt.checkpoint > 0:
            try:
                os.mkdir("weights")
            except:
                print(
                    "weights folder already exists, run program with -load_weights weights to load them"
                )
                quit()
            pickle.dump(SRC, open('weights/SRC.pkl', 'wb'))
            pickle.dump(TRG, open('weights/TRG.pkl', 'wb'))

    opt.src_pad = SRC.vocab.stoi['<pad>']
    opt.trg_pad = TRG.vocab.stoi['<pad>']

    opt.train_len = get_len(train_iter)
    print('number of train batches:', opt.train_len)
    print('number of val batches:', get_len(val_iter))
    return train_iter, val_iter