def test_networkbody_lstm(): torch.manual_seed(0) obs_size = 4 seq_len = 6 network_settings = NetworkSettings(memory=NetworkSettings.MemorySettings( sequence_length=seq_len, memory_size=12)) obs_shapes = [(obs_size, )] networkbody = NetworkBody(create_observation_specs_with_shapes(obs_shapes), network_settings) optimizer = torch.optim.Adam(networkbody.parameters(), lr=3e-4) sample_obs = torch.ones((seq_len, obs_size)) for _ in range(300): encoded, _ = networkbody([sample_obs], memories=torch.ones(1, 1, 12), sequence_length=seq_len) # Try to force output to 1 loss = torch.nn.functional.mse_loss(encoded, torch.ones(encoded.shape)) optimizer.zero_grad() loss.backward() optimizer.step() # In the last step, values should be close to 1 for _enc in encoded.flatten().tolist(): assert _enc == pytest.approx(1.0, abs=0.1)
def test_networkbody_visual(): torch.manual_seed(0) vec_obs_size = 4 obs_size = (84, 84, 3) network_settings = NetworkSettings() obs_shapes = [(vec_obs_size, ), obs_size] networkbody = NetworkBody(create_observation_specs_with_shapes(obs_shapes), network_settings) optimizer = torch.optim.Adam(networkbody.parameters(), lr=3e-3) sample_obs = 0.1 * torch.ones((1, 84, 84, 3)) sample_vec_obs = torch.ones((1, vec_obs_size)) obs = [sample_vec_obs] + [sample_obs] for _ in range(150): encoded, _ = networkbody(obs) assert encoded.shape == (1, network_settings.hidden_units) # Try to force output to 1 loss = torch.nn.functional.mse_loss(encoded, torch.ones(encoded.shape)) optimizer.zero_grad() loss.backward() optimizer.step() # In the last step, values should be close to 1 for _enc in encoded.flatten().tolist(): assert _enc == pytest.approx(1.0, abs=0.1)
def test_networkbody_vector(): torch.manual_seed(0) obs_size = 4 network_settings = NetworkSettings() obs_shapes = [(obs_size, )] networkbody = NetworkBody( create_observation_specs_with_shapes(obs_shapes), network_settings, encoded_act_size=2, ) optimizer = torch.optim.Adam(networkbody.parameters(), lr=3e-3) sample_obs = 0.1 * torch.ones((1, obs_size)) sample_act = 0.1 * torch.ones((1, 2)) for _ in range(300): encoded, _ = networkbody([sample_obs], sample_act) assert encoded.shape == (1, network_settings.hidden_units) # Try to force output to 1 loss = torch.nn.functional.mse_loss(encoded, torch.ones(encoded.shape)) optimizer.zero_grad() loss.backward() optimizer.step() # In the last step, values should be close to 1 for _enc in encoded.flatten(): assert _enc == pytest.approx(1.0, abs=0.1)