def test_multinetworkbody_lstm(with_actions): torch.manual_seed(0) obs_size = 4 act_size = 2 seq_len = 16 n_agents = 3 network_settings = NetworkSettings(memory=NetworkSettings.MemorySettings( sequence_length=seq_len, memory_size=12)) obs_shapes = [(obs_size, )] action_spec = ActionSpec(act_size, tuple(act_size for _ in range(act_size))) networkbody = MultiAgentNetworkBody( create_observation_specs_with_shapes(obs_shapes), network_settings, action_spec) optimizer = torch.optim.Adam(networkbody.parameters(), lr=3e-4) sample_obs = [[0.1 * torch.ones((seq_len, obs_size))] for _ in range(n_agents)] # simulate baseline in POCA sample_act = [ AgentAction( 0.1 * torch.ones((seq_len, 2)), [0.1 * torch.ones(seq_len) for _ in range(act_size)], ) for _ in range(n_agents - 1) ] for _ in range(300): if with_actions: encoded, _ = networkbody( obs_only=sample_obs[:1], obs=sample_obs[1:], actions=sample_act, memories=torch.ones(1, 1, 12), sequence_length=seq_len, ) else: encoded, _ = networkbody( obs_only=sample_obs, obs=[], actions=[], memories=torch.ones(1, 1, 12), sequence_length=seq_len, ) # Try to force output to 1 loss = torch.nn.functional.mse_loss(encoded, torch.ones(encoded.shape)) optimizer.zero_grad() loss.backward() optimizer.step() # In the last step, values should be close to 1 for _enc in encoded.flatten().tolist(): assert _enc == pytest.approx(1.0, abs=0.1)
def test_multinetworkbody_visual(with_actions): torch.manual_seed(0) act_size = 2 n_agents = 3 obs_size = 4 vis_obs_size = (84, 84, 3) network_settings = NetworkSettings() obs_shapes = [(obs_size, ), vis_obs_size] action_spec = ActionSpec(act_size, tuple(act_size for _ in range(act_size))) networkbody = MultiAgentNetworkBody( create_observation_specs_with_shapes(obs_shapes), network_settings, action_spec) optimizer = torch.optim.Adam(networkbody.parameters(), lr=3e-3) sample_obs = [[0.1 * torch.ones( (1, obs_size))] + [0.1 * torch.ones((1, 84, 84, 3))] for _ in range(n_agents)] # simulate baseline in POCA sample_act = [ AgentAction(0.1 * torch.ones((1, 2)), [0.1 * torch.ones(1) for _ in range(act_size)]) for _ in range(n_agents - 1) ] for _ in range(300): if with_actions: encoded, _ = networkbody(obs_only=sample_obs[:1], obs=sample_obs[1:], actions=sample_act) else: encoded, _ = networkbody(obs_only=sample_obs, obs=[], actions=[]) assert encoded.shape == (1, network_settings.hidden_units) # Try to force output to 1 loss = torch.nn.functional.mse_loss(encoded, torch.ones(encoded.shape)) optimizer.zero_grad() loss.backward() optimizer.step() # In the last step, values should be close to 1 for _enc in encoded.flatten().tolist(): assert _enc == pytest.approx(1.0, abs=0.1)