def call(self, state):
     inputs = tf.constant(np.zeros((state.shape[0], stack_size)),
                          dtype=tf.float32)
     return atari_lib.DQNNetworkType(self.layer((inputs)))
Exemplo n.º 2
0
 def call(self, state):
   """Creates the output tensor/op given the state tensor as input."""
   x = self.net(state)
   return atari_lib.DQNNetworkType(x)
    def __call__(self, x, rng):

        if self.net_conf == 'minatar':
            x = x.squeeze(3)
            x = x.astype(jnp.float32)
            x = nn.Conv(features=16,
                        kernel_size=(3, 3, 3),
                        strides=(1, 1, 1),
                        kernel_init=self.initzer)(x)
            x = jax.nn.relu(x)
            x = x.reshape((-1))

        elif self.net_conf == 'atari':
            # We need to add a "batch dimension" as nn.Conv expects it, yet vmap will
            # have removed the true batch dimension.
            x = x.astype(jnp.float32) / 255.
            x = nn.Conv(features=32,
                        kernel_size=(8, 8),
                        strides=(4, 4),
                        kernel_init=self.initzer)(x)
            x = jax.nn.relu(x)
            x = nn.Conv(features=64,
                        kernel_size=(4, 4),
                        strides=(2, 2),
                        kernel_init=self.initzer)(x)
            x = jax.nn.relu(x)
            x = nn.Conv(features=64,
                        kernel_size=(3, 3),
                        strides=(1, 1),
                        kernel_init=self.initzer)(x)
            x = jax.nn.relu(x)
            x = x.reshape((-1))  # flatten

        elif self.net_conf == 'classic':
            #classic environments
            #print('x1:',x.shape)
            x = x.astype(jnp.float32)
            x = x.reshape((-1))
            #print('x2:',x.shape)

        if self.env is not None and self.env in env_inf:
            x = x - env_inf[self.env]['MIN_VALS']
            x /= env_inf[self.env]['MAX_VALS'] - env_inf[self.env]['MIN_VALS']
            x = 2.0 * x - 1.0

        if self.noisy:

            def net(x, features, rng):
                return NoisyNetwork(features, rng=rng, bias_in=True)(x)
        else:

            def net(x, features, rng):
                return nn.Dense(features, kernel_init=self.initzer)(x)

        for _ in range(self.hidden_layer):
            x = net(x, features=self.neurons, rng=rng)
            x = jax.nn.relu(x)

        adv = net(x, features=self.num_actions, rng=rng)
        val = net(x, features=1, rng=rng)
        #print('adv:',adv.shape)
        #print('val:',val.shape)
        dueling_q = val + (adv - (jnp.mean(adv, -1, keepdims=True)))
        non_dueling_q = net(x, features=self.num_actions, rng=rng)

        q_values = jnp.where(self.dueling, dueling_q, non_dueling_q)
        #print('q_values:',q_values.shape)

        return atari_lib.DQNNetworkType(q_values)
    def apply(self,
              x,
              num_actions,
              net_conf,
              env,
              normalize_obs,
              noisy,
              dueling,
              hidden_layer=2,
              neurons=512):
        del normalize_obs

        if net_conf == 'minatar':
            x = x.squeeze(3)
            x = x[None, ...]
            x = x.astype(jnp.float32)
            x = nn.Conv(x,
                        features=16,
                        kernel_size=(3, 3, 3),
                        strides=(1, 1, 1),
                        kernel_init=arq_inf["initializers_layers"][TODO])
            #x = jax.nn.relu(x)
            x = arq_inf['conv_act_layers'][TODO]
            x = x.reshape((x.shape[0], -1))

        elif net_conf == 'atari':
            # We need to add a "batch dimension" as nn.Conv expects it, yet vmap will
            # have removed the true batch dimension.
            x = x[None, ...]
            x = x.astype(jnp.float32) / 255.
            x = nn.Conv(x,
                        features=32,
                        kernel_size=(8, 8),
                        strides=(4, 4),
                        kernel_init=arq_inf["initializers_layers"][TODO])
            #x = jax.nn.relu(x)
            x = arq_inf['conv_act_layers'][TODO]
            x = nn.Conv(x,
                        features=64,
                        kernel_size=(4, 4),
                        strides=(2, 2),
                        kernel_init=arq_inf["initializers_layers"][TODO])
            #x = jax.nn.relu(x)
            x = arq_inf['conv_act_layers'][TODO]
            x = nn.Conv(x,
                        features=64,
                        kernel_size=(3, 3),
                        strides=(1, 1),
                        kernel_init=arq_inf["initializers_layers"][TODO])
            #x = jax.nn.relu(x)
            x = arq_inf['conv_act_layers'][TODO]
            x = x.reshape((x.shape[0], -1))  # flatten

        elif net_conf == 'classic':
            #classic environments
            x = x[None, ...]
            x = x.astype(jnp.float32)
            x = x.reshape((x.shape[0], -1))

        if env is not None:
            x = x - env_inf[env]['MIN_VALS']
            x /= env_inf[env]['MAX_VALS'] - env_inf[env]['MIN_VALS']
            x = 2.0 * x - 1.0

        if noisy:

            def net(x, features):
                return NoisyNetwork(x, features)
        else:

            def net(x, features):
                return nn.Dense(
                    x,
                    features,
                    kernel_init=arq_inf["initializers_layers"][TODO])

        for _ in range(hidden_layer):
            x = net(x, features=neurons)
            #x = jax.nn.relu(x)
            x = arq_inf['dens_act_layers'][TODO]

        adv = net(x, features=num_actions)
        val = net(x, features=1)
        dueling_q = val + (adv - (jnp.mean(adv, -1, keepdims=True)))
        non_dueling_q = net(x, features=num_actions)

        q_values = jnp.where(dueling, dueling_q, non_dueling_q)

        return atari_lib.DQNNetworkType(q_values)
