def main(args):
    log_file_path = args.log
    if is_main_process() and log_file_path is not None:
        setup_log_file(os.path.expanduser(log_file_path))

    distributed, device_ids = init_distributed_mode(args.world_size,
                                                    args.dist_url)
    logger.info(args)
    cudnn.benchmark = True
    set_seed(args.seed)
    config = yaml_util.load_yaml_file(os.path.expanduser(args.config))
    device = torch.device(args.device)
    dataset_dict = util.get_all_datasets(config['datasets'])
    models_config = config['models']
    teacher_model_config = models_config.get('teacher_model', None)
    teacher_model = load_model(
        teacher_model_config,
        device) if teacher_model_config is not None else None
    student_model_config =\
        models_config['student_model'] if 'student_model' in models_config else models_config['model']
    ckpt_file_path = student_model_config['ckpt']
    student_model = load_model(student_model_config, device)
    if args.log_config:
        logger.info(config)

    if not args.test_only:
        train(teacher_model, student_model, dataset_dict, ckpt_file_path,
              device, device_ids, distributed, config, args)
        student_model_without_ddp =\
            student_model.module if module_util.check_if_wrapped(student_model) else student_model
        load_ckpt(student_model_config['ckpt'],
                  model=student_model_without_ddp,
                  strict=True)

    test_config = config['test']
    test_data_loader_config = test_config['test_data_loader']
    test_data_loader = util.build_data_loader(
        dataset_dict[test_data_loader_config['dataset_id']],
        test_data_loader_config, distributed)
    log_freq = test_config.get('log_freq', 1000)
    iou_types = args.iou_types
    if not args.student_only and teacher_model is not None:
        evaluate(teacher_model,
                 test_data_loader,
                 iou_types,
                 device,
                 device_ids,
                 distributed,
                 log_freq=log_freq,
                 title='[Teacher: {}]'.format(teacher_model_config['name']))
    evaluate(student_model,
             test_data_loader,
             iou_types,
             device,
             device_ids,
             distributed,
             log_freq=log_freq,
             title='[Student: {}]'.format(student_model_config['name']))
Exemple #2
0
def train_one_epoch(training_box, device, epoch, log_freq):
    metric_logger = MetricLogger(delimiter='  ')
    metric_logger.add_meter('lr', SmoothedValue(window_size=1, fmt='{value}'))
    metric_logger.add_meter('img/s', SmoothedValue(window_size=10, fmt='{value}'))
    header = 'Epoch: [{}]'.format(epoch)
    for sample_batch, targets, supp_dict in \
            metric_logger.log_every(training_box.train_data_loader, log_freq, header):
        start_time = time.time()
        sample_batch, targets = sample_batch.to(device), targets.to(device)
        loss = training_box(sample_batch, targets, supp_dict)
        training_box.update_params(loss)
        batch_size = sample_batch.shape[0]
        metric_logger.update(loss=loss.item(), lr=training_box.optimizer.param_groups[0]['lr'])
        metric_logger.meters['img/s'].update(batch_size / (time.time() - start_time))
        if (torch.isnan(loss) or torch.isinf(loss)) and is_main_process():
            raise ValueError('The training loop was broken due to loss = {}'.format(loss))
