Exemple #1
0
def load_masks(masks_dir):
    masks_path = [
        os.path.join(masks_dir, f) for f in os.listdir(masks_dir)
        if not f.startswith("init")
    ]
    masks_path = list(
        sorted(
            filter(lambda f: os.path.isfile(f), masks_path),
            key=lambda s: int(os.path.basename(s).split("_")[0]),
        ))
    masks = []
    logger.info("loading masks")
    for path in masks_path:
        dump = torch.load(path, "cpu")
        assert "mask" in dump and "pruning_time" in dump
        logger.info("loading pruning_time {}, mask in {}".format(
            dump["pruning_time"], path))
        masks.append(dump["mask"])

    # sanity check
    assert len(masks) == len(masks_path)
    for mi in masks:
        for name, m in mi.items():
            assert isinstance(m, torch.Tensor)
            mi[name] = m.bool()
    return masks
Exemple #2
0
    def test_stdout(self, mock_out):
        for i in range(3):
            logger.info(self.msg)
            logger.debug('aabbc')

        self.assertEqual([self.msg for i in range(3)],
                         mock_out.getvalue().strip().split('\n'))
Exemple #3
0
def print_info(*inp, islog=False, sep=' '):
    from fastNLP import logger
    if islog:
        print(*inp, sep=sep)
    else:
        inp = sep.join(map(str, inp))
        logger.info(inp)
Exemple #4
0
 def test_add_file(self):
     fn = os.path.join(self.tmpdir, 'log.txt')
     logger.add_file(fn)
     logger.info(self.msg)
     with open(fn, 'r') as f:
         line = ''.join([l for l in f])
         print(line)
     self.assertTrue(self.msg in line)
def iterative_train_and_prune_single_task(get_trainer,args,model,train_set,dev_set,test_set,device,save_path=None):

    '''

    :param trainer:
    :param ITER:
    :param PRUNE:
    :param is_global:
    :param save_path: should be a dictionary which will be filled with mask and state dict
    :return:
    '''



    from fastNLP import Trainer
    import torch
    import math
    import copy
    PRUNE = args.prune
    ITER = args.iter
    trainer = get_trainer(args,model,train_set,dev_set,test_set,device)
    optimizer_init_state_dict = copy.deepcopy(trainer.optimizer.state_dict())
    model_init_state_dict = copy.deepcopy(trainer.model.state_dict())
    if save_path is not None:
        if not os.path.exists(save_path):
            os.makedirs(save_path)
        # if not os.path.exists(os.path.join(save_path, 'model_init.pkl')):
        #     f = open(os.path.join(save_path, 'model_init.pkl'), 'wb')
        #     torch.save(trainer.model.state_dict(),f)


    mask_count = 0
    model = trainer.model
    task = trainer.model.now_task
    for name, p in model.mask[task].items():
        mask_count += torch.sum(p).item()
    init_mask_count = mask_count
    logger.info('init mask count:{}'.format(mask_count))
    # logger.info('{}th traning mask count: {} / {} = {}%'.format(i, mask_count, init_mask_count,
    #                                                             mask_count / init_mask_count * 100))

    prune_per_iter = math.pow(PRUNE, 1 / ITER)


    for i in range(ITER):
        trainer = get_trainer(args,model,train_set,dev_set,test_set,device)
        one_time_train_and_prune_single_task(trainer,prune_per_iter,optimizer_init_state_dict,model_init_state_dict)
        if save_path is not None:
            f = open(os.path.join(save_path,task+'_mask_'+str(i)+'.pkl'),'wb')
            torch.save(model.mask[task],f)

        mask_count = 0
        for name, p in model.mask[task].items():
            mask_count += torch.sum(p).item()
        logger.info('{}th traning mask count: {} / {} = {}%'.format(i,mask_count,init_mask_count,mask_count/init_mask_count*100))
Exemple #6
0
 def __init__(self, model, masks):
     self.model = model
     self.masks = masks
     self.weights = []
     if self.masks is None:
         mask = {}
         for name, param in self.model.named_parameters():
             m = torch.zeros_like(param.data).bool()
             mask[name] = m
         self.masks = mask
     logger.info("has masks %d, %s", len(self.masks), type(self.masks))
Exemple #7
0
def print_info(*inp, islog=True, sep=' '):
    """
    打印日志或者写到日志文件
    :param inp:
    :param islog:
    :param sep:
    :return:
    """
    if islog:
        print(*inp, sep=sep)
    else:
        inp = sep.join(map(str, inp))
        logger.info(inp)
Exemple #8
0
 def to(self, device):
     # logger.info(type(self.model), type(self.masks), device)
     logger.info("model to %s", device)
     self.model.to(device)
     if self.masks is None:
         return
     if isinstance(self.masks, dict):
         masks = [self.masks]
     else:
         masks = self.masks
     for i, mask in enumerate(masks):
         logger.info("mask {} to {}".format(i, device))
         for name, m in mask.items():
             mask[name] = m.to(device)
Exemple #9
0
    def load(self, path):
        state = torch.load(path)
        # self.backup_weights = state['init_weights']
        self.remain_mask = state["mask"]
        self.prune_times = state["pruning_time"]

        # sanity check
        for name, _ in self._model.named_parameters():
            assert name in self.backup_weights
        for name, m in self.remain_mask.items():
            assert name in self.backup_weights
        self._model.to("cuda")
        logger.info("load mask from %s", path)
        logger.info("current pruning time %d", self.prune_times)
Exemple #10
0
def data_summary(task_lst, vocabs=None):
    logger.info("******** DATA SUMMARY ********")
    logger.info("Contain {} tasks".format(len(task_lst)))
    for task in task_lst:
        logger.info(
            "Task {}: {},\tnum of samples: train {}, dev {}, test {}".format(
                task.task_id,
                task.task_name,
                len(task.train_set),
                len(task.dev_set),
                len(task.test_set),
            ))
    if vocabs is None:
        return
    logger.info("Contain {} vocabs".format(len(vocabs)))
    for name, v in vocabs.items():
        logger.info("Vocab {}: has length {},\t{}".format(name, len(v), v))
