def main(args):
    log_file_path = args.log
    if is_main_process() and log_file_path is not None:
        setup_log_file(os.path.expanduser(log_file_path))

    distributed, device_ids = init_distributed_mode(args.world_size,
                                                    args.dist_url)
    logger.info(args)
    cudnn.benchmark = True
    set_seed(args.seed)
    config = yaml_util.load_yaml_file(os.path.expanduser(args.config))
    device = torch.device(args.device)
    dataset_dict = util.get_all_datasets(config['datasets'])
    models_config = config['models']
    teacher_model_config = models_config.get('teacher_model', None)
    teacher_model = load_model(
        teacher_model_config,
        device) if teacher_model_config is not None else None
    student_model_config =\
        models_config['student_model'] if 'student_model' in models_config else models_config['model']
    ckpt_file_path = student_model_config['ckpt']
    student_model = load_model(student_model_config, device)
    if args.log_config:
        logger.info(config)

    if not args.test_only:
        train(teacher_model, student_model, dataset_dict, ckpt_file_path,
              device, device_ids, distributed, config, args)
        student_model_without_ddp =\
            student_model.module if module_util.check_if_wrapped(student_model) else student_model
        load_ckpt(student_model_config['ckpt'],
                  model=student_model_without_ddp,
                  strict=True)

    test_config = config['test']
    test_data_loader_config = test_config['test_data_loader']
    test_data_loader = util.build_data_loader(
        dataset_dict[test_data_loader_config['dataset_id']],
        test_data_loader_config, distributed)
    log_freq = test_config.get('log_freq', 1000)
    iou_types = args.iou_types
    if not args.student_only and teacher_model is not None:
        evaluate(teacher_model,
                 test_data_loader,
                 iou_types,
                 device,
                 device_ids,
                 distributed,
                 log_freq=log_freq,
                 title='[Teacher: {}]'.format(teacher_model_config['name']))
    evaluate(student_model,
             test_data_loader,
             iou_types,
             device,
             device_ids,
             distributed,
             log_freq=log_freq,
             title='[Student: {}]'.format(student_model_config['name']))
def main(args):
    log_file_path = args.log
    if is_main_process() and log_file_path is not None:
        setup_log_file(os.path.expanduser(log_file_path))

    distributed, device_ids = init_distributed_mode(args.world_size,
                                                    args.dist_url)
    logger.info(args)
    cudnn.benchmark = True
    config = yaml_util.load_yaml_file(os.path.expanduser(args.config))
    device = torch.device(args.device)
    dataset_dict = util.get_all_dataset(config['datasets'])
    models_config = config['models']
    teacher_model_config = models_config['teacher_model']
    teacher_model = get_model(teacher_model_config, device)
    student_model_config = models_config['student_model']
    student_model = get_model(student_model_config, device)
    if not args.test_only:
        distill(teacher_model, student_model, dataset_dict, device, device_ids,
                distributed, config, args)
        student_model_without_ddp =\
            student_model.module if module_util.check_if_wrapped(student_model) else student_model
        load_ckpt(student_model_config['ckpt'],
                  model=student_model_without_ddp,
                  strict=True)

    test_config = config['test']
    test_data_loader_config = test_config['test_data_loader']
    test_data_loader = util.build_data_loader(
        dataset_dict[test_data_loader_config['dataset_id']],
        test_data_loader_config, distributed)
    num_classes = args.num_classes
    if not args.student_only:
        evaluate(teacher_model,
                 test_data_loader,
                 device,
                 device_ids,
                 distributed,
                 num_classes=num_classes,
                 title='[Teacher: {}]'.format(teacher_model_config['name']))
    evaluate(student_model,
             test_data_loader,
             device,
             device_ids,
             distributed,
             num_classes=num_classes,
             title='[Student: {}]'.format(student_model_config['name']))