def test_standard_scaler_with_episode(observation_shape, batch_size): shape = (batch_size, ) + observation_shape observations = np.random.random(shape).astype("f4") actions = np.random.random((batch_size, 1)).astype("f4") rewards = np.random.random(batch_size).astype("f4") terminals = np.random.randint(2, size=batch_size) terminals[-1] = 1.0 dataset = MDPDataset( observations=observations, actions=actions, rewards=rewards, terminals=terminals, ) mean = observations.mean(axis=0) std = observations.std(axis=0) scaler = StandardScaler() scaler.fit(dataset.episodes) x = torch.rand((batch_size, ) + observation_shape) y = scaler.transform(x) ref_y = (x.numpy() - mean.reshape((1, -1))) / std.reshape((1, -1)) assert np.allclose(y.numpy(), ref_y, atol=1e-6)
def test_standard_scaler_with_dataset(observation_shape, batch_size): shape = (batch_size, ) + observation_shape observations = np.random.random(shape).astype('f') actions = np.random.random((batch_size, 1)).astype('f') rewards = np.random.random(batch_size).astype('f') terminals = np.random.randint(2, size=batch_size) dataset = MDPDataset(observations, actions, rewards, terminals) mean = observations.mean(axis=0) std = observations.std(axis=0) scaler = StandardScaler(dataset) x = torch.rand((batch_size, ) + observation_shape) y = scaler.transform(x) ref_y = (x.numpy() - mean.reshape((1, -1))) / std.reshape((1, -1)) assert np.allclose(y.numpy(), ref_y)
def test_standard_scaler(observation_shape, batch_size): shape = (batch_size, ) + observation_shape observations = np.random.random(shape).astype("f4") mean = observations.mean(axis=0) std = observations.std(axis=0) scaler = StandardScaler(mean=mean, std=std) x = torch.rand((batch_size, ) + observation_shape) y = scaler.transform(x) ref_y = (x.numpy() - mean.reshape((1, -1))) / std.reshape((1, -1)) assert np.allclose(y.numpy(), ref_y) assert scaler.get_type() == "standard" params = scaler.get_params() assert np.all(params["mean"] == mean) assert np.all(params["std"] == std) assert torch.allclose(scaler.reverse_transform(y), x, atol=1e-6)