Exemple #11
0
def load_model(model, path):
    dumps = torch.load(path, map_location="cpu")

    if model is None:
        assert isinstance(dumps,
                          nn.Module), "model is None but load %s" % type(dumps)
        model = dumps
    else:
        if isinstance(dumps, nn.Module):
            dumps = dumps.state_dict()
        else:
            assert isinstance(dumps, dict), type(dumps)
        res = model.load_state_dict(dumps, strict=False)
        assert len(res.unexpected_keys) == 0, res.unexpected_keys
        logger.info("missing keys in init-weights %s", res.missing_keys)
    logger.info("load init-weights from %s", path)
    return model
Exemple #12
0
def check_words_same(dataset_1, dataset_2, field_1, field_2):
    if len(dataset_1[field_1]) != len(dataset_2[field_2]):
        logger.info('CHECK: example num not same!')
        return False

    for i, words in enumerate(dataset_1[field_1]):
        if len(dataset_1[field_1][i]) != len(dataset_2[field_2][i]):
            logger.info('CHECK {} th example length not same'.format(i))
            logger.info('1:{}'.format(dataset_1[field_1][i]))
            logger.info('2:'.format(dataset_2[field_2][i]))
            return False

        # for j,w in enumerate(words):
        #     if dataset_1[field_1][i][j] != dataset_2[field_2][i][j]:
        #         print('CHECK', i, 'th example has words different!')
        #         print('1:',dataset_1[field_1][i])
        #         print('2:',dataset_2[field_2][i])
        #         return False

    logger.info('CHECK: totally same!')

    return True
Exemple #13
0
def get_appropriate_cuda(task_scale='s'):
    if task_scale not in {'s', 'm', 'l'}:
        logger.info('task scale wrong!')
        exit(2)
    import pynvml
    pynvml.nvmlInit()
    total_cuda_num = pynvml.nvmlDeviceGetCount()
    for i in range(total_cuda_num):
        logger.info(i)
        handle = pynvml.nvmlDeviceGetHandleByIndex(i)  # 这里的0是GPU id
        memInfo = pynvml.nvmlDeviceGetMemoryInfo(handle)
        utilizationInfo = pynvml.nvmlDeviceGetUtilizationRates(handle)
        logger.info(i, 'mem:', memInfo.used / memInfo.total, 'util:',
                    utilizationInfo.gpu)
        if memInfo.used / memInfo.total < 0.15 and utilizationInfo.gpu < 0.2:
            logger.info(i, memInfo.used / memInfo.total)
            return 'cuda:' + str(i)

    if task_scale == 's':
        max_memory = 2000
    elif task_scale == 'm':
        max_memory = 6000
    else:
        max_memory = 9000

    max_id = -1
    for i in range(total_cuda_num):
        handle = pynvml.nvmlDeviceGetHandleByIndex(0)  # 这里的0是GPU id
        memInfo = pynvml.nvmlDeviceGetMemoryInfo(handle)
        utilizationInfo = pynvml.nvmlDeviceGetUtilizationRates(handle)
        if max_memory < memInfo.free:
            max_memory = memInfo.free
            max_id = i

    if id == -1:
        logger.info('no appropriate gpu, wait!')
        exit(2)

    return 'cuda:' + str(max_id)
Exemple #14
0
def num_parameters(model):
    sum_params = 0
    for name, param in model.named_parameters():
        logger.info("{}: {}".format(name, param.shape))
        sum_params += param.numel()
    return sum_params
    embed_dropout = 0.3
    cls_dropout = 0.1
    weight_decay = 1e-5

    def __init__(self):
        self.datadir = os.path.join(os.environ['HOME'], self.datadir)
        self.datapath = {
            k: os.path.join(self.datadir, v)
            for k, v in self.datafile.items()
        }


ops = Config()

set_rng_seeds(ops.seed)
logger.info('RNG SEED %d' % ops.seed)

# 1.task相关信息:利用dataloader载入dataInfo


@cache_results(ops.model_dir_or_name + '-data-cache')
def load_data():
    datainfo = YelpFullPipe(lower=True,
                            tokenizer='raw').process_from_file(ops.datapath)
    for ds in datainfo.datasets.values():
        ds.apply_field(len, C.INPUT, C.INPUT_LEN)
        ds.set_input(C.INPUT, C.INPUT_LEN)
        ds.set_target(C.TARGET)

    return datainfo
Exemple #16
0
    parser.add_argument("--tasks",
                        type=str,
                        default=None,
                        help='the task ids for MTL, default using all tasks')
    parser.add_argument("--trainer",
                        type=str,
                        choices=['re-seq-label', 'seq-label'],
                        default='seq-label',
                        help='the trainer type')
    # fmt: on

    args = parser.parse_args()

    utils.init_prog(args)

    logger.info(args)
    torch.save(args, os.path.join(args.save_path, "args.th"))

    n_gpu = torch.cuda.device_count()
    print("# of gpu: {}".format(n_gpu))

    logger.info("========== Loading Datasets ==========")
    task_lst, vocabs = utils.get_data(args.data_path)
    if args.tasks is not None:
        args.tasks = list(
            map(int, map(lambda s: s.strip(), args.tasks.split(","))))
        logger.info("activate tasks %s", args.tasks)
    logger.info("# of Tasks: {}.".format(len(task_lst)))
    for task in task_lst:
        logger.info("Task {}: {}".format(task.task_id, task.task_name))
    for task in task_lst:
