コード例 #1
0
def main(args):
    config = configparser.ConfigParser(
        interpolation=configparser.ExtendedInterpolation())
    config.read(args.config)

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    # check if output dir exists. if so, assign a new one
    if os.path.isdir(args.outdir):
        # create new output dir
        outdir = os.path.join(args.outdir, str(uuid.uuid4()))
    else:
        outdir = args.outdir

    # make the output dir
    os.makedirs(outdir)
    if args.save_best:
        os.makedirs(os.path.join(outdir, 'best_model'))

    # create a logger
    logger = create_logger(__name__,
                           to_disk=True,
                           log_file='{}/{}'.format(outdir, args.logfile))
    tasks = []
    for task_name in args.tasks.split(','):
        task = load_task(
            os.path.join(args.task_spec, '{}.yml'.format(task_name)))
        tasks.append(task)
    device = "cuda" if torch.cuda.is_available() else "cpu"
    device = torch.device(device)

    if not args.load_checkpoint:
        tokenizer = setup_customized_tokenizer(model=args.tokenizer,
                                               do_lower_case=False,
                                               config=config,
                                               tokenizer_class=BertTokenizer)
    else:
        tokenizer = BertTokenizer.from_pretrained(args.checkpoint)

    train_datasets = {}
    dev_dataloaders = {}
    test_dataloaders = {}

    for task_id, task in enumerate(tasks):
        task.set_task_id(task_id)
        logging.info('Task {}: {} on {}'.format(task_id, task.task_type,
                                                task.dataset))
        if 'train' in task.splits:
            train_datasets[task_id] = get_data(task=task,
                                               split='train',
                                               config=config,
                                               tokenizer=tokenizer)
            train_datasets[task_id].set_task_id(task_id)
            task.set_label_map(train_datasets[task_id].label_map)

        if 'dev' in task.splits:
            dev_data = get_data(task=task,
                                split='dev',
                                config=config,
                                tokenizer=tokenizer)
            dev_data.set_task_id(task_id)
            dev_dataloader = DataLoader(dev_data,
                                        shuffle=False,
                                        batch_size=8,
                                        collate_fn=dev_data.collate_fn)
            dev_dataloaders[task_id] = dev_dataloader

        if 'test' in task.splits:
            test_data = get_data(task=task,
                                 split='test',
                                 config=config,
                                 tokenizer=tokenizer)
            test_data.set_task_id(task_id)
            if task.dataset == 'iulaconv':
                import json
                #with open('iulaconv_test.json.analsyis', 'w') as f:
                #    for elm in test_data:
                #        f.write(json.dumps(elm) + '\n')
                #f.close()

            test_dataloader = DataLoader(test_data,
                                         shuffle=False,
                                         batch_size=8,
                                         collate_fn=test_data.collate_fn)
            test_dataloaders[task_id] = test_dataloader

    padding_label = train_datasets[0].padding_label

    sorted_train_datasets = [ds for _, ds in sorted(train_datasets.items())]

    mtl_dataset = MultiTaskDataset(sorted_train_datasets)
    multi_task_batch_sampler = MultiTaskBatchSampler(
        sorted_train_datasets,
        batch_size=args.bs,
        mix_opt=args.mix_opt,
        extra_task_ratio=args.extra_task_ratio,
        annealed_sampling=args.annealed_sampling,
        max_epochs=args.epochs)
    mtl_train_dataloader = DataLoader(mtl_dataset,
                                      batch_sampler=multi_task_batch_sampler,
                                      collate_fn=mtl_dataset.collate_fn,
                                      pin_memory=False)

    model = MTLModel(bert_encoder=args.bert_model,
                     device=device,
                     tasks=tasks,
                     padding_label_idx=padding_label,
                     load_checkpoint=args.load_checkpoint,
                     checkpoint=os.path.join(args.checkpoint, 'model.pt'),
                     tokenizer=tokenizer)

    # get optimizer
    # TODO: in case of loading from checkpoint, initialize optimizer using saved optimizer state dict
    optimizer = get_optimizer(optimizer_name='adamw',
                              model=model,
                              lr=args.lr,
                              eps=args.eps,
                              decay=args.decay)

    # get lr schedule
    total_steps = (len(mtl_dataset) /
                   args.grad_accumulation_steps) * args.epochs
    warmup_steps = args.warmup_frac * total_steps
    logger.info(
        'Bs_per_device={}, gradient_accumulation_steps={} --> effective bs= {}'
        .format(args.bs, args.grad_accumulation_steps,
                args.bs * args.grad_accumulation_steps))
    logger.info('Total steps: {}'.format(total_steps))
    logger.info('Scheduler: {} with {} warmup steps'.format(
        'warmuplinear', warmup_steps))

    scheduler = get_scheduler(optimizer,
                              scheduler='warmuplinear',
                              warmup_steps=warmup_steps,
                              t_total=total_steps)

    model.fit(tasks,
              optimizer,
              scheduler,
              gradient_accumulation_steps=args.grad_accumulation_steps,
              train_dataloader=mtl_train_dataloader,
              dev_dataloaders=dev_dataloaders,
              test_dataloaders=test_dataloaders,
              epochs=args.epochs,
              evaluation_step=args.evaluation_steps,
              save_best=args.save_best,
              outdir=outdir,
              predict=args.predict)
コード例 #2
0
    tasks = []

    #task = load_task(os.path.join('../task_specs', '{}.yml'.format('bioconv')))
    #task.num_labels = 5
    #task.task_id = 0
    #tasks.append(task)
    task = load_task(os.path.join('../task_specs', '{}.yml'.format('iula')))
    task.num_labels = 5
    task.task_id = 0
    tasks.append(task)
    import configparser
    config = configparser.ConfigParser(
        interpolation=configparser.ExtendedInterpolation())
    config.read('../preprocessing/config.cfg')
    tokenizer = setup_customized_tokenizer(model='bert-base-cased',
                                           do_lower_case=False,
                                           config=config,
                                           tokenizer_class=BertTokenizer)
    model = MTLModel(bert_encoder='bert-base-cased',
                     device='cpu',
                     tasks=tasks,
                     padding_label_idx=-1,
                     load_checkpoint=False,
                     tokenizer=tokenizer)
    optimizer = get_optimizer(optimizer_name='adamw',
                              model=model,
                              lr=5e-5,
                              eps=1e-6,
                              decay=0)
    tokenizer.save_pretrained('checkpoints/test')
    model.save('checkpoints/test/model.pt', optimizer)
    #model = MTLModel(bert_encoder=None, device='cpu', tasks=tasks, padding_label_idx=-1, load_checkpoint=True, checkpoint='checkpoints/test/model.pt' )