Exemplo n.º 5
0
 def call(self, state):
     x = self.net(state)
     return atari_lib.DQNNetworkType(x)
Exemplo n.º 6
0
 def call(self, state):
     """Creates the output tensor/op given the state tensor as input."""
     x = tf.cast(state, tf.float32)
     x = self.forward_fn(x)
     return atari_lib.DQNNetworkType(x)
Exemplo n.º 7
0
    def __call__(self, x, rng):

        if self.net_conf == 'minatar':
            x = x.squeeze(3)
            x = x.astype(jnp.float32)
            for _ in range(self.hidden_conv):
                x = nn.Conv(features=16,
                            kernel_size=(3, 3),
                            strides=(1, 1),
                            padding='SAME',
                            kernel_init=self.initzer)(x)
                x = layer_funct_inf[self.layer_funct](x)
            x = nn.Conv(features=16,
                        kernel_size=(3, 3),
                        strides=(1, 1),
                        kernel_init=self.initzer)(x)
            x = layer_funct_inf[self.layer_funct](x)
            x = x.reshape((-1))

        elif self.net_conf == 'atari':
            # We need to add a "batch dimension" as nn.Conv expects it, yet vmap will
            # have removed the true batch dimension.
            ks = [8, 4, 3]
            fts = [32, 64, 64]
            sts = [4, 2, 1]
            x = x.astype(jnp.float32) / 255.
            for i in range(self.hidden_conv):
                x = nn.Conv(features=fts[i],
                            kernel_size=(ks[i], ks[i]),
                            strides=(sts[i], sts[i]),
                            kernel_init=self.initzer)(x)
                x = layer_funct_inf[self.layer_funct](x)
            x = x.reshape((-1))  # flatten

        elif self.net_conf == 'classic':
            #classic environments
            x = x.astype(jnp.float32)
            x = x.reshape((-1))

        if self.env is not None and self.env in env_inf:
            x = x - env_inf[self.env]['MIN_VALS']
            x /= env_inf[self.env]['MAX_VALS'] - env_inf[self.env]['MIN_VALS']
            x = 2.0 * x - 1.0

        if self.noisy:

            def net(x, features, rng):
                return NoisyNetwork(features, rng=rng, bias_in=True)(x)
        else:

            def net(x, features, rng):
                return nn.Dense(features, kernel_init=self.initzer)(x)

        for _ in range(self.hidden_layer):
            x = net(x, features=self.neurons, rng=rng)
            if self.normalization == 'non_normalization':
                if self.layer_funct != 'non_activation':
                    x = layer_funct_inf[self.layer_funct](x)
            elif self.normalization == 'BatchNorm':
                x = nn.BatchNorm(use_running_average=True)(x)
                if self.layer_funct != 'non_activation':
                    x = layer_funct_inf[self.layer_funct](x)
            elif self.normalization == 'LayerNorm':
                if self.layer_funct != 'non_activation':
                    x = layer_funct_inf[self.layer_funct](x)
                x = nn.LayerNorm()(x)
            else:
                print('error: Choose a correct Normalization Module')

        adv = net(x, features=self.num_actions, rng=rng)
        val = net(x, features=1, rng=rng)

        dueling_q = val + (adv - (jnp.mean(adv, -1, keepdims=True)))
        non_dueling_q = net(x, features=self.num_actions, rng=rng)

        q_values = jnp.where(self.dueling, dueling_q, non_dueling_q)
        return atari_lib.DQNNetworkType(q_values)