Exemple #17
0
def main():

    parser = argparse.ArgumentParser()
    arg_options.add_path_options(parser)
    arg_options.add_para_options(parser)
    args = parser.parse_args()

    os.environ['CUDA_VISIBLE_DEVICES'] = '2,3'

    n_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1

    print("num gpus: {}".format(n_gpus))
    is_distributed = n_gpus > 1
    if is_distributed:
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group('nccl', init_method='env://')
        args.world_size = dist.get_world_size()
        args.local_rank = int(args.local_rank)
        # synchronize()

    # Setup logging
    log_file_path = os.path.join(
        args.log_output_dir, 'log-{}.txt'.format(
            time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))))
    init_logger_dist()
    logger.add_file(log_file_path, level='INFO')
    # logging.basicConfig(
    #     format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    #     datefmt="%m/%d/%Y %H:%M:%S",
    #     level=logging.INFO,
    # )
    # logging_fh = logging.FileHandler(log_file_path)
    # logging_fh.setLevel(logging.DEBUG)
    # logger.addHandler(logging_fh)

    args.test_logging_name = log_file_path.split('/')[-1].split(
        '.')[0].replace('log-', '')
    print(log_file_path.split('/')[-1].split('.')[0].replace('log-', ''))

    # cpt prep data load
    if args.local_rank == 0:
        print("Load prep data...\n")
    cpt2words, cpt2id, et2id, sememes, cpt_tree = load_annotated_concepts_new()
    verb_cpts = load_verb_concepts()
    CptEmb = CptEmbedding(sememes, cpt2words, cpt2id, et2id, cpt_tree,
                          args.cpt_max_num, args.random_cpt_num)

    word2cpt_ids = CptEmb.word2cpt_idx
    verb_cpt_ids = [cpt2id[cc] for cc in verb_cpts]
    sememe2id = CptEmb.sememe2id

    cpt_vec = torch.load(
        CptEmb.cpt_vec_in_bert_file
    )[:34442]  # [:10907]   # [:34442]    # 34443 * 768, padding index = 34442
    logger.info("cpt embedding file: {}".format(CptEmb.cpt_vec_in_bert_file))
    logger.info("cpt vec length: {}".format(len(cpt_vec)))

    et2cpts = CptEmb.et2cpts
    cpt2center_sem = CptEmb.cpt2center_sem
    cpt_id2center_sem_id = {
        cpt2id[cc]: sememe2id[sem]
        for cc, sem in cpt2center_sem.items()
    }

    id2cpt = {idx: cc for cc, idx in cpt2id.items()}
    id2et = {id: et for et, id in et2id.items()}
    anno_cpt2et = defaultdict(list)
    et_id2cpt_ids = defaultdict(list)
    for et, cpts in et2cpts.items():
        # print(self.et_id2cpt_ids)
        for cc in cpts:
            anno_cpt2et[cc].append(et)
            et_id2cpt_ids[et2id[et]].append(cpt2id[cc])
    cpt_id2et_id = {
        cpt2id[cc]: [et2id[et] for et in ets]
        for cc, ets in anno_cpt2et.items()
    }

    args.cpt_num = CptEmb.cpt_num
    logger.info("cpt nums: {}\n".format(args.cpt_num))
    logger.info("HowNet words cnt: {}".format(len(word2cpt_ids)))

    # pred DataSet

    train_samples = MyData(args, 'train', args.world_size, args.local_rank)
    dev_samples = MyData(args, 'dev', args.world_size, args.local_rank)
    dev_ace_samples = MyTestData(
        args,
        os.path.join(config.cached_data_dir, "cached_devACE_fixed_samples"),
        args.local_rank)
    logger.info("rank {} / {} load dataset with length: {}.".format(
        args.local_rank, args.world_size, len(train_samples)))
    test_ace_samples = None
    # ************** train data ************************
    train_sampler = DistributedSampler(train_samples,
                                       rank=args.local_rank,
                                       num_replicas=args.world_size)
    train_loader = DataLoader(train_samples,
                              batch_size=args.per_gpu_train_batch_size,
                              pin_memory=True,
                              sampler=train_sampler,
                              num_workers=args.num_workers,
                              collate_fn=train_samples.collate_fn)
    # ************** dev data ************************
    dev_loader = DataLoader(dev_samples,
                            batch_size=args.per_gpu_dev_batch_size,
                            collate_fn=dev_samples.collate_fn)
    dev_ace_loader = DataLoader(dev_ace_samples,
                                batch_size=args.per_gpu_eval_batch_size,
                                collate_fn=dev_ace_samples.collate_fn)
    # ************** test data ************************
    # self.test_loader = DataLoader(test_ace_samples, batch_size=args.per_gpu_eval_batch_size,
    #                               collate_fn=test_ace_samples.collate_fn)

    # ************** init model ***************************
    tokenizer = BertTokenizer.from_pretrained(args.bert_model_dir)
    bert_config = BertConfig.from_pretrained(args.bert_model_dir)
    bert_config.is_decoder = False
    cpt_model = commonCptODEE(args, bert_config, cpt_vec, len(cpt_vec[0]))

    # pred Trainer
    trainer = Trainer(args=args,
                      train_samples=train_loader,
                      dev_samples=dev_loader,
                      dev_ace_samples=dev_ace_loader,
                      test_ace_samples=None,
                      cpt_model=cpt_model,
                      id2cpt=id2cpt,
                      id2et=id2et,
                      cpt_id2et_id=cpt_id2et_id)
    trainer.train()
