Esempio n. 1
0
def test_standard_scaler_with_episode(observation_shape, batch_size):
    shape = (batch_size, ) + observation_shape
    observations = np.random.random(shape).astype("f4")
    actions = np.random.random((batch_size, 1)).astype("f4")
    rewards = np.random.random(batch_size).astype("f4")
    terminals = np.random.randint(2, size=batch_size)
    terminals[-1] = 1.0

    dataset = MDPDataset(
        observations=observations,
        actions=actions,
        rewards=rewards,
        terminals=terminals,
    )

    mean = observations.mean(axis=0)
    std = observations.std(axis=0)

    scaler = StandardScaler()
    scaler.fit(dataset.episodes)

    x = torch.rand((batch_size, ) + observation_shape)

    y = scaler.transform(x)

    ref_y = (x.numpy() - mean.reshape((1, -1))) / std.reshape((1, -1))

    assert np.allclose(y.numpy(), ref_y, atol=1e-6)
Esempio n. 2
0
def test_standard_scaler_with_dataset(observation_shape, batch_size):
    shape = (batch_size, ) + observation_shape
    observations = np.random.random(shape).astype('f')
    actions = np.random.random((batch_size, 1)).astype('f')
    rewards = np.random.random(batch_size).astype('f')
    terminals = np.random.randint(2, size=batch_size)

    dataset = MDPDataset(observations, actions, rewards, terminals)

    mean = observations.mean(axis=0)
    std = observations.std(axis=0)

    scaler = StandardScaler(dataset)

    x = torch.rand((batch_size, ) + observation_shape)

    y = scaler.transform(x)

    ref_y = (x.numpy() - mean.reshape((1, -1))) / std.reshape((1, -1))

    assert np.allclose(y.numpy(), ref_y)
Esempio n. 3
0
def test_standard_scaler(observation_shape, batch_size):
    shape = (batch_size, ) + observation_shape
    observations = np.random.random(shape).astype("f4")

    mean = observations.mean(axis=0)
    std = observations.std(axis=0)

    scaler = StandardScaler(mean=mean, std=std)

    x = torch.rand((batch_size, ) + observation_shape)

    y = scaler.transform(x)

    ref_y = (x.numpy() - mean.reshape((1, -1))) / std.reshape((1, -1))

    assert np.allclose(y.numpy(), ref_y)

    assert scaler.get_type() == "standard"
    params = scaler.get_params()
    assert np.all(params["mean"] == mean)
    assert np.all(params["std"] == std)
    assert torch.allclose(scaler.reverse_transform(y), x, atol=1e-6)