from __future__ import division, print_function, unicode_literals

import itertools

import numpy as np
import pytest

from brainstorm.handlers import NumpyHandler
from brainstorm.optional import has_pycuda

non_default_handlers = []
handler_ids = []
if has_pycuda:
    from brainstorm.handlers import PyCudaHandler
    non_default_handlers.append(PyCudaHandler())
    handler_ids.append("PyCudaHandler")

# np.random.seed(1234)
ref_dtype = np.float32
ref = NumpyHandler(ref_dtype)
some_2d_shapes = ((1, 1), (4, 1), (1, 4), (5, 5), (3, 4), (4, 3))
some_nd_shapes = ((1, 1, 4), (1, 1, 3, 3), (3, 4, 2, 1))

np.set_printoptions(linewidth=150)


def operation_check(handler, op_name, ref_args, ignored_args=(), atol=1e-8):
    args = get_args_from_ref_args(handler, ref_args)
    getattr(ref, op_name)(*ref_args)
    getattr(handler, op_name)(*args)