def test_gradients():
    tm = TransitModule(time=time_array,
                       **params_dicts['scalar'],
                       secondary=True)
    for param in list(tm._parameters.keys()) + ['rp_over_rs']:
        if param == 'time':
            continue
        tm.zero_grad()
        tm.activate_grad(param)
        assert getattr(tm, param).requires_grad
        flux = tm()
        assert flux.requires_grad
        flux.sum().backward()
        g = getattr(tm, param).grad
        assert not torch.isnan(g) and g.item() != 0.
        tm.deactivate_grad(param)
        assert not getattr(tm, param).requires_grad

        # external argument
        if param in tm._parnames:
            value = torch.tensor(pars[param], requires_grad=True)
        else:
            value = torch.tensor(pars[PLC_ALIASES[param]], requires_grad=True)
        assert value.requires_grad
        flux = tm(**{param: value})
        assert flux.requires_grad
        assert not getattr(tm, param).requires_grad
        flux.sum().backward()
        g = getattr(tm, param).grad
        assert not torch.isnan(g) and g.item() != 0.
def test_ldc_methods():
    pars_ldc = {
        'linear': np.random.rand(1)[None, :],
        'sqrt': np.random.rand(2)[None, :],
        'quad': np.random.rand(2)[None, :],
        'claret': np.random.rand(4)[None, :]
    }
    tm = TransitModule(**params_dicts['scalar'], time=time_tensor)
    for method in ['linear', 'sqrt', 'quad', 'claret']:
        tm.reset_param('ldc')
        tm.set_method(method)
        tm.set_param('ldc', pars_ldc[method])
        tm.get_drop_s()
        tm()
def test_cache():
    tm = TransitModule(time=time_array,
                       **params_dicts['scalar'],
                       secondary=True)
    tm_cache = TransitModule(time=time_array,
                             **params_dicts['scalar'],
                             secondary=True,
                             cache_pos=True)

    time = timeit.timeit(lambda: tm.get_position(), number=20)
    time_cache = timeit.timeit(lambda: tm_cache.get_position(), number=20)
    assert time_cache < time / 5
    with pytest.warns(UserWarning):
        tm_cache.activate_grad('P')
    assert not tm_cache.cache_pos
def test_transit_params():
    for k, d in params_dicts.items():
        tm = TransitModule(**d)
        tm.set_time(time_array)

        for p in tm._parnames:
            assert p == 'method' or getattr(tm, p).data.dtype == torch.float64
        assert tm.time.dtype == torch.float64

        tm.reset_params()
        attr = np.random.choice(list(tm._parnames))
        assert getattr(tm, attr) is None
        tm.set_params(**d)

        for i, x in enumerate([tm.proj_dist, tm.drop_p, tm.forward(), tm()]):
            assert isinstance(x, torch.Tensor)
            assert not torch.isnan(x).any()

        flux_1 = tm()
        # External arguments
        flux_2 = tm(**d)
        assert torch.isclose(flux_1, flux_2).all()
def test_pytorch_inherited_attr():
    tm = TransitModule()
    tm.set_param("e", 0.)
    tm.set_time(range(10))

    tm.float()
    assert tm.e.dtype == torch.float32
    assert tm.time.dtype == torch.float32
    tm.double()
    assert tm.e.dtype == torch.float64
    assert tm.time.dtype == torch.float64

    assert tm.to('cpu').e.device.type == 'cpu'
    if torch.cuda.is_available():
        assert tm.to('cuda').e.device.type in ['gpu', 'cuda']

    tm.train()
    tm.eval()
    tm.named_modules()
    tm.named_parameters()
    assert 'e' in tm._parameters
    tm.zero_grad()
def test_transit_type():
    assert TransitModule().primary == True
    assert TransitModule().secondary == False
    assert TransitModule(primary=False, secondary=True).primary == False
    assert TransitModule(secondary=True).primary == True
    assert TransitModule(secondary=True).secondary == True

    try:
        TransitModule(primary=False, secondary=False)
    except RuntimeError:
        pass  # rightly caught error

    # epoch_type
    for tm in [
            TransitModule(),
            TransitModule(primary=True),
            TransitModule(secondary=True),
            TransitModule(primary=False, secondary=True, epoch_type='primary')
    ]:
        assert tm.epoch_type == 'primary'

    for tm in [
            TransitModule(primary=False, secondary=True),
            TransitModule(epoch_type='secondary')
    ]:
        assert tm.epoch_type == 'secondary'
