def atari_encoder(in_channels): encoder = nn.Sequential( nn_init(nn.Conv2d(in_channels, 16, kernel_size=8, stride=4), w_scale=np.sqrt(2)), nn.ReLU(True), nn_init(nn.Conv2d(16, 32, kernel_size=4, stride=2), w_scale=np.sqrt(2)), nn.ReLU(True), ) return encoder
def __init__(self, embed_dim, num_actions): super(MLPRewardFn, self).__init__() self.embedding_dim = embed_dim self.num_actions = num_actions self.mlp = nn.Sequential( nn_init(nn.Linear(embed_dim, 64), w_scale=np.sqrt(2)), nn.ReLU(inplace=True), nn_init(nn.Linear(64, num_actions), w_scale=0.01))
def __init__(self, embed_dim, num_actions): super(MLPRewardFn, self).__init__() self.embedding_dim = embed_dim self.num_actions = num_actions self.mlp = nn.Sequential( nn_init(nn.Linear(embed_dim, 64), w_scale=np.sqrt(2)), nn.ReLU(inplace=True), nn_init(nn.Linear(64, num_actions), w_scale=0.01) )
def push_encoder(in_channels): encoder = nn.Sequential( nn_init(nn.Conv2d(in_channels, 24, kernel_size=3, stride=1), w_scale=1.0), nn.ReLU(inplace=True), nn_init(nn.Conv2d(24, 24, kernel_size=3, stride=1), w_scale=1.0), nn.ReLU(inplace=True), nn_init(nn.Conv2d(24, 48, kernel_size=4, stride=2), w_scale=1.0), nn.ReLU(inplace=True), ) return encoder
def __init__(self, ob_space, ac_space, nenv, nsteps, nstack, use_actor_critic=False, transition_fun_name="matrix", transition_nonlin="tanh", normalise_state=True, residual_transition=True, tree_depth=2, embedding_dim=512, predict_rewards=True, gamma=0.99, td_lambda=0.8, input_mode="atari", value_aggregation="softmax", output_tree=False): super(TreeQNPolicy, self).__init__() nbatch = nenv * nsteps nh, nw, nc = ob_space.shape ob_shape = (nbatch, nc * nstack, nh, nw) self.nenv = nenv self.num_actions = ac_space.n self.embedding_dim = embedding_dim self.use_actor_critic = use_actor_critic self.obs_scale = 1.0 self.eps_threshold = 0 self.predict_rewards = predict_rewards self.gamma = gamma self.output_tree = output_tree self.td_lambda = td_lambda self.residual_transition = residual_transition if transition_fun_name == "two_layer": # we are going to introduce residuals manually inside the transition function, so turning this off... self.residual_transition = False self.normalise_state = normalise_state self.value_aggregation = value_aggregation self.embedding_dim = embedding_dim if input_mode == "atari": encoder = atari_encoder(ob_shape[1]) dummy = Variable(torch.zeros(1, *ob_shape[1:])) conv_dim_out = tuple(encoder(dummy).size())[1:] self.obs_scale = 255.0 elif input_mode == "push": encoder = push_encoder(ob_shape[1]) dummy = Variable(torch.zeros(1, *ob_shape[1:])) conv_dim_out = tuple(encoder(dummy).size())[1:] else: raise ValueError("Input mode not accepted. use atari, push") print("CONV DIM OUT", conv_dim_out) flat_conv_dim_out = int(np.prod(conv_dim_out)) self.embed = nn.Sequential( encoder, View(-1, flat_conv_dim_out), nn_init(nn.Linear(flat_conv_dim_out, self.embedding_dim), w_scale=np.sqrt(2)), nn.ReLU(True) ) self.value_fn = nn_init(nn.Linear(embedding_dim, 1), w_scale=.01) if self.use_actor_critic: self.ac_value_fn = nn_init(nn.Linear(embedding_dim, 1), w_scale=1.0) self.transition_fun_name = transition_fun_name if transition_nonlin == "tanh": self.transition_nonlin = nn.Tanh() elif transition_nonlin == "relu": self.transition_nonlin = nn.ReLU() else: raise ValueError if self.transition_fun_name == "two_layer": self.transition_fun1, self.transition_fun2 = \ build_transition_fn(transition_fun_name, embedding_dim, nonlin=self.transition_nonlin, num_actions=self.num_actions) else: self.transition_fun = build_transition_fn(transition_fun_name, embedding_dim, nonlin=self.transition_nonlin, num_actions=self.num_actions) if self.predict_rewards: self.tree_reward_fun = MLPRewardFn(embedding_dim, self.num_actions) self.tree_depth = tree_depth
def __init__(self, ob_space, ac_space, nenv, nsteps, nstack, use_actor_critic=False, transition_fun_name="matrix", transition_nonlin="tanh", normalise_state=True, residual_transition=True, tree_depth=2, embedding_dim=512, predict_rewards=True, gamma=0.99, td_lambda=0.8, input_mode="atari", value_aggregation="softmax", output_tree=False): super(TreeQNPolicy, self).__init__() nbatch = nenv * nsteps nh, nw, nc = ob_space.shape ob_shape = (nbatch, nc * nstack, nh, nw) self.nenv = nenv self.num_actions = ac_space.n self.embedding_dim = embedding_dim self.use_actor_critic = use_actor_critic self.obs_scale = 1.0 self.eps_threshold = 0 self.predict_rewards = predict_rewards self.gamma = gamma self.output_tree = output_tree self.td_lambda = td_lambda self.residual_transition = residual_transition if transition_fun_name == "two_layer": # we are going to introduce residuals manually inside the transition function, so turning this off... self.residual_transition = False self.normalise_state = normalise_state self.value_aggregation = value_aggregation self.embedding_dim = embedding_dim if input_mode == "atari": encoder = atari_encoder(ob_shape[1]) dummy = Variable(torch.zeros(1, *ob_shape[1:])) conv_dim_out = tuple(encoder(dummy).size())[1:] self.obs_scale = 255.0 elif input_mode == "push": encoder = push_encoder(ob_shape[1]) dummy = Variable(torch.zeros(1, *ob_shape[1:])) conv_dim_out = tuple(encoder(dummy).size())[1:] elif input_mode == "blocksworld": encoder = blocksworld_encoder(ob_shape[1]) dummy = Variable(torch.zeros(1, *ob_shape[1:])) conv_dim_out = tuple(encoder(dummy).size())[1:] else: raise ValueError("Input mode not accepted. use atari, push") print("CONV DIM OUT", conv_dim_out) flat_conv_dim_out = int(np.prod(conv_dim_out)) self.embed = nn.Sequential( encoder, View(-1, flat_conv_dim_out), nn_init(nn.Linear(flat_conv_dim_out, self.embedding_dim), w_scale=np.sqrt(2)), nn.ReLU(True)) self.value_fn = nn_init(nn.Linear(embedding_dim, 1), w_scale=.01) if self.use_actor_critic: self.ac_value_fn = nn_init(nn.Linear(embedding_dim, 1), w_scale=1.0) self.transition_fun_name = transition_fun_name if transition_nonlin == "tanh": self.transition_nonlin = nn.Tanh() elif transition_nonlin == "relu": self.transition_nonlin = nn.ReLU() else: raise ValueError if self.transition_fun_name == "two_layer": self.transition_fun1, self.transition_fun2 = \ build_transition_fn(transition_fun_name, embedding_dim, nonlin=self.transition_nonlin, num_actions=self.num_actions) else: self.transition_fun = build_transition_fn( transition_fun_name, embedding_dim, nonlin=self.transition_nonlin, num_actions=self.num_actions) if self.predict_rewards: self.tree_reward_fun = MLPRewardFn(embedding_dim, self.num_actions) self.tree_depth = tree_depth