def train(teacher_model, student_model, dataset_dict, ckpt_file_path, device,
          device_ids, distributed, config, args):
    logger.info('Start training')
    train_config = config['train']
    lr_factor = args.world_size if distributed and args.adjust_lr else 1
    training_box = get_training_box(student_model, dataset_dict, train_config,
                                    device, device_ids, distributed, lr_factor) if teacher_model is None \
        else get_distillation_box(teacher_model, student_model, dataset_dict, train_config,
                                  device, device_ids, distributed, lr_factor)
    best_val_map = 0.0
    optimizer, lr_scheduler = training_box.optimizer, training_box.lr_scheduler
    if file_util.check_if_exists(ckpt_file_path):
        best_val_map, _, _ = load_ckpt(ckpt_file_path,
                                       optimizer=optimizer,
                                       lr_scheduler=lr_scheduler)

    log_freq = train_config['log_freq']
    iou_types = args.iou_types
    val_iou_type = iou_types[0] if isinstance(
        iou_types, (list, tuple)) and len(iou_types) > 0 else 'bbox'
    student_model_without_ddp = student_model.module if module_util.check_if_wrapped(
        student_model) else student_model
    start_time = time.time()
    for epoch in range(args.start_epoch, training_box.num_epochs):
        training_box.pre_process(epoch=epoch)
        train_one_epoch(training_box, device, epoch, log_freq)
        val_coco_evaluator =\
            evaluate(student_model, training_box.val_data_loader, iou_types, device, device_ids, distributed,
                     log_freq=log_freq, header='Validation:')
        # Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ]
        val_map = val_coco_evaluator.coco_eval[val_iou_type].stats[0]
        if val_map > best_val_map and is_main_process():
            logger.info('Best mAP ({}): {:.4f} -> {:.4f}'.format(
                val_iou_type, best_val_map, val_map))
            logger.info('Updating ckpt at {}'.format(ckpt_file_path))
            best_val_map = val_map
            save_ckpt(student_model_without_ddp, optimizer, lr_scheduler,
                      best_val_map, config, args, ckpt_file_path)
        training_box.post_process()

    if distributed:
        dist.barrier()

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    logger.info('Training time {}'.format(total_time_str))
    training_box.clean_modules()
def main(args):
    log_file_path = args.log
    if is_main_process() and log_file_path is not None:
        setup_log_file(os.path.expanduser(log_file_path))

    distributed, device_ids = init_distributed_mode(args.world_size,
                                                    args.dist_url)
    logger.info(args)
    cudnn.benchmark = True
    config = yaml_util.load_yaml_file(os.path.expanduser(args.config))
    device = torch.device(args.device)
    dataset_dict = util.get_all_dataset(config['datasets'])
    models_config = config['models']
    teacher_model_config = models_config['teacher_model']
    teacher_model = get_model(teacher_model_config, device)
    student_model_config = models_config['student_model']
    student_model = get_model(student_model_config, device)
    if not args.test_only:
        distill(teacher_model, student_model, dataset_dict, device, device_ids,
                distributed, config, args)
        student_model_without_ddp =\
            student_model.module if module_util.check_if_wrapped(student_model) else student_model
        load_ckpt(student_model_config['ckpt'],
                  model=student_model_without_ddp,
                  strict=True)

    test_config = config['test']
    test_data_loader_config = test_config['test_data_loader']
    test_data_loader = util.build_data_loader(
        dataset_dict[test_data_loader_config['dataset_id']],
        test_data_loader_config, distributed)
    num_classes = args.num_classes
    if not args.student_only:
        evaluate(teacher_model,
                 test_data_loader,
                 device,
                 device_ids,
                 distributed,
                 num_classes=num_classes,
                 title='[Teacher: {}]'.format(teacher_model_config['name']))
    evaluate(student_model,
             test_data_loader,
             device,
             device_ids,
             distributed,
             num_classes=num_classes,
             title='[Student: {}]'.format(student_model_config['name']))
