def test_ctsplitter_mask_sorted(point_process_dataset): d = next(iter(point_process_dataset)) ia_times = d["target"][0, :] ts = np.cumsum(ia_times) splitter = transform.ContinuousTimeInstanceSplitter( past_interval_length=2, future_interval_length=1, instance_sampler=transform.ContinuousTimeUniformSampler( num_instances=10, min_past=2, min_future=1, ), freq=to_offset(point_process_dataset.freq), ) # no boundary conditions res = splitter._mask_sorted(ts, 1, 2) assert all([a == b for a, b in zip([2, 3, 4], res)]) # lower bound equal, exclusive of upper bound res = splitter._mask_sorted(np.array([1, 2, 3, 4, 5, 6]), 1, 2) assert all([a == b for a, b in zip([0], res)])
def test_ctsplitter_train_samples_correct_times(point_process_dataset): splitter = transform.ContinuousTimeInstanceSplitter( 1.25, 1.25, train_sampler=transform.ContinuousTimeUniformSampler(20)) iter_de = splitter(point_process_dataset, is_train=True) assert all([(pd.Timestamp("2011-01-01 01:15:00") <= d["forecast_start"] <= pd.Timestamp("2011-01-01 01:45:00")) for d in iter_de])
def test_ctsplitter_train_samples_correct_times(point_process_dataset): splitter = transform.ContinuousTimeInstanceSplitter( past_interval_length=1.25, future_interval_length=1.25, instance_sampler=transform.ContinuousTimeUniformSampler( num_instances=20, min_past=1.25, min_future=1.25, ), ) iter_de = splitter(point_process_dataset, is_train=True) assert all([(pd.Timestamp("2011-01-01 01:15:00") <= d["forecast_start"] <= pd.Timestamp("2011-01-01 01:45:00")) for d in iter_de])
def test_ctsplitter_no_train_last_point(point_process_dataset): splitter = transform.ContinuousTimeInstanceSplitter( 2, 1, train_sampler=transform.ContinuousTimeUniformSampler(num_instances=10), ) iter_de = splitter(point_process_dataset, is_train=False) d_out = next(iter(iter_de)) assert "future_target" not in d_out assert "future_valid_length" not in d_out assert "past_target" in d_out assert "past_valid_length" in d_out assert d_out["past_valid_length"] == 6 assert np.allclose( [0.1, 0.5, 0.3, 0.3, 0.2, 0.1], d_out["past_target"][..., 0], atol=0.01 )
def test_ctsplitter_mask_sorted(point_process_dataset): d = next(iter(point_process_dataset)) ia_times = d["target"][0, :] ts = np.cumsum(ia_times) splitter = transform.ContinuousTimeInstanceSplitter( 2, 1, train_sampler=transform.ContinuousTimeUniformSampler(num_instances=10), ) # no boundary conditions res = splitter._mask_sorted(ts, 1, 2) assert all([a == b for a, b in zip([2, 3, 4], res)]) # lower bound equal, exclusive of upper bound res = splitter._mask_sorted(np.array([1, 2, 3, 4, 5, 6]), 1, 2) assert all([a == b for a, b in zip([0], res)])
def test_continuous_time_sampler(): sampler = transform.ContinuousTimeUniformSampler(num_instances=4) assert equals(sampler, clone(sampler)) assert not equals(sampler, clone(sampler, {"num_instances": 5}))