Exemple #18
0
def train():
    args = parse_args()
    if args.debug:
        fitlog.debug()
        args.save_model = False
    # ================= define =================
    tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
    word_mask_index = tokenizer.mask_token_id
    word_vocab_size = len(tokenizer)

    if get_local_rank() == 0:
        fitlog.set_log_dir(args.log_dir)
        fitlog.commit(__file__, fit_msg=args.name)
        fitlog.add_hyper_in_file(__file__)
        fitlog.add_hyper(args)

    # ================= load data =================
    dist.init_process_group('nccl')
    init_logger_dist()

    n_proc = dist.get_world_size()
    bsz = args.batch_size // args.grad_accumulation // n_proc
    args.local_rank = get_local_rank()
    args.save_dir = os.path.join(args.save_dir,
                                 args.name) if args.save_model else None
    if args.save_dir is not None and os.path.exists(args.save_dir):
        raise RuntimeError('save_dir has already existed.')
    logger.info('save directory: {}'.format(
        'None' if args.save_dir is None else args.save_dir))
    devices = list(range(torch.cuda.device_count()))
    NUM_WORKERS = 4

    ent_vocab, rel_vocab = load_ent_rel_vocabs()
    logger.info('# entities: {}'.format(len(ent_vocab)))
    logger.info('# relations: {}'.format(len(rel_vocab)))
    ent_freq = get_ent_freq()
    assert len(ent_vocab) == len(ent_freq), '{} {}'.format(
        len(ent_vocab), len(ent_freq))

    #####
    root = args.data_dir
    dirs = os.listdir(root)
    drop_files = []
    for dir in dirs:
        path = os.path.join(root, dir)
        max_idx = 0
        for file_name in os.listdir(path):
            if 'large' in file_name:
                continue
            max_idx = int(file_name) if int(file_name) > max_idx else max_idx
        drop_files.append(os.path.join(path, str(max_idx)))
    #####

    file_list = []
    for path, _, filenames in os.walk(args.data_dir):
        for filename in filenames:
            file = os.path.join(path, filename)
            if 'large' in file or file in drop_files:
                continue
            file_list.append(file)
    logger.info('used {} files in {}.'.format(len(file_list), args.data_dir))
    if args.data_prop > 1:
        used_files = file_list[:int(args.data_prop)]
    else:
        used_files = file_list[:round(args.data_prop * len(file_list))]

    data = GraphOTFDataSet(used_files, n_proc, args.local_rank,
                           word_mask_index, word_vocab_size, args.n_negs,
                           ent_vocab, rel_vocab, ent_freq)
    dev_data = GraphDataSet(used_files[0], word_mask_index, word_vocab_size,
                            args.n_negs, ent_vocab, rel_vocab, ent_freq)

    sampler = OTFDistributedSampler(used_files, n_proc, get_local_rank())
    train_data_iter = TorchLoaderIter(dataset=data,
                                      batch_size=bsz,
                                      sampler=sampler,
                                      num_workers=NUM_WORKERS,
                                      collate_fn=data.collate_fn)
    dev_data_iter = TorchLoaderIter(dataset=dev_data,
                                    batch_size=bsz,
                                    sampler=RandomSampler(),
                                    num_workers=NUM_WORKERS,
                                    collate_fn=dev_data.collate_fn)
    if args.test_data is not None:
        test_data = FewRelDevDataSet(path=args.test_data,
                                     label_vocab=rel_vocab,
                                     ent_vocab=ent_vocab)
        test_data_iter = TorchLoaderIter(dataset=test_data,
                                         batch_size=32,
                                         sampler=RandomSampler(),
                                         num_workers=NUM_WORKERS,
                                         collate_fn=test_data.collate_fn)

    if args.local_rank == 0:
        print('full wiki files: {}'.format(len(file_list)))
        print('used wiki files: {}'.format(len(used_files)))
        print('# of trained samples: {}'.format(len(data) * n_proc))
        print('# of trained entities: {}'.format(len(ent_vocab)))
        print('# of trained relations: {}'.format(len(rel_vocab)))

    # ================= prepare model =================
    logger.info('model init')
    if args.rel_emb is not None:  # load pretrained relation embeddings
        rel_emb = np.load(args.rel_emb)
        # add_embs = np.random.randn(3, rel_emb.shape[1])  # add <pad>, <mask>, <unk>
        # rel_emb = np.r_[add_embs, rel_emb]
        rel_emb = torch.from_numpy(rel_emb).float()
        assert rel_emb.shape[0] == len(rel_vocab), '{} {}'.format(
            rel_emb.shape[0], len(rel_vocab))
        # assert rel_emb.shape[1] == args.rel_dim
        logger.info('loaded pretrained relation embeddings. dim: {}'.format(
            rel_emb.shape[1]))
    else:
        rel_emb = None
    if args.model_name is not None:
        logger.info('further pre-train.')
        config = RobertaConfig.from_pretrained('roberta-base',
                                               type_vocab_size=3)
        model = CoLAKE(config=config,
                       num_ent=len(ent_vocab),
                       num_rel=len(rel_vocab),
                       ent_dim=args.ent_dim,
                       rel_dim=args.rel_dim,
                       ent_lr=args.ent_lr,
                       ip_config=args.ip_config,
                       rel_emb=None,
                       emb_name=args.emb_name)
        states_dict = torch.load(args.model_name)
        model.load_state_dict(states_dict, strict=True)
    else:
        model = CoLAKE.from_pretrained(
            'roberta-base',
            num_ent=len(ent_vocab),
            num_rel=len(rel_vocab),
            ent_lr=args.ent_lr,
            ip_config=args.ip_config,
            rel_emb=rel_emb,
            emb_name=args.emb_name,
            cache_dir=PYTORCH_PRETRAINED_BERT_CACHE /
            'dist_{}'.format(args.local_rank))
        model.extend_type_embedding(token_type=3)
    # if args.local_rank == 0:
    #     for name, param in model.named_parameters():
    #         if param.requires_grad is True:
    #             print('{}: {}'.format(name, param.shape))

    # ================= train model =================
    # lr=1e-4 for peak value, lr=5e-5 for initial value
    logger.info('trainer init')
    no_decay = [
        'bias', 'LayerNorm.bias', 'LayerNorm.weight', 'layer_norm.bias',
        'layer_norm.weight'
    ]
    param_optimizer = list(model.named_parameters())
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]
    word_acc = WordMLMAccuracy(pred='word_pred',
                               target='masked_lm_labels',
                               seq_len='word_seq_len')
    ent_acc = EntityMLMAccuracy(pred='entity_pred',
                                target='ent_masked_lm_labels',
                                seq_len='ent_seq_len')
    rel_acc = RelationMLMAccuracy(pred='relation_pred',
                                  target='rel_masked_lm_labels',
                                  seq_len='rel_seq_len')
    metrics = [word_acc, ent_acc, rel_acc]

    if args.test_data is not None:
        test_metric = [rel_acc]
        tester = Tester(data=test_data_iter,
                        model=model,
                        metrics=test_metric,
                        device=list(range(torch.cuda.device_count())))
        # tester.test()
    else:
        tester = None

    optimizer = optim.AdamW(optimizer_grouped_parameters,
                            lr=args.lr,
                            betas=(0.9, args.beta),
                            eps=1e-6)
    # warmup_callback = WarmupCallback(warmup=args.warm_up, schedule='linear')
    fitlog_callback = MyFitlogCallback(tester=tester,
                                       log_loss_every=100,
                                       verbose=1)
    gradient_clip_callback = GradientClipCallback(clip_value=1,
                                                  clip_type='norm')
    emb_callback = EmbUpdateCallback(model.ent_embeddings)
    all_callbacks = [gradient_clip_callback, emb_callback]
    if args.save_dir is None:
        master_callbacks = [fitlog_callback]
    else:
        save_callback = SaveModelCallback(args.save_dir,
                                          model.ent_embeddings,
                                          only_params=True)
        master_callbacks = [fitlog_callback, save_callback]

    if args.do_test:
        states_dict = torch.load(os.path.join(args.save_dir,
                                              args.model_name)).state_dict()
        model.load_state_dict(states_dict)
        data_iter = TorchLoaderIter(dataset=data,
                                    batch_size=args.batch_size,
                                    sampler=RandomSampler(),
                                    num_workers=NUM_WORKERS,
                                    collate_fn=data.collate_fn)
        tester = Tester(data=data_iter,
                        model=model,
                        metrics=metrics,
                        device=devices)
        tester.test()
    else:
        trainer = DistTrainer(train_data=train_data_iter,
                              dev_data=dev_data_iter,
                              model=model,
                              optimizer=optimizer,
                              loss=LossInForward(),
                              batch_size_per_gpu=bsz,
                              update_every=args.grad_accumulation,
                              n_epochs=args.epoch,
                              metrics=metrics,
                              callbacks_master=master_callbacks,
                              callbacks_all=all_callbacks,
                              validate_every=5000,
                              use_tqdm=True,
                              fp16='O1' if args.fp16 else '')
        trainer.train(load_best_model=False)