def distill(teacher_model, student_model, dataset_dict, device, device_ids,
            distributed, config, args):
    logger.info('Start distillation')
    train_config = config['train']
    lr_factor = args.world_size if distributed and args.adjust_lr else 1
    distillation_box =\
        get_distillation_box(teacher_model, student_model, dataset_dict,
                             train_config, device, device_ids, distributed, lr_factor)
    ckpt_file_path = config['models']['student_model']['ckpt']
    best_val_top1_accuracy = 0.0
    optimizer, lr_scheduler = distillation_box.optimizer, distillation_box.lr_scheduler
    if file_util.check_if_exists(ckpt_file_path):
        best_val_top1_accuracy, _, _ = load_ckpt(ckpt_file_path,
                                                 optimizer=optimizer,
                                                 lr_scheduler=lr_scheduler)

    log_freq = train_config['log_freq']
    student_model_without_ddp = student_model.module if module_util.check_if_wrapped(
        student_model) else student_model
    start_time = time.time()
    for epoch in range(args.start_epoch, distillation_box.num_epochs):
        distillation_box.pre_process(epoch=epoch)
        distill_one_epoch(distillation_box, device, epoch, log_freq)
        val_top1_accuracy = evaluate(student_model,
                                     distillation_box.val_data_loader,
                                     device,
                                     device_ids,
                                     distributed,
                                     log_freq=log_freq,
                                     header='Validation:')
        if val_top1_accuracy > best_val_top1_accuracy and is_main_process():
            logger.info('Updating ckpt (Best top1 accuracy: '
                        '{:.4f} -> {:.4f})'.format(best_val_top1_accuracy,
                                                   val_top1_accuracy))
            best_val_top1_accuracy = val_top1_accuracy
            save_ckpt(student_model_without_ddp, optimizer, lr_scheduler,
                      best_val_top1_accuracy, config, args, ckpt_file_path)
        distillation_box.post_process()

    if distributed:
        dist.barrier()

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    logger.info('Training time {}'.format(total_time_str))
    distillation_box.clean_modules()
def train(teacher_model, student_model, dataset_dict, ckpt_file_path, device,
          device_ids, distributed, config, args):
    logger.info('Start training')
    train_config = config['train']
    lr_factor = args.world_size if distributed and args.adjust_lr else 1
    training_box = get_training_box(student_model, dataset_dict, train_config,
                                    device, device_ids, distributed, lr_factor) if teacher_model is None \
        else get_distillation_box(teacher_model, student_model, dataset_dict, train_config,
                                  device, device_ids, distributed, lr_factor)
    best_val_miou = 0.0
    optimizer, lr_scheduler = training_box.optimizer, training_box.lr_scheduler
    if file_util.check_if_exists(ckpt_file_path):
        best_val_miou, _, _ = load_ckpt(ckpt_file_path,
                                        optimizer=optimizer,
                                        lr_scheduler=lr_scheduler)

    log_freq = train_config['log_freq']
    student_model_without_ddp = student_model.module if module_util.check_if_wrapped(
        student_model) else student_model
    start_time = time.time()
    for epoch in range(args.start_epoch, training_box.num_epochs):
        training_box.pre_process(epoch=epoch)
        train_one_epoch(training_box, device, epoch, log_freq)
        val_seg_evaluator =\
            evaluate(student_model, training_box.val_data_loader, device, device_ids, distributed,
                     num_classes=args.num_classes, log_freq=log_freq, header='Validation:')

        val_acc_global, val_acc, val_iou = val_seg_evaluator.compute()
        val_miou = val_iou.mean().item()
        if val_miou > best_val_miou and is_main_process():
            logger.info('Updating ckpt (Best mIoU: {:.4f} -> {:.4f})'.format(
                best_val_miou, val_miou))
            best_val_miou = val_miou
            save_ckpt(student_model_without_ddp, optimizer, lr_scheduler,
                      best_val_miou, config, args, ckpt_file_path)
        training_box.post_process()

    if distributed:
        dist.barrier()

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    logger.info('Training time {}'.format(total_time_str))
    training_box.clean_modules()
Exemple #7
0
def save_module_ckpt(module, ckpt_file_path):
    if is_main_process():
        make_parent_dirs(ckpt_file_path)
    state_dict = module.module.state_dict() if check_if_wrapped(module) else module.state_dict()
    save_on_master(state_dict, ckpt_file_path)
