コード例 #1
0
ファイル: eval_cnn.py プロジェクト: hereismari/ajna
def main(args):
    # Get dataset
    test_files = glob.glob(os.path.join(args.test_path, '*.pickle'))
    datasource = DataSource(None,
                            test_files,
                            shape=tuple(args.eye_shape),
                            data_format=args.data_format,
                            heatmap_scale=args.heatmap_scale)

    # Get model
    learning_schedule = [{
        'loss_terms_to_optimize': {
            'heatmaps_mse': ['hourglass'],
            'radius_mse': ['radius'],
        },
        'learning_rate': 1e-3,
    }]
    model = CNN(datasource.tensors, datasource.x_shape, learning_schedule)

    # Get evaluator
    evaluator = Trainer(model, model_checkpoint=args.model_checkpoint)

    # Evaluate
    avg_losses = evaluator.run_eval(datasource)
    print('Avarage Losses', avg_losses)