Exemple #19
0
"""测试log模块"""
from fastNLP import logger

if __name__ == "__main__":
    logger.info('测试一下info')
    logger.debug('测试一下debug')
    logger.warning('测试一下warning')
    logger.error('测试一下error')
    foo_bar(width,
            height,
            color='black',
            design=None,
            x='foo',
            emphasis=None,
            highlight=0)
 def on_epoch_end(self):
     for ith, param in enumerate(self._optimizer.param_groups):
         logger.info(f"{ith} lr: {param['lr']}")
Exemple #21
0
    parser.add_argument('--type',
                        choices=['conll03', 'ontonotes', 'ptb'],
                        help='multi task data type')
    parser.add_argument('--out',
                        type=str,
                        default='data',
                        help='processed data output dir')
    # fmt: on
    args = parser.parse_args()
    assert args.pos is not None
    return args


if __name__ == "__main__":
    args = get_args()
    set_seed(1)
    parse_table = {
        "conll03": prepare_conll03,
        "ontonotes": prepare_ontonotes,
        "ptb": prepare_ptb,
    }
    logger.info(args)
    assert args.type in parse_table
    task_lst, vocabs = parse_table[args.type](args)
    os.makedirs(args.out, exist_ok=True)
    data_summary(task_lst, vocabs)
    path = os.path.join(args.out, args.type + ".pkl")
    logger.info("saving data to " + path)
    pdump({"task_lst": task_lst, "vocabs": vocabs}, path)
    def __init__(self,
                 vocab: Vocabulary,
                 embed_size: int = 30,
                 char_emb_size: int = 30,
                 word_dropout: float = 0,
                 dropout: float = 0,
                 pool_method: str = 'max',
                 activation='relu',
                 min_char_freq: int = 2,
                 requires_grad=True,
                 include_word_start_end=True,
                 char_attn_type='adatrans',
                 char_n_head=3,
                 char_dim_ffn=60,
                 char_scale=False,
                 char_pos_embed=None,
                 char_dropout=0.15,
                 char_after_norm=False):
        """
        :param vocab: 词表
        :param embed_size: TransformerCharEmbed的输出维度。默认值为50.
        :param char_emb_size: character的embedding的维度。默认值为50. 同时也是Transformer的d_model大小
        :param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。
        :param dropout: 以多大概率drop character embedding的输出以及最终的word的输出。
        :param pool_method: 支持'max', 'avg'。
        :param activation: 激活函数,支持'relu', 'sigmoid', 'tanh', 或者自定义函数.
        :param min_char_freq: character的最小出现次数。默认值为2.
        :param requires_grad:
        :param include_word_start_end: 是否使用特殊的tag标记word的开始与结束
        :param char_attn_type: adatrans or naive.
        :param char_n_head: 多少个head
        :param char_dim_ffn: transformer中ffn中间层的大小
        :param char_scale: 是否使用scale
        :param char_pos_embed: None, 'fix', 'sin'. What kind of position embedding. When char_attn_type=relative, None is
            ok
        :param char_dropout: Dropout in Transformer encoder
        :param char_after_norm: the normalization place.
        """
        super(TransformerCharEmbed, self).__init__(vocab,
                                                   word_dropout=word_dropout,
                                                   dropout=dropout)

        assert char_emb_size % char_n_head == 0, "d_model should divide n_head."

        assert pool_method in ('max', 'avg')
        self.pool_method = pool_method
        # activation function
        if isinstance(activation, str):
            if activation.lower() == 'relu':
                self.activation = F.relu
            elif activation.lower() == 'sigmoid':
                self.activation = F.sigmoid
            elif activation.lower() == 'tanh':
                self.activation = F.tanh
        elif activation is None:
            self.activation = lambda x: x
        elif callable(activation):
            self.activation = activation
        else:
            raise Exception(
                "Undefined activation function: choose from: [relu, tanh, sigmoid, or a callable function]"
            )

        logger.info("Start constructing character vocabulary.")
        # 建立char的词表
        self.char_vocab = _construct_char_vocab_from_vocab(
            vocab,
            min_freq=min_char_freq,
            include_word_start_end=include_word_start_end)
        self.char_pad_index = self.char_vocab.padding_idx
        logger.info(
            f"In total, there are {len(self.char_vocab)} distinct characters.")
        # 对vocab进行index
        max_word_len = max(map(lambda x: len(x[0]), vocab))
        if include_word_start_end:
            max_word_len += 2
        self.register_buffer(
            'words_to_chars_embedding',
            torch.full((len(vocab), max_word_len),
                       fill_value=self.char_pad_index,
                       dtype=torch.long))
        self.register_buffer('word_lengths', torch.zeros(len(vocab)).long())
        for word, index in vocab:
            # if index!=vocab.padding_idx:  # 如果是pad的话,直接就为pad_value了. 修改为不区分pad与否
            if include_word_start_end:
                word = ['<bow>'] + list(word) + ['<eow>']
            self.words_to_chars_embedding[index, :len(word)] = \
                torch.LongTensor([self.char_vocab.to_index(c) for c in word])
            self.word_lengths[index] = len(word)

        self.char_embedding = get_embeddings(
            (len(self.char_vocab), char_emb_size))
        self.transformer = TransformerEncoder(1,
                                              char_emb_size,
                                              char_n_head,
                                              char_dim_ffn,
                                              dropout=char_dropout,
                                              after_norm=char_after_norm,
                                              attn_type=char_attn_type,
                                              pos_embed=char_pos_embed,
                                              scale=char_scale)
        self.fc = nn.Linear(char_emb_size, embed_size)

        self._embed_size = embed_size

        self.requires_grad = requires_grad
