コード例 #1
0
def main():
    args = parser.parse_args()
    model_cfgs = []
    model_names = []
    if os.path.isdir(args.checkpoint):
        # validate all checkpoints in a path with same model
        checkpoints = glob.glob(args.checkpoint + '/*.pth.tar')
        checkpoints += glob.glob(args.checkpoint + '/*.pth')
        model_names = list_models(args.model)
        model_cfgs = [(args.model, c) for c in sorted(checkpoints, key=natural_key)]
    else:
        if args.model == 'all':
            # validate all models in a list of names with pretrained checkpoints
            args.pretrained = True
            model_names = list_models(pretrained=True, exclude_filters=['*in21k'])
            model_cfgs = [(n, '') for n in model_names]
        elif not is_model(args.model):
            # model name doesn't exist, try as wildcard filter
            model_names = list_models(args.model)
            model_cfgs = [(n, '') for n in model_names]
    
    setup_default_logging(log_path=os.path.join(os.path.dirname(args.checkpoint) , "eval_print_log.txt"))

    if len(model_cfgs):
        #results_file = args.results_file or './results-all.csv'
        results_file = args.results_file or os.path.join(os.path.dirname(args.checkpoint) , "eval_results.csv.txt")
        _logger.info('Running bulk validation on these pretrained models: {}'.format(', '.join(model_names)))
        results = []
        try:
            start_batch_size = args.batch_size
            for m, c in model_cfgs:
                batch_size = start_batch_size
                args.model = m
                args.checkpoint = c
                result = OrderedDict(model=args.model)
                r = {}
                while not r and batch_size >= args.num_gpu:
                    torch.cuda.empty_cache()
                    try:
                        args.batch_size = batch_size
                        print('Validating with batch size: %d' % args.batch_size)
                        r = validate(args)
                    except RuntimeError as e:
                        if batch_size <= args.num_gpu:
                            print("Validation failed with no ability to reduce batch size. Exiting.")
                            raise e
                        batch_size = max(batch_size // 2, args.num_gpu)
                        print("Validation failed, reducing batch size by 50%")
                result.update(r)
                if args.checkpoint:
                    result['checkpoint'] = args.checkpoint
                results.append(result)
        except KeyboardInterrupt as e:
            pass
        results = sorted(results, key=lambda x: x['top1'], reverse=True)
        if len(results):
            write_results(results_file, results)
    else:
        validate(args)
コード例 #2
0
def main():
    setup_default_logging()
    args = parser.parse_args()
    model_cfgs = []
    model_names = []
    if os.path.isdir(args.checkpoint):
        # validate all checkpoints in a path with same model
        checkpoints = glob.glob(args.checkpoint + '/*.pth.tar')
        checkpoints += glob.glob(args.checkpoint + '/*.pth')
        model_names = list_models(args.model)
        model_cfgs = [(args.model, c)
                      for c in sorted(checkpoints, key=natural_key)]
    else:
        if args.model == 'all':
            # validate all models in a list of names with pretrained checkpoints
            args.pretrained = True
            model_names = list_models(
                pretrained=True,
                exclude_filters=['*_in21k', '*_in22k', '*_dino'])
            model_cfgs = [(n, '') for n in model_names]
        elif not is_model(args.model):
            # model name doesn't exist, try as wildcard filter
            model_names = list_models(args.model)
            model_cfgs = [(n, '') for n in model_names]

        if not model_cfgs and os.path.isfile(args.model):
            with open(args.model) as f:
                model_names = [line.rstrip() for line in f]
            model_cfgs = [(n, None) for n in model_names if n]

    if len(model_cfgs):
        results_file = args.results_file or './results-all.csv'
        _logger.info(
            'Running bulk validation on these pretrained models: {}'.format(
                ', '.join(model_names)))
        results = []
        try:
            initial_batch_size = args.batch_size
            for m, c in model_cfgs:
                args.model = m
                args.checkpoint = c
                r = _try_run(args, initial_batch_size)
                if 'error' in r:
                    continue
                if args.checkpoint:
                    r['checkpoint'] = args.checkpoint
                results.append(r)
        except KeyboardInterrupt as e:
            pass
        results = sorted(results, key=lambda x: x['top1'], reverse=True)
        if len(results):
            write_results(results_file, results)
    else:
        results = validate(args)
    # output results in JSON to stdout w/ delimiter for runner script
    print(f'--result\n{json.dumps(results, indent=4)}')
コード例 #3
0
def main():
    setup_default_logging()
    args = parser.parse_args()
    model_cfgs = []
    model_names = []

    if args.model_list:
        args.model = ''
        with open(args.model_list) as f:
            model_names = [line.rstrip() for line in f]
        model_cfgs = [(n, None) for n in model_names]
    elif args.model == 'all':
        # validate all models in a list of names with pretrained checkpoints
        args.pretrained = True
        model_names = list_models(pretrained=True, exclude_filters=['*in21k'])
        model_cfgs = [(n, None) for n in model_names]
    elif not is_model(args.model):
        # model name doesn't exist, try as wildcard filter
        model_names = list_models(args.model)
        model_cfgs = [(n, None) for n in model_names]

    if len(model_cfgs):
        results_file = args.results_file or './benchmark.csv'
        _logger.info(
            'Running bulk validation on these pretrained models: {}'.format(
                ', '.join(model_names)))
        results = []
        try:
            for m, _ in model_cfgs:
                if not m:
                    continue
                args.model = m
                r = benchmark(args)
                if r:
                    results.append(r)
                time.sleep(10)
        except KeyboardInterrupt as e:
            pass
        sort_key = 'infer_samples_per_sec'
        if 'train' in args.bench:
            sort_key = 'train_samples_per_sec'
        elif 'profile' in args.bench:
            sort_key = 'infer_gmacs'
        results = filter(lambda x: sort_key in x, results)
        results = sorted(results, key=lambda x: x[sort_key], reverse=True)
        if len(results):
            write_results(results_file, results)
    else:
        results = benchmark(args)

    # output results in JSON to stdout w/ delimiter for runner script
    print(f'--result\n{json.dumps(results, indent=4)}')
コード例 #4
0
def main():
    setup_default_logging()
    args = parser.parse_args()
    model_cfgs = []
    model_names = []
    if os.path.isdir(args.checkpoint):
        # validate all checkpoints in a path with same model
        checkpoints = glob.glob(args.checkpoint + '/*.pth.tar')
        checkpoints += glob.glob(args.checkpoint + '/*.pth')
        model_names = list_models(args.model)
        model_cfgs = [(args.model, c)
                      for c in sorted(checkpoints, key=natural_key)]
    else:
        if args.model == 'all':
            # validate all models in a list of names with pretrained checkpoints
            args.pretrained = True
            model_names = list_models(pretrained=True)
            model_cfgs = [(n, '') for n in model_names]
        elif not is_model(args.model):
            # model name doesn't exist, try as wildcard filter
            model_names = list_models(args.model)
            model_cfgs = [(n, '') for n in model_names]

    if len(model_cfgs):
        results_file = args.results_file or './results-all.csv'
        logging.info(
            'Running bulk validation on these pretrained models: {}'.format(
                ', '.join(model_names)))
        results = []
        try:
            for m, c in model_cfgs:
                args.model = m
                args.checkpoint = c
                result = OrderedDict(model=args.model)
                r = validate(args)
                result.update(r)
                if args.checkpoint:
                    result['checkpoint'] = args.checkpoint
                results.append(result)
        except KeyboardInterrupt as e:
            pass
        results = sorted(results, key=lambda x: x['top1'], reverse=True)
        if len(results):
            write_results(results_file, results)
    else:
        validate(args)
コード例 #5
0
def main():
    setup_default_logging()
    args = parser.parse_args()
    model_cfgs = []
    model_names = []
    if os.path.isdir(args.checkpoint):
        # validate all checkpoints in a path with same model
        checkpoints = glob.glob(args.checkpoint + '/*.pth.tar')
        checkpoints += glob.glob(args.checkpoint + '/*.pth')
        model_names = list_models(args.model)
        model_cfgs = [(args.model, c)
                      for c in sorted(checkpoints, key=natural_key)]
    else:
        if args.model == 'all':
            # validate all models in a list of names with pretrained checkpoints
            args.pretrained = True
            model_names = list_models(pretrained=True)
            model_cfgs = [(n, '') for n in model_names]
        elif not is_model(args.model):
            # model name doesn't exist, try as wildcard filter
            model_names = list_models(args.model)
            model_cfgs = [(n, '') for n in model_names]

    if len(model_cfgs):
        logging.info(
            'Running bulk validation on these pretrained models: {}'.format(
                ', '.join(model_names)))
        header_written = False
        with open('./results-all.csv', mode='w') as cf:
            for m, c in model_cfgs:
                args.model = m
                args.checkpoint = c
                result = OrderedDict(model=args.model)
                r = validate(args)
                result.update(r)
                if args.checkpoint:
                    result['checkpoint'] = args.checkpoint
                dw = csv.DictWriter(cf, fieldnames=result.keys())
                if not header_written:
                    dw.writeheader()
                    header_written = True
                dw.writerow(result)
                cf.flush()
    else:
        validate(args)
コード例 #6
0
def main():
    setup_default_logging()
    args = parser.parse_args()
    model_cfgs = []
    model_names = []
    if os.path.isdir(args.checkpoint):
        # validate all checkpoints in a path with same model
        checkpoints = glob.glob(args.checkpoint + '/*.pth.tar')
        checkpoints += glob.glob(args.checkpoint + '/*.pth')
        model_names = list_models(args.model)
        model_cfgs = [(args.model, c)
                      for c in sorted(checkpoints, key=natural_key)]
    else:
        if args.model == 'all':
            # validate all models in a list of names with pretrained checkpoints
            args.pretrained = True
            model_names = list_models(pretrained=True)
            model_cfgs = [(n, '') for n in model_names]
        elif not is_model(args.model):
            # model name doesn't exist, try as wildcard filter
            model_names = list_models(args.model)
            model_cfgs = [(n, '') for n in model_names]
    validate(args)
コード例 #7
0
def main():
    setup_default_logging()

    args, args_text = _parse_args()
    # args = parser.parse_args()

    torch.manual_seed(args.seed)  # ! mostly for aug on test
    random.seed(args.seed)
    np.random.seed(args.seed)

    model_cfgs = []
    model_names = []
    if os.path.isdir(
            args.checkpoint):  # ! can pass in a directory of checkpoints
        # validate all checkpoints in a path with same model
        checkpoints = glob.glob(args.checkpoint + '/*.pth.tar')
        checkpoints += glob.glob(args.checkpoint + '/*.pth')
        model_names = list_models(args.model)
        model_cfgs = [(args.model, c)
                      for c in sorted(checkpoints, key=natural_key)]
    else:
        if args.model == 'all':
            # validate all models in a list of names with pretrained checkpoints
            args.pretrained = True
            model_names = list_models(pretrained=True)
            model_cfgs = [(n, '') for n in model_names]
        elif not is_model(args.model):
            # model name doesn't exist, try as wildcard filter
            model_names = list_models(args.model)
            model_cfgs = [(n, '') for n in model_names]

    results_file = args.results_file or './results-all.csv'
    if len(model_cfgs):
        _logger.info(
            'Running bulk validation on these pretrained models: {}'.format(
                ', '.join(model_names)))
        results = []
        try:
            start_batch_size = args.batch_size
            for m, c in model_cfgs:
                batch_size = start_batch_size
                args.model = m
                args.checkpoint = c
                result = OrderedDict(model=args.model)
                r = {}
                while not r and batch_size >= args.num_gpu:
                    torch.cuda.empty_cache()
                    try:
                        args.batch_size = batch_size
                        print('Validating with batch size: %d' %
                              args.batch_size)
                        r, prediction, true_label = validate(args)
                    except RuntimeError as e:
                        if batch_size <= args.num_gpu:
                            print(
                                "Validation failed with no ability to reduce batch size. Exiting."
                            )
                            raise e
                        batch_size = max(batch_size // 2, args.num_gpu)
                        print("Validation failed, reducing batch size by 50%")
                result.update(r)
                if args.checkpoint:
                    result['checkpoint'] = args.checkpoint
                results.append(result)
        except KeyboardInterrupt as e:
            pass
        results = sorted(results, key=lambda x: x['top1'], reverse=True)
        if len(results):
            write_results(results_file, results)
    else:
        # ! eval one single model
        results_file = re.sub(r'\.csv', '', results_file) + '-standard.csv'
        _, prediction, true_label = validate(args)

        from HAM10000 import helper
        prediction = helper.softmax(
            prediction)  # softmax convert to range 0-1, sum to 1

        if args.has_eval_label:
            from sklearn.metrics import accuracy_score, balanced_accuracy_score
            true_label_onehot = np.identity(
                args.num_classes)[true_label]  # array into 1 hot
            # https://scikit-learn.org/stable/modules/generated/sklearn.metrics.accuracy_score.html
            _logger.info(
                ' * sklearn multilabel probabilities accuracy_score {:.3f} '.
                format(accuracy_score(true_label_onehot,
                                      np.round(prediction))))
            _logger.info(
                ' * sklearn max probabilities balanced_accuracy_score {:.3f} '.
                format(
                    balanced_accuracy_score(
                        true_label,
                        helper.convert_max_1_other_0(prediction).argmax(
                            axis=1))))

        # output csv, need to reorder columns
        helper.save_output_csv(
            prediction, obs_name=[],
            output_name=results_file)  # no name for each observation, use []