def test_sampler(penv: ParallelEnv, is_recurrent: bool) -> None: NWORKERS = penv.nworkers rnns = (TeState(torch.arange(NWORKERS)) if is_recurrent else recurrent.DummyRnn.DUMMY_STATE) storage = RolloutStorage(NSTEP, penv.nworkers, Device()) storage.set_initial_state(penv.reset(), rnn_state=rnns) policy_dist = CategoricalDist(ACTION_DIM) for _ in range(NSTEP): state, reward, done, _ = penv.step([None] * NWORKERS) value = torch.rand(NWORKERS, dtype=torch.float32) policy = policy_dist(torch.rand(NWORKERS, ACTION_DIM)) storage.push(state, reward, done, rnn_state=rnns, policy=policy, value=value) MINIBATCH = 12 rnn_test = set() for batch in RolloutSampler(storage, penv, MINIBATCH): length = len(batch.states) assert length == MINIBATCH if isinstance(batch.rnn_init, TeState): assert batch.rnn_init.h.size(0) == MINIBATCH // NSTEP rnn_test.update(batch.rnn_init.h.cpu().tolist()) if is_recurrent: assert len(rnn_test) > NWORKERS - (MINIBATCH // NSTEP) penv.close()
def test_oc_storage() -> None: penv = DummyParallelEnv(lambda: DummyEnv(array_dim=(16, 16)), 6) NWORKERS = penv.nworkers NOPTIONS = 4 storage = AOCRolloutStorage(NSTEP, penv.nworkers, Device(), NOPTIONS) storage.set_initial_state(penv.reset()) policy_dist = CategoricalDist(ACTION_DIM) for _ in range(NSTEP): state, reward, done, _ = penv.step([None] * NWORKERS) value = torch.rand(NWORKERS, NOPTIONS) policy = policy_dist(torch.rand(NWORKERS, ACTION_DIM)) options = torch.randint(NOPTIONS, (NWORKERS, ), device=storage.device.unwrapped) opt_terminals = torch.randint(2, (NWORKERS, ), device=storage.device.unwrapped).byte() storage.push( state, reward, done, options=options, opt_terminals=opt_terminals, value=value, policy=policy, epsilon=0.5, ) next_value = torch.randn(NWORKERS, NOPTIONS).max(dim=-1)[0] storage.set_ac_returns(next_value, 0.99, 0.01) assert tuple(storage.beta_adv.shape) == (NSTEP, NWORKERS) penv.close()
def __init__( self, input_dim: Tuple[int, int, int], action_dim: int, conv_channels: List[int] = [32, 32, 32], conv_args: List[tuple] = [(4, 2, 1), (3, 1, 1), (3, 1, 1)], h_dim: int = 256, z_dim: int = 64, output_channels: int = 0, device: Device = Device(), ) -> None: super(ActorCriticNet, self).__init__() cnn_hidden = calc_cnn_hidden(conv_args, *input_dim[1:]) conved = cnn_hidden[0] * cnn_hidden[1] * conv_channels[-1] self.encoder = nn.Sequential( nn.Conv2d(input_dim[0], conv_channels[0], *conv_args[0]), nn.ReLU(True), nn.Conv2d(conv_channels[0], conv_channels[1], *conv_args[1]), nn.ReLU(True), nn.Conv2d(conv_channels[1], conv_channels[2], *conv_args[2]), nn.ReLU(True), Flatten(), nn.Linear(conved, h_dim), nn.ReLU(True), ) self.z_fc = LinearHead(h_dim, z_dim) self.logvar_fc = LinearHead(h_dim, z_dim) self.actor = LinearHead(z_dim, action_dim, Initializer(weight_init=orthogonal(0.01))) self.critic = LinearHead(z_dim, 1) if output_channels == 0: output_channels = input_dim[0] self.decoder = nn.Sequential( LinearHead(z_dim, h_dim), nn.ReLU(True), LinearHead(h_dim, conved), nn.ReLU(True), UnFlatten(cnn_hidden), nn.ConvTranspose2d(conv_channels[2], conv_channels[1], *conv_args[2]), nn.ReLU(True), nn.ConvTranspose2d(conv_channels[1], conv_channels[0], *conv_args[1]), nn.ReLU(True), nn.ConvTranspose2d(conv_channels[0], output_channels, *conv_args[0]), ) CNN_INIT(self.encoder) CNN_INIT(self.decoder) self.encoder = device.data_parallel(self.encoder) self.decoder = device.data_parallel(self.decoder) self.device = device self._state_dim = input_dim self.policy_head = CategoricalHead(action_dim=action_dim) self.to(device.unwrapped) self._rnn = DummyRnn()
def test_dim(net_gen: callable, input_dim: tuple) -> None: device = Device() vae_net = net_gen(input_dim, ACTION_DIM, device=device) batch = torch.randn(BATCH_SIZE, *input_dim) with torch.no_grad(): vae, policy, value = vae_net(device.tensor(batch)) assert vae.x.shape == torch.Size((BATCH_SIZE, *input_dim)) assert policy.dist.probs.shape == torch.Size((BATCH_SIZE, ACTION_DIM)) assert value.shape == torch.Size((BATCH_SIZE, )) print(vae_net)
def test_storage_and_irew() -> None: penv = DummyParallelEnv(lambda: DummyEnv(array_dim=(16, 16)), 6) NSTEPS = 4 ACTION_DIM = 3 NWORKERS = penv.nworkers states = penv.reset() storage = IntValueRolloutStorage(NSTEPS, NWORKERS, Device(), 0.99, Device(use_cpu=True)) storage.set_initial_state(states) policy_head = CategoricalDist(ACTION_DIM) for _ in range(NSTEPS): state, reward, done, _ = penv.step([None] * NWORKERS) value = torch.randn(NWORKERS) pvalue = torch.randn(NWORKERS) policy = policy_head(torch.randn(NWORKERS).view(-1, 1)) storage.push(state, reward, done, value=value, policy=policy, pvalue=pvalue) rewards = torch.randn(NWORKERS * NSTEPS) storage.calc_int_returns(torch.randn(NWORKERS), rewards, gamma=0.99, lambda_=0.95) batch = storage.batch_states(penv) batch_shape = torch.Size((NSTEPS * NWORKERS, )) assert batch.shape == torch.Size((*batch_shape, 16, 16)) MINIBATCH = 12 sampler = rnd.rollout.RNDRolloutSampler( RolloutSampler(storage, penv, MINIBATCH), storage, torch.randn(NSTEPS * NWORKERS), 1.0, 1.0, ) assert sampler.int_returns.shape == batch_shape assert sampler.int_values.shape == batch_shape assert sampler.advantages.shape == batch_shape for batch in sampler: assert len(batch.states) == MINIBATCH penv.close()
def test_ffmodel_for_atari() -> None: atari = gym.make("BreakoutNoFrameskip-v0") acvp_netfn = models.prepare_ff() d = Device() acvp_net = acvp_netfn((3, 210, 160), 4, d) states, actions = [], [] atari.reset() for _ in range(10): s, _, _, _ = atari.step(0) states.append(s.transpose(2, 0, 1)) actions.append(0) states = d.tensor(states) actions = d.tensor(actions, dtype=torch.long) s_decoded = acvp_net(states, actions) assert tuple(s_decoded.shape) == (10, 3, 210, 160)
def __init__( self, body: NetworkBlock, actor_head: NetworkBlock, critic_head: NetworkBlock, policy_dist: PolicyDist, recurrent_body: RnnBlock = DummyRnn(), device: Device = Device(), int_critic_head: Optional[NetworkBlock] = None, ) -> None: super().__init__(body, actor_head, critic_head, policy_dist, recurrent_body, device) self.int_critic_head = (copy.deepcopy(self.critic_head) if int_critic_head is None else int_critic_head) self.int_critic_head.to(device.unwrapped)
def test_tcnet(state_dim: tuple): BATCH_SIZE = 10 NUM_OPTIONS = 3 if len(state_dim) > 1: net_fn = termination_critic.tc_conv_shared(num_options=NUM_OPTIONS) else: net_fn = termination_critic.tc_fc_shared(num_options=NUM_OPTIONS) net = net_fn(state_dim, 1, Device()) input1 = torch.randn(BATCH_SIZE, *state_dim) input2 = torch.randn(BATCH_SIZE, *state_dim) out = net(input1, input2) assert tuple(out.beta.dist.logits.shape) == (BATCH_SIZE, NUM_OPTIONS) assert tuple(out.p.shape) == (BATCH_SIZE, NUM_OPTIONS) assert tuple(out.p_mu.shape) == (BATCH_SIZE, NUM_OPTIONS) assert tuple(out.baseline.shape) == (BATCH_SIZE, NUM_OPTIONS)
def test_save_and_load(irew_gen) -> None: c = config() c._int_reward_gen = irew_gen agent = rnd.RNDAgent(c) agent.irew_gen.gen_rewards(torch.randn(4 * 4, 2, 84, 84)) nonep = agent.irew_gen.rff_rms.mean.cpu().numpy() savedir = Path("Results/Test") if not savedir.exists(): savedir.mkdir(parents=True) agent.save("agent.pth", savedir) agent.close() c.device = Device(use_cpu=True) agent = rnd.RNDAgent(c) agent.load("agent.pth", savedir) nonep_new = agent.irew_gen.rff_rms.mean.cpu().numpy() assert_array_almost_equal(nonep, nonep_new) agent.close()
def test_storage(penv: ParallelEnv) -> None: NWORKERS = penv.nworkers storage = RolloutStorage(NSTEP, penv.nworkers, Device()) storage.set_initial_state(penv.reset()) policy_dist = CategoricalDist(ACTION_DIM) for _ in range(NSTEP): state, reward, done, _ = penv.step([None] * NWORKERS) value = torch.rand(NWORKERS, dtype=torch.float32) policy = policy_dist(torch.rand(NWORKERS, ACTION_DIM)) storage.push(state, reward, done, value=value, policy=policy) batch = storage.batch_states(penv) batch_shape = torch.Size((NSTEP * NWORKERS, )) assert batch.shape == torch.Size((*batch_shape, 16, 16)) sampler = RolloutSampler(storage, penv, 10) assert sampler.actions.shape == batch_shape assert sampler.returns.shape == batch_shape assert sampler.masks.shape == batch_shape assert sampler.values.shape == batch_shape assert sampler.old_log_probs.shape == batch_shape penv.close()
def __init__( self, actor_body: NetworkBlock, critic_body: NetworkBlock, action_dim: int, action_coef: float = 1.0, device: Device = Device(), init: Initializer = Initializer(weight_init=kaiming_uniform(a=3**0.5)), ) -> None: super().__init__() self.actor = nn.Sequential( actor_body, LinearHead(actor_body.output_dim, action_dim, init=init), nn.Tanh(), ) self.critic = nn.Sequential( critic_body, LinearHead(critic_body.output_dim, 1, init=init)) self.to(device.unwrapped) self.action_coef = action_coef self.device = device
def test_rnn(rnn_gen: Callable[[int, int], RnnBlock]) -> None: TIME_STEP = 10 BATCH_SIZE = 5 INPUT_DIM = 20 OUTPUT_DIM = 3 rnn = rnn_gen(INPUT_DIM, OUTPUT_DIM) device = Device() rnn.to(device.unwrapped) hidden = rnn.initial_state(BATCH_SIZE, device) cached_inputs = [] for i in range(TIME_STEP): inputs = torch.randn(BATCH_SIZE, INPUT_DIM, device=device.unwrapped) cached_inputs.append(inputs.detach()) out, hidden = rnn(inputs, hidden) assert tuple(out.shape) == (BATCH_SIZE, OUTPUT_DIM) batch_inputs = torch.cat(cached_inputs) hidden = rnn.initial_state(BATCH_SIZE, device) out, _ = rnn(batch_inputs, hidden) assert tuple(out.shape) == (TIME_STEP * BATCH_SIZE, OUTPUT_DIM)
def __init__( self, input_dim: Sequence[int], action_dim: int, hidden_dim: int = 2048, conv_channels: List[int] = [64, 128, 128, 128], encoder_args: List[tuple] = [(8, 2, (0, 1)), (6, 2, 1), (6, 2, 1), (4, 2)], decoder_args: List[tuple] = [(4, 2), (6, 2, 1), (6, 2, 1), (8, 2, (0, 1))], device: Device = Device(), init: Initializer = Initializer(orthogonal(nonlinearity="relu")), ) -> None: super().__init__() in_channel, height, width = input_dim channels = [in_channel] + conv_channels self.conv = init.make_list([ nn.Conv2d(channels[i], channels[i + 1], *encoder_args[i]) for i in range(len(channels) - 1) ]) conved_dim = (np.prod(calc_cnn_hidden(encoder_args, height, width)) * channels[-1]) self.fc_enc = nn.Linear(conved_dim, hidden_dim) self.w_enc = nn.Linear(hidden_dim, hidden_dim, bias=False) self.w_action = nn.Linear(action_dim, hidden_dim, bias=False) self.fc_action_trans = nn.Linear(hidden_dim, hidden_dim) self.fc_dec = nn.Linear(hidden_dim, conved_dim) channels.reverse() self.deconv = init.make_list([ nn.ConvTranspose2d(channels[i], channels[i + 1], *decoder_args[i]) for i in range(len(channels) - 1) ]) self.action_dim = action_dim self.device = device self.to(device.unwrapped)
CNNBodyWithoutFC, GruBlock, LstmBlock, actor_critic, termination_critic, ) from rainy.net.init import Initializer, kaiming_normal, kaiming_uniform from rainy.utils import Device ACTION_DIM = 10 @pytest.mark.parametrize( "net, state_dim, batch_size", [ (actor_critic.fc_shared()((4, ), ACTION_DIM, Device()), (4, ), 32), ( actor_critic.conv_shared()((4, 84, 84), ACTION_DIM, Device()), (4, 84, 84), 32, ), ( actor_critic.conv_shared(rnn=GruBlock)( (4, 84, 84), ACTION_DIM, Device()), (4, 84, 84), 32, ), ( actor_critic.conv_shared(rnn=LstmBlock)( (4, 84, 84), ACTION_DIM, Device()), (4, 84, 84),
import numpy as np import pytest from rainy.net import actor_critic, DqnConv, GruBlock, LstmBlock from rainy.net.init import Initializer, kaiming_normal, kaiming_uniform from rainy.utils import Device from test_env import DummyEnv import torch from typing import Optional, Tuple ACTION_DIM = 10 @pytest.mark.parametrize('net, state_dim, batch_size', [ (actor_critic.fc_shared()((4,), ACTION_DIM, Device()), (4,), 32), (actor_critic.ac_conv()((4, 84, 84), ACTION_DIM, Device()), (4, 84, 84), 32), (actor_critic.ac_conv(rnn=GruBlock)((4, 84, 84), ACTION_DIM, Device()), (4, 84, 84), 32), (actor_critic.ac_conv(rnn=LstmBlock)((4, 84, 84), ACTION_DIM, Device()), (4, 84, 84), 32), (actor_critic.impala_conv()((4, 84, 84), ACTION_DIM, Device()), (4, 84, 84), 32), ]) def test_acnet(net: actor_critic.ActorCriticNet, state_dim: tuple, batch_size: int) -> None: assert net.state_dim == state_dim assert net.action_dim == ACTION_DIM env = DummyEnv() states = np.stack([env.step(None)[0].to_array(state_dim) for _ in range(batch_size)]) policy, values, _ = net(states) batch_size = torch.Size([batch_size]) assert policy.action().shape == batch_size assert policy.log_prob().shape == batch_size assert policy.entropy().shape == batch_size assert values.shape == batch_size
def test_eps_greedy(): eg = EpsGreedy(1.0, LinearCooler(1.0, 0.1, int(100))) value_pred = value.fc()((100, ), 10, Device(use_cpu=True)) for _ in range(0, 100): eg.select_action(np.arange(100), value_pred) assert eg.epsilon == 0.1