Exemple #23
0
def tester(model, test_batch, write_out=False):
    res = []
    prf = utils.CWSEvaluator(i2t)
    prf_dataset = {}
    oov_dataset = {}

    logger.info("start evaluation")
    # import ipdb; ipdb.set_trace()
    with torch.no_grad():
        for batch_x, batch_y in test_batch:
            batch_to_device(batch_x, device)
            # batch_to_device(batch_y, device)
            if bigram_embedding is not None:
                out = model(
                    batch_x["task"],
                    batch_x["uni"],
                    batch_x["seq_len"],
                    batch_x["bi1"],
                    batch_x["bi2"],
                )
            else:
                out = model(batch_x["task"], batch_x["uni"], batch_x["seq_len"])
            out = out["pred"]
            # print(out)
            num = out.size(0)
            out = out.detach().cpu().numpy()
            for i in range(num):
                length = int(batch_x["seq_len"][i])

                out_tags = out[i, 1:length].tolist()
                sentence = batch_x["words"][i]
                gold_tags = batch_y["tags"][i][1:length].numpy().tolist()
                dataset_name = sentence[0]
                sentence = sentence[1:]
                # print(out_tags,gold_tags)
                assert utils.is_dataset_tag(dataset_name), dataset_name
                assert len(gold_tags) == len(out_tags) and len(gold_tags) == len(
                    sentence
                )

                if dataset_name not in prf_dataset:
                    prf_dataset[dataset_name] = utils.CWSEvaluator(i2t)
                    oov_dataset[dataset_name] = utils.CWS_OOV(
                        word_dic[dataset_name[1:-1]]
                    )

                prf_dataset[dataset_name].add_instance(gold_tags, out_tags)
                prf.add_instance(gold_tags, out_tags)

                if write_out:
                    gold_strings = utils.to_tag_strings(i2t, gold_tags)
                    obs_strings = utils.to_tag_strings(i2t, out_tags)

                    word_list = utils.bmes_to_words(sentence, obs_strings)
                    oov_dataset[dataset_name].update(
                        utils.bmes_to_words(sentence, gold_strings), word_list
                    )

                    raw_string = " ".join(word_list)
                    res.append(dataset_name + " " + raw_string + " " + dataset_name)

        Ap = 0.0
        Ar = 0.0
        Af = 0.0
        Aoov = 0.0
        tot = 0
        nw = 0.0
        for dataset_name, performance in sorted(prf_dataset.items()):
            p = performance.result()
            if write_out:
                nw = oov_dataset[dataset_name].oov()
                # nw = 0
                logger.info(
                    "{}\t{:04.2f}\t{:04.2f}\t{:04.2f}\t{:04.2f}".format(
                        dataset_name, p[0], p[1], p[2], nw
                    )
                )
            else:
                logger.info(
                    "{}\t{:04.2f}\t{:04.2f}\t{:04.2f}".format(
                        dataset_name, p[0], p[1], p[2]
                    )
                )
            Ap += p[0]
            Ar += p[1]
            Af += p[2]
            Aoov += nw
            tot += 1

        prf = prf.result()
        logger.info(
            "{}\t{:04.2f}\t{:04.2f}\t{:04.2f}".format("TOT", prf[0], prf[1], prf[2])
        )
        if not write_out:
            logger.info(
                "{}\t{:04.2f}\t{:04.2f}\t{:04.2f}".format(
                    "AVG", Ap / tot, Ar / tot, Af / tot
                )
            )
        else:
            logger.info(
                "{}\t{:04.2f}\t{:04.2f}\t{:04.2f}\t{:04.2f}".format(
                    "AVG", Ap / tot, Ar / tot, Af / tot, Aoov / tot
                )
            )
    return prf[-1], res
Exemple #24
0
    else:
        logger.setLevel(logging.WARNING)
    return logger


# ===-----------------------------------------------------------------------===
# Set up logging
# ===-----------------------------------------------------------------------===
# logger = init_logger()
logger.add_file("{}/info.log".format(root_dir), "INFO")
logger.setLevel(logging.INFO if dist.get_rank() == 0 else logging.WARNING)

# ===-----------------------------------------------------------------------===
# Log some stuff about this run
# ===-----------------------------------------------------------------------===
logger.info(" ".join(sys.argv))
logger.info("")
logger.info(options)

if options.debug:
    logger.info("DEBUG MODE")
    options.num_epochs = 2
    options.batch_size = 20

random.seed(options.python_seed)
np.random.seed(options.python_seed % (2 ** 32 - 1))
torch.cuda.manual_seed_all(options.python_seed)
logger.info("Python random seed: {}".format(options.python_seed))

