Exemplo n.º 1
0
    def testTensorLoaderIter(self):
        class FakeData:
            def __init__(self, return_dict=True):
                self.x = [[1, 2, 3], [4, 5, 6]]
                self.return_dict = return_dict

            def __len__(self):
                return len(self.x)

            def __getitem__(self, i):
                x = self.x[i]
                y = 0
                if self.return_dict:
                    return {'x': x}, {'y': y}
                return x, y

        data1 = FakeData()

        def collact_fn(ins_list):
            xs = [ins[0]['x'] for ins in ins_list]
            ys = [ins[1]['y'] for ins in ins_list]
            return {'x': xs}, {'y': ys}

        dataiter = TorchLoaderIter(data1, collate_fn=collact_fn, batch_size=2)
        for x, y in dataiter:
            print(x, y)
Exemplo n.º 2
0
    def testTensorLoaderIter(self):
        class FakeData:
            def __init__(self, return_dict=True):
                self.x = [[1, 2, 3], [4, 5, 6]]
                self.return_dict = return_dict

            def __len__(self):
                return len(self.x)

            def __getitem__(self, i):
                x = self.x[i]
                y = 0
                if self.return_dict:
                    return {'x': x}, {'y': y}
                return x, y

        data1 = FakeData()
        dataiter = TorchLoaderIter(data1, batch_size=2)
        for x, y in dataiter:
            print(x, y)

        def func():
            data2 = FakeData(return_dict=False)
            dataiter = TorchLoaderIter(data2, batch_size=2)

        self.assertRaises(Exception, func)
Exemplo n.º 3
0
    def test_udf_dataiter(self):
        import random
        import torch

        class UdfDataSet:
            def __init__(self, num_samples):
                self.num_samples = num_samples

            def __getitem__(self, idx):
                x = [random.random() for _ in range(3)]
                y = random.random()
                return x, y

            def __len__(self):
                return self.num_samples

        def collate_fn(data_list):
            # [(x1,y1), (x2,y2), ...], 这里的输入实际上是将UdfDataSet的__getitem__输入结合为list
            xs, ys = [], []
            for l in data_list:
                x, y = l
                xs.append(x)
                ys.append(y)
            x, y = torch.FloatTensor(xs), torch.FloatTensor(ys)
            return {'x': x, 'y': y}, {'y': y}

        dataset = UdfDataSet(10)
        dataset = TorchLoaderIter(dataset, collate_fn=collate_fn)

        class Model(nn.Module):
            def __init__(self):
                super().__init__()
                self.fc = nn.Linear(3, 1)

            def forward(self, x, y):
                return {'loss': torch.pow(self.fc(x).squeeze(-1) - y, 2).sum()}

            def predict(self, x):
                return {'pred': self.fc(x).squeeze(0)}

        model = Model()
        trainer = Trainer(train_data=dataset,
                          model=model,
                          loss=None,
                          print_every=2,
                          dev_data=dataset,
                          metrics=AccuracyMetric(target='y'),
                          use_tqdm=False)
        trainer.train(load_best_model=False)