def test_cuda():
    if not torch.cuda.is_available():
        pytest.skip('no available gpu')
    tm = TransitModule(time_tensor, secondary=True,
                       **params_dicts['scalar']).cuda()
    tm()

    tm.cpu()
    tm.reset_time()
    try:
        tm.set_time(time_tensor)
        tm()
    except RuntimeError:
        print(
            "error caught. Right behaviour because time tensor not supposed to have been converted"
        )
        tm.set_time(time_tensor)
        tm.cuda()
        tm()
def test_cache():
    tm = TransitModule(time=time_array,
                       **params_dicts['scalar'],
                       secondary=True)
    tm_cache = TransitModule(time=time_array,
                             **params_dicts['scalar'],
                             secondary=True,
                             cache_pos=True)

    time = timeit.timeit(lambda: tm.get_position(), number=20)
    time_cache = timeit.timeit(lambda: tm_cache.get_position(), number=20)
    assert time_cache < time / 5
    # check that activating gradient deactivate the cache
    with pytest.warns(UserWarning):
        tm_cache.activate_grad('P')
    assert not tm_cache.cache_pos

    # check that runtime computation won't affect the cached vector
    tm_cache = TransitModule(time=time_array,
                             **params_dicts['scalar'],
                             secondary=True,
                             cache_pos=True)
    flux = tm_cache()
    tm_cache(i=93)
    assert tm_cache.cache_pos
    assert (tm_cache() == flux).all()

    # check that setting a position parameter will update the cached vector
    tm_cache = TransitModule(time=time_array,
                             **params_dicts['scalar'],
                             secondary=True,
                             cache_pos=True)
    flux = tm_cache()
    tm_cache.set_param('i', 91.)
    assert tm_cache.cache_pos
    assert not (flux == tm_cache()).all()
def test_time_tensor():
    tm = TransitModule(**params_dicts['scalar'])
    tm.set_time(torch.linspace(0, 10, 100))
    tm()

    tm.set_time(torch.linspace(0, 10, 100)[None, :].repeat(5, 1))
    tm()

    tm = TransitModule(
        **map_dict(pars, lambda x: torch.tensor(x)[None, None].repeat(5, 1)))
    tm.set_time(torch.linspace(0, 10, 100)[None, :].repeat(5, 1))
    tm()

    tm = TransitModule(
        **map_dict(pars, lambda x: torch.tensor(x)[None, None].repeat(5, 1)))
    try:
        tm.set_time(torch.linspace(0, 10, 100)[None, :].repeat(6, 1))
    except RuntimeError:
        ...  # Caught error

    # Runtime mode
    tm = TransitModule(**params_dicts['scalar'])
    flux = tm(time=torch.linspace(0, 10, 100))
    assert flux.shape == (1, 100)

    tm = TransitModule(**params_dicts['scalar'])
    flux = tm.set_time(torch.linspace(0, 10, 100))
    flux = tm(time=torch.linspace(0, 10, 150))
    assert flux.shape == (1, 150)
def test_transit_params():
    for k, d in params_dicts.items():
        tm = TransitModule(**d)
        tm.set_time(time_array)

        for p in tm._parnames:
            assert p == 'method' or getattr(tm, p).data.dtype == torch.float64
        assert tm.time.dtype == torch.float64

        tm.reset_params()
        attr = np.random.choice(list(tm._parnames))
        assert getattr(tm, attr) is None
        tm.set_params(**d)
        tm.set_time(time_array)

        for i, x in enumerate([tm.proj_dist, tm.drop_p, tm.forward(), tm()]):
            assert isinstance(x, torch.Tensor)
            assert not torch.isnan(x).any()

        flux_1 = tm()
        # External arguments
        flux_2 = tm(**d)
        assert torch.isclose(flux_1, flux_2).all()

    # Wrong Argument
    try:
        tm.set_param('wrong_argument', 0)
    except RuntimeError as e:
        ...
    else:
        raise RuntimeError(
            'should raise an error because argument does not exist')

    try:
        tm(wrong_argument=0)
    except RuntimeError as e:
        ...
    else:
        raise RuntimeError(
            'should raise an error because argument does not exist')