# ===-----------------------------------------------------------------------===
# Read in dataset
Exemple #25
0
    def __init__(self,
                 vocab: Vocabulary,
                 embed_size: int = 30,
                 char_emb_size: int = 30,
                 word_dropout: float = 0,
                 dropout: float = 0,
                 pool_method: str = 'max',
                 activation='relu',
                 min_char_freq: int = 2,
                 requires_grad=True,
                 include_word_start_end=True,
                 char_attn_type='adatrans',
                 char_n_head=3,
                 char_dim_ffn=60,
                 char_scale=False,
                 char_pos_embed=None,
                 char_dropout=0.15,
                 char_after_norm=False):
        super(TransformerCharEmbed, self).__init__(vocab,
                                                   word_dropout=word_dropout,
                                                   dropout=dropout)

        assert char_emb_size % char_n_head == 0, "d_model should divide n_head."

        assert pool_method in ('max', 'avg')
        self.pool_method = pool_method
        # activation function
        if isinstance(activation, str):
            if activation.lower() == 'relu':
                self.activation = F.relu
            elif activation.lower() == 'sigmoid':
                self.activation = F.sigmoid
            elif activation.lower() == 'tanh':
                self.activation = F.tanh
        elif activation is None:
            self.activation = lambda x: x
        elif callable(activation):
            self.activation = activation
        else:
            raise Exception(
                "Undefined activation function: choose from: [relu, tanh, sigmoid, or a callable function]"
            )

        logger.info("Start constructing character vocabulary.")

        self.char_vocab = _construct_char_vocab_from_vocab(
            vocab,
            min_freq=min_char_freq,
            include_word_start_end=include_word_start_end)
        self.char_pad_index = self.char_vocab.padding_idx
        logger.info(
            f"In total, there are {len(self.char_vocab)} distinct characters.")

        max_word_len = max(map(lambda x: len(x[0]), vocab))
        if include_word_start_end:
            max_word_len += 2
        self.register_buffer(
            'words_to_chars_embedding',
            torch.full((len(vocab), max_word_len),
                       fill_value=self.char_pad_index,
                       dtype=torch.long))
        self.register_buffer('word_lengths', torch.zeros(len(vocab)).long())
        for word, index in vocab:

            if include_word_start_end:
                word = ['<bow>'] + list(word) + ['<eow>']
            self.words_to_chars_embedding[index, :len(word)] = \
                torch.LongTensor([self.char_vocab.to_index(c) for c in word])
            self.word_lengths[index] = len(word)

        self.char_embedding = get_embeddings(
            (len(self.char_vocab), char_emb_size))
        self.transformer = TransformerEncoder(1,
                                              char_emb_size,
                                              char_n_head,
                                              char_dim_ffn,
                                              dropout=char_dropout,
                                              after_norm=char_after_norm,
                                              attn_type=char_attn_type,
                                              pos_embed=char_pos_embed,
                                              scale=char_scale)
        self.fc = nn.Linear(char_emb_size, embed_size)

        self._embed_size = embed_size

        self.requires_grad = requires_grad
Exemple #26
0
def eval_mtl_single(args):
    global logger
    # import ipdb; ipdb.set_trace()
    args = torch.load(os.path.join(args.save_path, "args"))
    print(args)
    logger.info(args)
    task_lst, vocabs = utils.get_data(args.data_path)
    task_db = task_lst[args.task_id]
    train_data = task_db.train_set
    dev_data = task_db.dev_set
    test_data = task_db.test_set
    task_name = task_db.task_name

    # text classification
    for ds in [train_data, dev_data, test_data]:
        ds.rename_field("words_idx", "x")
        ds.rename_field("label", "y")
        ds.set_input("x", "y", "task_id")
        ds.set_target("y")
    # seq label
    if task_name in SEQ_LABEL_TASK:
        for ds in [train_data, dev_data, test_data]:
            ds.set_input("seq_len")
            ds.set_target("seq_len")

    logger = utils.get_logger(__name__)
    logger.info("task name: {}, task id: {}".format(task_db.task_name, task_db.task_id))
    logger.info(
        "train len {}, dev len {}, test len {}".format(
            len(train_data), len(dev_data), len(test_data)
        )
    )

    # init model
    model = get_model(args, task_lst, vocabs)
    # logger.info('model: \n{}'.format(model))

    if task_name not in SEQ_LABEL_TASK or task_name == "pos":
        metrics = [
            AccuracyMetric(target="y"),
            # MetricInForward(val_name='loss')
        ]
    else:
        metrics = [
            SpanFPreRecMetric(
                tag_vocab=vocabs[task_name],
                pred="pred",
                target="y",
                seq_len="seq_len",
                encoding_type="bioes" if task_name == "ner" else "chunk",
            ),
            AccuracyMetric(target="y")
            # MetricInForward(val_name='loss')
        ]

    cur_best = 0.0
    init_best = None
    eval_time = 0
    paths = [path for path in os.listdir(args.save_path) if path.startswith("best")]
    paths = sorted(paths, key=lambda x: int(x.split("_")[1]))
    for path in paths:
        path = os.path.join(args.save_path, path)
        state = torch.load(path, map_location="cpu")
        model.load_state_dict(state)
        tester = Tester(
            test_data,
            model,
            metrics=metrics,
            batch_size=args.batch_size,
            num_workers=4,
            device="cuda",
            use_tqdm=False,
        )
        res = tester.test()
        val = 0.0
        for metric_name, metric_dict in res.items():
            if task_name == "pos" and "acc" in metric_dict:
                val = metric_dict["acc"]
                break
            elif "f" in metric_dict:
                val = metric_dict["f"]
                break

        if init_best is None:
            init_best = val
        logger.info(
            "No #%d: best %f, %s, path: %s, is better: %s",
            eval_time,
            val,
            tester._format_eval_results(res),
            path,
            val > init_best,
        )

        eval_time += 1