Exemplo n.º 4
0
    def test_batch_sampler(self):
        # 测试DataSetIter与TorchLoaderIter的batch_sampler能否正常工作
        # DataSetIter
        ds = generate_fake_dataset(5)
        ds.set_input('1')

        class BatchSampler:
            def __init__(self, dataset):
                self.num_samples = len(dataset)

            def __iter__(self):
                index = 0
                indexes = list(range(self.num_samples))
                np.random.shuffle(indexes)
                start_idx = 0
                while index < self.num_samples:
                    if start_idx == 0:
                        end_index = self.num_samples // 2
                    else:
                        end_index = self.num_samples
                    yield indexes[start_idx:end_index]
                    index = end_index
                    start_idx = end_index

            def __len__(self):
                return 2

        batch_sampler = BatchSampler(ds)

        data_iter = DataSetIter(ds,
                                batch_size=10,
                                sampler=batch_sampler,
                                as_numpy=False,
                                num_workers=0,
                                pin_memory=False,
                                drop_last=False,
                                timeout=0,
                                worker_init_fn=None,
                                batch_sampler=batch_sampler)
        num_samples = [len(ds) // 2, len(ds) - len(ds) // 2]
        for idx, (batch_x, batch_y) in enumerate(data_iter):
            self.assertEqual(num_samples[idx], len(batch_x['1']))

        # TorchLoaderIter
        class FakeData:
            def __init__(self):
                self.x = [[1, 2, 3], [4, 5, 6], [1, 2]]

            def __len__(self):
                return len(self.x)

            def __getitem__(self, i):
                x = self.x[i]
                y = 0
                return x, y

        def collate_fn(ins_list):
            xs = [ins[0] for ins in ins_list]
            ys = [ins[1] for ins in ins_list]
            return {'x': xs}, {'y': ys}

        ds = FakeData()
        batch_sampler = BatchSampler(ds)
        data_iter = TorchLoaderIter(ds,
                                    batch_size=10,
                                    sampler=batch_sampler,
                                    num_workers=0,
                                    pin_memory=False,
                                    drop_last=False,
                                    timeout=0,
                                    worker_init_fn=None,
                                    collate_fn=collate_fn,
                                    batch_sampler=batch_sampler)
        num_samples = [len(ds) // 2, len(ds) - len(ds) // 2]
        for idx, (batch_x, batch_y) in enumerate(data_iter):
            self.assertEqual(num_samples[idx], len(batch_x['x']))
Exemplo n.º 5
0
def train():
    args = parse_args()
    if args.debug:
        fitlog.debug()
        args.save_model = False
    # ================= define =================
    tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
    word_mask_index = tokenizer.mask_token_id
    word_vocab_size = len(tokenizer)

    if get_local_rank() == 0:
        fitlog.set_log_dir(args.log_dir)
        fitlog.commit(__file__, fit_msg=args.name)
        fitlog.add_hyper_in_file(__file__)
        fitlog.add_hyper(args)

    # ================= load data =================
    dist.init_process_group('nccl')
    init_logger_dist()

    n_proc = dist.get_world_size()
    bsz = args.batch_size // args.grad_accumulation // n_proc
    args.local_rank = get_local_rank()
    args.save_dir = os.path.join(args.save_dir,
                                 args.name) if args.save_model else None
    if args.save_dir is not None and os.path.exists(args.save_dir):
        raise RuntimeError('save_dir has already existed.')
    logger.info('save directory: {}'.format(
        'None' if args.save_dir is None else args.save_dir))
    devices = list(range(torch.cuda.device_count()))
    NUM_WORKERS = 4

    ent_vocab, rel_vocab = load_ent_rel_vocabs()
    logger.info('# entities: {}'.format(len(ent_vocab)))
    logger.info('# relations: {}'.format(len(rel_vocab)))
    ent_freq = get_ent_freq()
    assert len(ent_vocab) == len(ent_freq), '{} {}'.format(
        len(ent_vocab), len(ent_freq))

    #####
    root = args.data_dir
    dirs = os.listdir(root)
    drop_files = []
    for dir in dirs:
        path = os.path.join(root, dir)
        max_idx = 0
        for file_name in os.listdir(path):
            if 'large' in file_name:
                continue
            max_idx = int(file_name) if int(file_name) > max_idx else max_idx
        drop_files.append(os.path.join(path, str(max_idx)))
    #####

    file_list = []
    for path, _, filenames in os.walk(args.data_dir):
        for filename in filenames:
            file = os.path.join(path, filename)
            if 'large' in file or file in drop_files:
                continue
            file_list.append(file)
    logger.info('used {} files in {}.'.format(len(file_list), args.data_dir))
    if args.data_prop > 1:
        used_files = file_list[:int(args.data_prop)]
    else:
        used_files = file_list[:round(args.data_prop * len(file_list))]

    data = GraphOTFDataSet(used_files, n_proc, args.local_rank,
                           word_mask_index, word_vocab_size, args.n_negs,
                           ent_vocab, rel_vocab, ent_freq)
    dev_data = GraphDataSet(used_files[0], word_mask_index, word_vocab_size,
                            args.n_negs, ent_vocab, rel_vocab, ent_freq)

    sampler = OTFDistributedSampler(used_files, n_proc, get_local_rank())
    train_data_iter = TorchLoaderIter(dataset=data,
                                      batch_size=bsz,
                                      sampler=sampler,
                                      num_workers=NUM_WORKERS,
                                      collate_fn=data.collate_fn)
    dev_data_iter = TorchLoaderIter(dataset=dev_data,
                                    batch_size=bsz,
                                    sampler=RandomSampler(),
                                    num_workers=NUM_WORKERS,
                                    collate_fn=dev_data.collate_fn)
    if args.test_data is not None:
        test_data = FewRelDevDataSet(path=args.test_data,
                                     label_vocab=rel_vocab,
                                     ent_vocab=ent_vocab)
        test_data_iter = TorchLoaderIter(dataset=test_data,
                                         batch_size=32,
                                         sampler=RandomSampler(),
                                         num_workers=NUM_WORKERS,
                                         collate_fn=test_data.collate_fn)

    if args.local_rank == 0:
        print('full wiki files: {}'.format(len(file_list)))
        print('used wiki files: {}'.format(len(used_files)))
        print('# of trained samples: {}'.format(len(data) * n_proc))
        print('# of trained entities: {}'.format(len(ent_vocab)))
        print('# of trained relations: {}'.format(len(rel_vocab)))

    # ================= prepare model =================
    logger.info('model init')
    if args.rel_emb is not None:  # load pretrained relation embeddings
        rel_emb = np.load(args.rel_emb)
        # add_embs = np.random.randn(3, rel_emb.shape[1])  # add <pad>, <mask>, <unk>
        # rel_emb = np.r_[add_embs, rel_emb]
        rel_emb = torch.from_numpy(rel_emb).float()
        assert rel_emb.shape[0] == len(rel_vocab), '{} {}'.format(
            rel_emb.shape[0], len(rel_vocab))
        # assert rel_emb.shape[1] == args.rel_dim
        logger.info('loaded pretrained relation embeddings. dim: {}'.format(
            rel_emb.shape[1]))
    else:
        rel_emb = None
    if args.model_name is not None:
        logger.info('further pre-train.')
        config = RobertaConfig.from_pretrained('roberta-base',
                                               type_vocab_size=3)
        model = CoLAKE(config=config,
                       num_ent=len(ent_vocab),
                       num_rel=len(rel_vocab),
                       ent_dim=args.ent_dim,
                       rel_dim=args.rel_dim,
                       ent_lr=args.ent_lr,
                       ip_config=args.ip_config,
                       rel_emb=None,
                       emb_name=args.emb_name)
        states_dict = torch.load(args.model_name)
        model.load_state_dict(states_dict, strict=True)
    else:
        model = CoLAKE.from_pretrained(
            'roberta-base',
            num_ent=len(ent_vocab),
            num_rel=len(rel_vocab),
            ent_lr=args.ent_lr,
            ip_config=args.ip_config,
            rel_emb=rel_emb,
            emb_name=args.emb_name,
            cache_dir=PYTORCH_PRETRAINED_BERT_CACHE /
            'dist_{}'.format(args.local_rank))
        model.extend_type_embedding(token_type=3)
    # if args.local_rank == 0:
    #     for name, param in model.named_parameters():
    #         if param.requires_grad is True:
    #             print('{}: {}'.format(name, param.shape))

    # ================= train model =================
    # lr=1e-4 for peak value, lr=5e-5 for initial value
    logger.info('trainer init')
    no_decay = [
        'bias', 'LayerNorm.bias', 'LayerNorm.weight', 'layer_norm.bias',
        'layer_norm.weight'
    ]
    param_optimizer = list(model.named_parameters())
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]
    word_acc = WordMLMAccuracy(pred='word_pred',
                               target='masked_lm_labels',
                               seq_len='word_seq_len')
    ent_acc = EntityMLMAccuracy(pred='entity_pred',
                                target='ent_masked_lm_labels',
                                seq_len='ent_seq_len')
    rel_acc = RelationMLMAccuracy(pred='relation_pred',
                                  target='rel_masked_lm_labels',
                                  seq_len='rel_seq_len')
    metrics = [word_acc, ent_acc, rel_acc]

    if args.test_data is not None:
        test_metric = [rel_acc]
        tester = Tester(data=test_data_iter,
                        model=model,
                        metrics=test_metric,
                        device=list(range(torch.cuda.device_count())))
        # tester.test()
    else:
        tester = None

    optimizer = optim.AdamW(optimizer_grouped_parameters,
                            lr=args.lr,
                            betas=(0.9, args.beta),
                            eps=1e-6)
    # warmup_callback = WarmupCallback(warmup=args.warm_up, schedule='linear')
    fitlog_callback = MyFitlogCallback(tester=tester,
                                       log_loss_every=100,
                                       verbose=1)
    gradient_clip_callback = GradientClipCallback(clip_value=1,
                                                  clip_type='norm')
    emb_callback = EmbUpdateCallback(model.ent_embeddings)
    all_callbacks = [gradient_clip_callback, emb_callback]
    if args.save_dir is None:
        master_callbacks = [fitlog_callback]
    else:
        save_callback = SaveModelCallback(args.save_dir,
                                          model.ent_embeddings,
                                          only_params=True)
        master_callbacks = [fitlog_callback, save_callback]

    if args.do_test:
        states_dict = torch.load(os.path.join(args.save_dir,
                                              args.model_name)).state_dict()
        model.load_state_dict(states_dict)
        data_iter = TorchLoaderIter(dataset=data,
                                    batch_size=args.batch_size,
                                    sampler=RandomSampler(),
                                    num_workers=NUM_WORKERS,
                                    collate_fn=data.collate_fn)
        tester = Tester(data=data_iter,
                        model=model,
                        metrics=metrics,
                        device=devices)
        tester.test()
    else:
        trainer = DistTrainer(train_data=train_data_iter,
                              dev_data=dev_data_iter,
                              model=model,
                              optimizer=optimizer,
                              loss=LossInForward(),
                              batch_size_per_gpu=bsz,
                              update_every=args.grad_accumulation,
                              n_epochs=args.epoch,
                              metrics=metrics,
                              callbacks_master=master_callbacks,
                              callbacks_all=all_callbacks,
                              validate_every=5000,
                              use_tqdm=True,
                              fp16='O1' if args.fp16 else '')
        trainer.train(load_best_model=False)
Exemplo n.º 6
0
    def test_onthefly_iter(self):
        import tempfile
        import random
        import torch
        tmp_file_handler, tmp_file_path = tempfile.mkstemp(text=True)
        try:
            num_samples = 10
            data = []
            for _ in range(num_samples):
                x, y = [random.random() for _ in range(3)], random.random()
                data.append(x + [y])
            with open(tmp_file_path, 'w') as f:
                for d in data:
                    f.write(' '.join(map(str, d)) + '\n')

            class FileDataSet:
                def __init__(self, tmp_file):
                    num_samples = 0
                    line_pos = [0]  # 对应idx是某一行对应的位置
                    self.tmp_file_handler = open(tmp_file,
                                                 'r',
                                                 encoding='utf-8')
                    line = self.tmp_file_handler.readline()
                    while line:
                        if line.strip():
                            num_samples += 1
                            line_pos.append(self.tmp_file_handler.tell())
                        line = self.tmp_file_handler.readline()
                    self.tmp_file_handler.seek(0)
                    self.num_samples = num_samples
                    self.line_pos = line_pos

                def __getitem__(self, idx):
                    line_start, line_end = self.line_pos[idx], self.line_pos[
                        idx + 1]
                    self.tmp_file_handler.seek(line_start)
                    line = self.tmp_file_handler.read(line_end -
                                                      line_start).strip()
                    values = list(map(float, line.split()))
                    gold_d = data[idx]
                    assert all([g == v for g, v in zip(gold_d, values)
                                ]), "Should have the same data"
                    x, y = values[:3], values[-1]
                    return x, y

                def __len__(self):
                    return self.num_samples

            def collact_fn(data_list):
                # [(x1,y1), (x2,y2), ...], 这里的输入实际上是将UdfDataSet的__getitem__输入结合为list
                xs, ys = [], []
                for l in data_list:
                    x, y = l
                    xs.append(x)
                    ys.append(y)
                x, y = torch.FloatTensor(xs), torch.FloatTensor(ys)
                return {'x': x, 'y': y}, {'y': y}

            dataset = FileDataSet(tmp_file_path)
            dataset = TorchLoaderIter(dataset, collate_fn=collact_fn)

            class Model(nn.Module):
                def __init__(self):
                    super().__init__()
                    self.fc = nn.Linear(3, 1)

                def forward(self, x, y):
                    return {
                        'loss': torch.pow(self.fc(x).squeeze(-1) - y, 2).sum()
                    }

                def predict(self, x):
                    return {'pred': self.fc(x).squeeze(-1)}

            model = Model()
            trainer = Trainer(train_data=dataset,
                              model=model,
                              loss=None,
                              print_every=2,
                              dev_data=dataset,
                              metrics=AccuracyMetric(target='y'),
                              use_tqdm=False,
                              n_epochs=2)
            trainer.train(load_best_model=False)

        finally:
            import os
            if os.path.exists(tmp_file_path):
                os.remove(tmp_file_path)
Exemplo n.º 7
0
    def test_batch_sampler_dataiter(self):
        import random
        import torch

        class BatchSampler:
            def __init__(self, dataset):
                self.num_samples = len(dataset)

            def __iter__(self):
                index = 0
                indexes = list(range(self.num_samples))
                np.random.shuffle(indexes)
                start_idx = 0
                while index < self.num_samples:
                    if start_idx == 0:
                        end_index = self.num_samples // 2
                    else:
                        end_index = self.num_samples
                    yield indexes[start_idx:end_index]
                    index = end_index
                    start_idx = end_index

            def __len__(self):
                return 2

        class UdfDataSet:
            def __init__(self, num_samples):
                self.num_samples = num_samples

            def __getitem__(self, idx):
                x = [random.random() for _ in range(3)]
                y = random.random()
                return x, y

            def __len__(self):
                return self.num_samples

        def collate_fn(data_list):
            # [(x1,y1), (x2,y2), ...], 这里的输入实际上是将UdfDataSet的__getitem__输入结合为list
            xs, ys = [], []
            for l in data_list:
                x, y = l
                xs.append(x)
                ys.append(y)
            x, y = torch.FloatTensor(xs), torch.FloatTensor(ys)
            return {'x': x, 'y': y}, {'y': y}

        dataset = UdfDataSet(11)
        batch_sampler = BatchSampler(dataset)
        dataset = TorchLoaderIter(dataset,
                                  collate_fn=collate_fn,
                                  batch_sampler=batch_sampler)

        class Model(nn.Module):
            def __init__(self):
                super().__init__()
                self.fc = nn.Linear(3, 1)

            def forward(self, x, y):
                return {'loss': torch.pow(self.fc(x).squeeze(-1) - y, 2).sum()}

            def predict(self, x):
                return {'pred': self.fc(x).squeeze(-1)}

        model = Model()
        trainer = Trainer(train_data=dataset,
                          model=model,
                          loss=None,
                          print_every=2,
                          dev_data=dataset,
                          metrics=AccuracyMetric(target='y'),
                          use_tqdm=False)
        trainer.train(load_best_model=False)
Exemplo n.º 8
0
 def func():
     data2 = FakeData(return_dict=False)
     dataiter = TorchLoaderIter(data2, batch_size=2)
Exemplo n.º 9
0
def main():
    args = parse_args()

    if args.debug:
        fitlog.debug()

    fitlog.set_log_dir(args.log_dir)
    fitlog.commit(__file__)
    fitlog.add_hyper_in_file(__file__)
    fitlog.add_hyper(args)
    if args.gpu != 'all':
        os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu

    train_set, dev_set, test_set, temp_ent_vocab = load_fewrel_graph_data(
        data_dir=args.data_dir)

    print('data directory: {}'.format(args.data_dir))
    print('# of train samples: {}'.format(len(train_set)))
    print('# of dev samples: {}'.format(len(dev_set)))
    print('# of test samples: {}'.format(len(test_set)))

    ent_vocab, rel_vocab = load_ent_rel_vocabs(path='../')

    # load entity embeddings
    ent_index = []
    for k, v in temp_ent_vocab.items():
        ent_index.append(ent_vocab[k])
    ent_index = torch.tensor(ent_index)
    ent_emb = np.load(os.path.join(args.model_path, 'entities.npy'))
    ent_embedding = nn.Embedding.from_pretrained(torch.from_numpy(ent_emb))
    ent_emb = ent_embedding(ent_index.view(1, -1)).squeeze().detach()

    # load CoLAKE parameters
    config = RobertaConfig.from_pretrained('roberta-base', type_vocab_size=3)
    model = CoLAKEForRE(config,
                        num_types=len(train_set.label_vocab),
                        ent_emb=ent_emb)
    states_dict = torch.load(os.path.join(args.model_path, 'model.bin'))
    model.load_state_dict(states_dict, strict=False)
    print('parameters below are randomly initializecd:')
    for name, param in model.named_parameters():
        if name not in states_dict:
            print(name)

    # tie relation classification head
    rel_index = []
    for k, v in train_set.label_vocab.items():
        rel_index.append(rel_vocab[k])
    rel_index = torch.LongTensor(rel_index)
    rel_embeddings = nn.Embedding.from_pretrained(
        states_dict['rel_embeddings.weight'])
    rel_index = rel_index.cuda()
    rel_cls_weight = rel_embeddings(rel_index.view(1, -1)).squeeze()
    model.tie_rel_weights(rel_cls_weight)

    model.rel_head.dense.weight.data = states_dict['rel_lm_head.dense.weight']
    model.rel_head.dense.bias.data = states_dict['rel_lm_head.dense.bias']
    model.rel_head.layer_norm.weight.data = states_dict[
        'rel_lm_head.layer_norm.weight']
    model.rel_head.layer_norm.bias.data = states_dict[
        'rel_lm_head.layer_norm.bias']

    model.resize_token_embeddings(
        len(RobertaTokenizer.from_pretrained('roberta-base')) + 4)
    print('parameters of CoLAKE has been loaded.')

    # fine-tune
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight', 'embedding']
    param_optimizer = list(model.named_parameters())
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        args.weight_decay
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]
    optimizer = optim.AdamW(optimizer_grouped_parameters,
                            lr=args.lr,
                            betas=(0.9, args.beta),
                            eps=1e-6)

    metrics = [MacroMetric(pred='pred', target='target')]

    test_data_iter = TorchLoaderIter(dataset=test_set,
                                     batch_size=args.batch_size,
                                     sampler=RandomSampler(),
                                     num_workers=4,
                                     collate_fn=test_set.collate_fn)
    devices = list(range(torch.cuda.device_count()))
    tester = Tester(data=test_data_iter,
                    model=model,
                    metrics=metrics,
                    device=devices)
    # tester.test()

    fitlog_callback = FitlogCallback(tester=tester,
                                     log_loss_every=100,
                                     verbose=1)
    gradient_clip_callback = GradientClipCallback(clip_value=1,
                                                  clip_type='norm')
    warmup_callback = WarmupCallback(warmup=args.warm_up, schedule='linear')

    bsz = args.batch_size // args.grad_accumulation

    train_data_iter = TorchLoaderIter(dataset=train_set,
                                      batch_size=bsz,
                                      sampler=RandomSampler(),
                                      num_workers=4,
                                      collate_fn=train_set.collate_fn)
    dev_data_iter = TorchLoaderIter(dataset=dev_set,
                                    batch_size=bsz,
                                    sampler=RandomSampler(),
                                    num_workers=4,
                                    collate_fn=dev_set.collate_fn)

    trainer = Trainer(
        train_data=train_data_iter,
        dev_data=dev_data_iter,
        model=model,
        optimizer=optimizer,
        loss=LossInForward(),
        batch_size=bsz,
        update_every=args.grad_accumulation,
        n_epochs=args.epoch,
        metrics=metrics,
        callbacks=[fitlog_callback, gradient_clip_callback, warmup_callback],
        device=devices,
        use_tqdm=True)

    trainer.train(load_best_model=False)