Exemple #8
0
def main(args):
    log_file_path = args.log
    if is_main_process() and log_file_path is not None:
        setup_log_file(os.path.expanduser(log_file_path))

    world_size = args.world_size
    logger.info(args)
    cudnn.benchmark = True
    set_seed(args.seed)
    config = yaml_util.load_yaml_file(os.path.expanduser(args.config))

    # Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
    accelerator = Accelerator()
    distributed = accelerator.state.distributed_type == DistributedType.MULTI_GPU
    device_ids = [accelerator.device.index]
    if distributed:
        setup_for_distributed(is_main_process())

    logger.info(accelerator.state)
    device = accelerator.device

    # Setup logging, we only want one process per machine to log things on the screen.
    # accelerator.is_local_main_process is only True for one process per machine.
    logger.setLevel(
        logging.INFO if accelerator.is_local_main_process else logging.ERROR)
    if accelerator.is_local_main_process:
        datasets.utils.logging.set_verbosity_warning()
        transformers.utils.logging.set_verbosity_info()
    else:
        datasets.utils.logging.set_verbosity_error()
        transformers.utils.logging.set_verbosity_error()

    # Load pretrained model and tokenizer
    task_name = args.task_name
    models_config = config['models']
    teacher_model_config = models_config.get('teacher_model', None)
    teacher_tokenizer, teacher_model = (None, None) if teacher_model_config is None \
        else load_tokenizer_and_model(teacher_model_config, task_name, True)
    student_model_config =\
        models_config['student_model'] if 'student_model' in models_config else models_config['model']
    student_tokenizer, student_model = load_tokenizer_and_model(
        student_model_config, task_name)
    ckpt_dir_path = student_model_config['ckpt']
    # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below)
    # or specify a GLUE benchmark task (the dataset will be downloaded automatically from the datasets Hub).
    dataset_dict, label_names_dict, is_regression = \
        get_all_datasets(config['datasets'], task_name, student_tokenizer, student_model)

    # Update config with dataset size len(data_loader)
    customize_lr_config(config, dataset_dict, world_size)

    # register collate function
    register_collate_func(
        DataCollatorWithPadding(
            student_tokenizer,
            pad_to_multiple_of=(8 if accelerator.use_fp16 else None)))

    # Get the metric function
    metric = get_metrics(task_name)

    if not args.test_only:
        train(teacher_model, student_model, dataset_dict, is_regression,
              ckpt_dir_path, metric, device, device_ids, distributed, config,
              args, accelerator)
        student_tokenizer.save_pretrained(ckpt_dir_path)

    test_config = config['test']
    test_data_loader_config = test_config['test_data_loader']
    test_data_loader = util.build_data_loader(
        dataset_dict[test_data_loader_config['dataset_id']],
        test_data_loader_config, distributed)
    test_data_loader = accelerator.prepare(test_data_loader)
    if not args.student_only and teacher_model is not None:
        teacher_model = teacher_model.to(accelerator.device)
        evaluate(teacher_model,
                 test_data_loader,
                 metric,
                 is_regression,
                 accelerator,
                 title='[Teacher: {}]'.format(teacher_model_config['name']))

    # Reload best checkpoint based on validation result
    student_tokenizer, student_model = load_tokenizer_and_model(
        student_model_config, task_name, True)
    student_model = accelerator.prepare(student_model)
    evaluate(student_model,
             test_data_loader,
             metric,
             is_regression,
             accelerator,
             title='[Student: {}]'.format(student_model_config['name']))

    # Output prediction for private dataset(s) if both the config and output dir path are given
    private_configs = config.get('private', None)
    private_output_dir_path = args.private_output
    if private_configs is not None and private_output_dir_path is not None and is_main_process(
    ):
        predict_private(student_model, dataset_dict, label_names_dict,
                        is_regression, accelerator, private_configs,
                        private_output_dir_path)
Exemple #9
0
def log_info(*args, **kwargs):
    force = kwargs.pop('force', False)
    if is_main_process() or force:
        logger.info(*args, **kwargs)