Exemple #27
0
def train_mlt_single(args):
    global logger
    logger.info(args)
    task_lst, vocabs = utils.get_data(args.data_path)
    task_db = task_lst[args.task_id]
    train_data = task_db.train_set
    dev_data = task_db.dev_set
    test_data = task_db.test_set
    task_name = task_db.task_name

    if args.debug:
        train_data = train_data[:200]
        dev_data = dev_data[:200]
        test_data = test_data[:200]
        args.epochs = 3
        args.pruning_iter = 3

    summary_writer = SummaryWriter(
        log_dir=os.path.join(args.tb_path, "global/%s" % task_name)
    )

    logger.info("task name: {}, task id: {}".format(task_db.task_name, task_db.task_id))
    logger.info(
        "train len {}, dev len {}, test len {}".format(
            len(train_data), len(dev_data), len(test_data)
        )
    )

    # init model
    model = get_model(args, task_lst, vocabs)

    logger.info("model: \n{}".format(model))
    if args.init_weights is not None:
        utils.load_model(model, args.init_weights)

    if utils.need_acc(task_name):
        metrics = [AccuracyMetric(target="y"), MetricInForward(val_name="loss")]
        metric_key = "acc"

    else:
        metrics = [
            YangJieSpanMetric(
                tag_vocab=vocabs[task_name],
                pred="pred",
                target="y",
                seq_len="seq_len",
                encoding_type="bioes" if task_name == "ner" else "bio",
            ),
            MetricInForward(val_name="loss"),
        ]
        metric_key = "f"
    logger.info(metrics)

    need_cut_names = list(set([s.strip() for s in args.need_cut.split(",")]))
    prune_names = []
    for name, p in model.named_parameters():
        if not p.requires_grad or "bias" in name:
            continue
        for n in need_cut_names:
            if n in name:
                prune_names.append(name)
                break

    # get Pruning class
    pruner = Pruning(
        model, prune_names, final_rate=args.final_rate, pruning_iter=args.pruning_iter
    )
    if args.init_masks is not None:
        pruner.load(args.init_masks)
        pruner.apply_mask(pruner.remain_mask, pruner._model)
    # save checkpoint
    os.makedirs(args.save_path, exist_ok=True)

    logger.info('Saving init-weights to {}'.format(args.save_path))
    torch.save(
        model.cpu().state_dict(), os.path.join(args.save_path, "init_weights.th")
    )
    torch.save(args, os.path.join(args.save_path, "args.th"))
    # start training and pruning
    summary_writer.add_scalar("remain_rate", 100.0, 0)
    summary_writer.add_scalar("cutoff", 0.0, 0)

    if args.init_weights is not None:
        init_tester = Tester(
            test_data,
            model,
            metrics=metrics,
            batch_size=args.batch_size,
            num_workers=4,
            device="cuda",
            use_tqdm=False,
        )
        res = init_tester.test()
        logger.info("No init testing, Result: {}".format(res))
        del res, init_tester

    for prune_step in range(pruner.pruning_iter + 1):
        # reset optimizer every time
        optim_params = [p for p in model.parameters() if p.requires_grad]
        # utils.get_logger(__name__).debug(optim_params)
        utils.get_logger(__name__).debug(len(optim_params))
        optimizer = get_optim(args.optim, optim_params)
        # optimizer = TriOptim(optimizer, args.n_filters, args.warmup, args.decay)
        factor = pruner.cur_rate / 100.0
        factor = 1.0
        # print(factor, pruner.cur_rate)
        for pg in optimizer.param_groups:
            pg["lr"] = factor * pg["lr"]
        utils.get_logger(__name__).info(optimizer)

        trainer = Trainer(
            train_data,
            model,
            loss=LossInForward(),
            optimizer=optimizer,
            metric_key=metric_key,
            metrics=metrics,
            print_every=200,
            batch_size=args.batch_size,
            num_workers=4,
            n_epochs=args.epochs,
            dev_data=dev_data,
            save_path=None,
            sampler=fastNLP.BucketSampler(batch_size=args.batch_size),
            callbacks=[
                pruner,
                # LRStep(lstm.WarmupLinearSchedule(optimizer, args.warmup, int(len(train_data)/args.batch_size*args.epochs)))
                GradientClipCallback(clip_type="norm", clip_value=5),
                LRScheduler(
                    lr_scheduler=LambdaLR(optimizer, lambda ep: 1 / (1 + 0.05 * ep))
                ),
                LogCallback(path=os.path.join(args.tb_path, "No", str(prune_step))),
            ],
            use_tqdm=False,
            device="cuda",
            check_code_level=-1,
        )
        res = trainer.train()
        logger.info("No #{} training, Result: {}".format(pruner.prune_times, res))
        name, val = get_metric(res)
        summary_writer.add_scalar("prunning_dev_acc", val, prune_step)
        tester = Tester(
            test_data,
            model,
            metrics=metrics,
            batch_size=args.batch_size,
            num_workers=4,
            device="cuda",
            use_tqdm=False,
        )
        res = tester.test()
        logger.info("No #{} testing, Result: {}".format(pruner.prune_times, res))
        name, val = get_metric(res)
        summary_writer.add_scalar("pruning_test_acc", val, prune_step)

        # prune and save
        torch.save(
            model.state_dict(),
            os.path.join(
                args.save_path,
                "best_{}_{}.th".format(pruner.prune_times, pruner.cur_rate),
            ),
        )
        pruner.pruning_model()
        summary_writer.add_scalar("remain_rate", pruner.cur_rate, prune_step + 1)
        summary_writer.add_scalar("cutoff", pruner.last_cutoff, prune_step + 1)

        pruner.save(
            os.path.join(
                args.save_path, "{}_{}.th".format(pruner.prune_times, pruner.cur_rate)
            )
        )
    args.embed_dropout = over_all_dropout
    args.output_dropout = over_all_dropout
    args.pre_dropout = over_all_dropout
    args.post_dropout = over_all_dropout
    args.ff_dropout = over_all_dropout
    args.attn_dropout = over_all_dropout

if args.lattice and args.use_rel_pos:
    args.train_clip = True

# fitlog.commit(__file__, fit_msg='绝对位置用新的了')
# fitlog.set_log_dir('../output/logs')
now_time = get_peking_time()
logger.add_file(f'../output/logs/{args.dataset}_{args.status}/bert{args.use_bert}_scheme{args.new_tag_scheme}'
                f'_ple{args.ple_channel_num}_plstm{int(args.use_ple_lstm)}_trainrate{args.train_dataset_rate}/{now_time}.log', level='info')
logger.info('Arguments')
for arg in vars(args):
    logger.info("{}: {}".format(arg, getattr(args, arg)))

# fitlog.add_hyper(now_time, 'time')
if args.debug:
    # args.dataset = 'toy'
    pass

if args.device != 'cpu':
    assert args.device.isdigit()
    device = torch.device('cuda:{}'.format(args.device))
else:
    device = torch.device('cpu')

Exemple #29
0
def print_info(*inp,islog=True,sep=' '):
    if islog:
        print(*inp,sep=sep)
    else:
        inp = sep.join(map(str,inp))
        logger.info(inp)