Exemple #1
0
    def __init__(self, game, args):
        # Network arguments
        self.x, self.y, self.planes = game.getDimensions()

        self.action_size = game.getActionSize()
        self.args = args
        self.crafter = Crafter(args)

        # s: batch_size x time x state_x x state_y
        self.observation_history = Input(shape=(self.x, self.y, self.planes *
                                                self.args.observation_length))
        # a: one hot encoded vector of shape batch_size x (state_x * state_y)
        self.action_tensor = Input(shape=(self.action_size, ))

        observations = Reshape(
            (self.x * self.y * self.planes * self.args.observation_length, ))(
                self.observation_history)

        self.pi, self.v = self.build_predictor(observations)

        self.model = Model(inputs=self.observation_history,
                           outputs=[self.pi, self.v])

        opt = Adam(args.optimizer.lr_init)
        if self.args.support_size > 0:
            self.model.compile(loss=['categorical_crossentropy'] * 2,
                               optimizer=opt)
        else:
            self.model.compile(
                loss=['categorical_crossentropy', 'mean_squared_error'],
                optimizer=opt)
Exemple #2
0
    def __init__(self, game, args):
        # Network arguments
        self.board_x, self.board_y, self.planes = game.getDimensions()
        self.latent_x, self.latent_y = (6, 6)
        self.action_size = game.getActionSize()
        self.args = args
        self.crafter = Crafter(args)

        assert self.action_size == self.latent_x * self.latent_y, \
            "The action space should be the same size as the latent space"

        # s: batch_size x time x state_x x state_y
        self.observation_history = Input(shape=(self.board_x, self.board_y, self.planes * self.args.observation_length))
        # a: one hot encoded vector of shape batch_size x (state_x * state_y)
        self.action_plane = Input(shape=(self.action_size,))
        # s': batch_size  x board_x x board_y x 1
        self.latent_state = Input(shape=(self.latent_x, self.latent_y, 1))

        action_plane = Reshape((self.latent_x, self.latent_y, 1))(self.action_plane)
        latent_state = Reshape((self.latent_x, self.latent_y, 1))(self.latent_state)

        self.s = self.build_encoder(self.observation_history)
        self.r, self.s_next = self.build_dynamics(latent_state, action_plane)

        self.pi, self.v = self.build_predictor(latent_state)

        self.encoder = Model(inputs=self.observation_history, outputs=self.s)
        self.dynamics = Model(inputs=[self.latent_state, self.action_plane], outputs=[self.r, self.s_next])
        self.predictor = Model(inputs=self.latent_state, outputs=[self.pi, self.v])

        self.forward = Model(inputs=self.observation_history, outputs=[self.s, *self.predictor(self.s)])
        self.recurrent = Model(inputs=[self.latent_state, self.action_plane],
                               outputs=[self.r, self.s_next, *self.predictor(self.s_next)])
Exemple #3
0
class AlphaZeroAtariNetwork:
    def __init__(self, game, args):
        self.board_x, self.board_y, depth = game.getDimensions()
        self.action_size = game.getActionSize()
        self.args = args
        self.crafter = Crafter(args)

        # Neural Net# s: batch_size x board_x x board_y
        self.input_boards = Input(shape=(self.board_x, self.board_y,
                                         depth * self.args.observation_length))

        self.pi, self.v = self.build_model(self.input_boards)

        self.model = Model(inputs=self.input_boards, outputs=[self.pi, self.v])

        opt = Adam(args.optimizer.lr_init)
        if self.args.support_size > 0:
            self.model.compile(loss=['categorical_crossentropy'] * 2,
                               optimizer=opt)
        else:
            self.model.compile(
                loss=['categorical_crossentropy', 'mean_squared_error'],
                optimizer=opt)

        print(self.model.summary())

    def build_model(self, x_image):
        conv = Conv2D(self.args.num_channels, kernel_size=3,
                      strides=(2, 2))(x_image)
        res = self.crafter.conv_residual_tower(self.args.num_towers, conv,
                                               self.args.residual_left,
                                               self.args.residual_right)

        conv = Conv2D(self.args.num_channels, kernel_size=3,
                      strides=(2, 2))(res)
        res = self.crafter.conv_residual_tower(self.args.num_towers, conv,
                                               self.args.residual_left,
                                               self.args.residual_right)

        pooled = AveragePooling2D(3, strides=(2, 2))(res)
        res = self.crafter.conv_residual_tower(self.args.num_towers, pooled,
                                               self.args.residual_left,
                                               self.args.residual_right)

        pooled = AveragePooling2D(3, strides=(2, 2))(res)
        conv = Conv2D(self.args.num_channels // 2,
                      kernel_size=3,
                      strides=(2, 2))(pooled)
        flat = Flatten()(conv)

        fc = self.crafter.dense_sequence(1, flat)

        pi = Dense(self.action_size, activation='softmax', name='pi')(fc)
        v = Dense(1, activation='tanh', name='v')(fc) \
            if self.args.support_size == 0 else \
            Dense(self.args.support_size * 2 + 1, activation='softmax', name='v')(fc)

        return pi, v
Exemple #4
0
    def __init__(self, game, args):
        # Network arguments
        self.board_x, self.board_y, self.planes = game.getDimensions()
        self.action_size = game.getActionSize()
        self.args = args
        self.crafter = Crafter(args)

        # s: batch_size x time x state_x x state_y
        self.observation_history = Input(shape=(self.board_x, self.board_y,
                                                self.planes *
                                                self.args.observation_length))
        # a: one hot encoded vector of shape batch_size x (state_x * state_y)
        self.action_plane = Input(shape=(self.action_size, ))
        # s': batch_size  x board_x x board_y x 1
        self.encoded_state = Input(shape=(self.board_x, self.board_y,
                                          self.args.latent_depth))

        # Format action vector to plane
        omit_resign = Lambda(lambda x: x[..., :-1],
                             output_shape=(self.board_x * self.board_y, ),
                             input_shape=(self.action_size, ))(
                                 self.action_plane)
        action_plane = Reshape((self.board_x, self.board_y, 1))(omit_resign)

        self.s = self.build_encoder(self.observation_history)
        self.r, self.s_next = self.build_dynamics(self.encoded_state,
                                                  action_plane)
        self.pi, self.v = self.build_predictor(self.encoded_state)

        self.encoder = Model(inputs=self.observation_history,
                             outputs=self.s,
                             name='h')
        self.dynamics = Model(inputs=[self.encoded_state, self.action_plane],
                              outputs=[self.r, self.s_next],
                              name='g')
        self.predictor = Model(inputs=self.encoded_state,
                               outputs=[self.pi, self.v],
                               name='f')

        self.forward = Model(inputs=self.observation_history,
                             outputs=[self.s, *self.predictor(self.s)])
        self.recurrent = Model(
            inputs=[self.encoded_state, self.action_plane],
            outputs=[self.r, self.s_next, *self.predictor(self.s_next)])

        # Decoder functionality.
        self.decoded_observations = self.build_decoder(self.encoded_state)
        self.decoder = Model(inputs=self.encoded_state,
                             outputs=self.decoded_observations,
                             name='decoder')

        print(self.encoder.summary())
        print(self.dynamics.summary())
        print(self.predictor.summary())
        print(self.decoder.summary())
Exemple #5
0
class AlphaZeroHexNetwork:
    def __init__(self, game, args):
        # game params
        self.board_x, self.board_y, depth = game.getDimensions()
        self.action_size = game.getActionSize()
        self.args = args
        self.crafter = Crafter(args)

        # Neural Net# s: batch_size x board_x x board_y
        self.input_boards = Input(shape=(self.board_x, self.board_y,
                                         depth * self.args.observation_length))

        self.pi, self.v = self.build_model(self.input_boards)

        self.model = Model(inputs=self.input_boards, outputs=[self.pi, self.v])

        opt = Adam(args.optimizer.lr_init)
        if self.args.support_size > 0:
            self.model.compile(loss=['categorical_crossentropy'] * 2,
                               optimizer=opt)
        else:
            self.model.compile(
                loss=['categorical_crossentropy', 'mean_squared_error'],
                optimizer=opt)
        print(self.model.summary())

    def build_model(self, x_image):
        conv = self.crafter.conv_tower(self.args.num_convs,
                                       x_image,
                                       use_bn=True)
        res = self.crafter.conv_residual_tower(self.args.num_towers,
                                               conv,
                                               self.args.residual_left,
                                               self.args.residual_right,
                                               use_bn=True)

        small = self.crafter.activation()(BatchNormalization()(Conv2D(
            32, 3, padding='same', use_bias=False)(res)))

        flat = Flatten()(small)

        fc = self.crafter.dense_sequence(1, flat)

        pi = Dense(self.action_size, activation='softmax', name='pi')(fc)
        v = Dense(1, activation='tanh', name='v')(fc) \
            if self.args.support_size == 0 else \
            Dense(self.args.support_size * 2 + 1, activation='softmax', name='v')(fc)

        return pi, v
Exemple #6
0
    def __init__(self, game, args):
        # Network arguments
        self.x, self.y, self.planes = game.getDimensions()
        self.latents = args.latent_depth
        self.action_size = game.getActionSize()
        self.args = args
        self.crafter = Crafter(args)

        # s: batch_size x time x state_x x state_y
        self.observation_history = Input(shape=(self.x, self.y, self.planes))
        # a: one hot encoded vector of shape batch_size x (state_x * state_y)
        self.action_tensor = Input(shape=(self.action_size, ))
        # s': batch_size  x board_x x board_y x 1
        self.latent_state = Input(shape=(self.latents, 1))

        observations = Reshape(
            (self.x * self.y * self.planes, ))(self.observation_history)
        latent_state = Reshape((self.latents, ))(self.latent_state)

        # Build tensorflow computation graph
        self.s = self.build_encoder(observations)
        self.r, self.s_next = self.build_dynamics(latent_state,
                                                  self.action_tensor)
        self.pi, self.v = self.build_predictor(latent_state)

        self.encoder = Model(inputs=self.observation_history,
                             outputs=self.s,
                             name="r")
        self.dynamics = Model(inputs=[self.latent_state, self.action_tensor],
                              outputs=[self.r, self.s_next],
                              name='d')
        self.predictor = Model(inputs=self.latent_state,
                               outputs=[self.pi, self.v],
                               name='p')

        self.forward = Model(inputs=self.observation_history,
                             outputs=[self.s, *self.predictor(self.s)],
                             name='initial')
        self.recurrent = Model(
            inputs=[self.latent_state, self.action_tensor],
            outputs=[self.r, self.s_next, *self.predictor(self.s_next)],
            name='recurrent')

        # Decoder functionality.
        self.decoded_observations = self.build_decoder(latent_state)
        self.decoder = Model(inputs=self.latent_state,
                             outputs=self.decoded_observations,
                             name='decoder')
Exemple #7
0
    def __init__(self, game, args):
        self.board_x, self.board_y, depth = game.getDimensions()
        self.action_size = game.getActionSize()
        self.args = args
        self.crafter = Crafter(args)

        # Neural Net# s: batch_size x board_x x board_y
        self.input_boards = Input(shape=(self.board_x, self.board_y, depth * self.args.observation_length))

        self.pi, self.v = self.build_model(self.input_boards)

        self.model = Model(inputs=self.input_boards, outputs=[self.pi, self.v])

        opt = Adam(args.optimizer.lr_init)
        if self.args.support_size > 0:
            self.model.compile(loss=['categorical_crossentropy'] * 2, optimizer=opt)
        else:
            self.model.compile(loss=['categorical_crossentropy', 'mean_squared_error'], optimizer=opt)
Exemple #8
0
class AlphaZeroGymNetwork:
    def __init__(self, game, args):
        # Network arguments
        self.x, self.y, self.planes = game.getDimensions()

        self.action_size = game.getActionSize()
        self.args = args
        self.crafter = Crafter(args)

        # s: batch_size x time x state_x x state_y
        self.observation_history = Input(shape=(self.x, self.y, self.planes *
                                                self.args.observation_length))
        # a: one hot encoded vector of shape batch_size x (state_x * state_y)
        self.action_tensor = Input(shape=(self.action_size, ))

        observations = Reshape(
            (self.x * self.y * self.planes * self.args.observation_length, ))(
                self.observation_history)

        self.pi, self.v = self.build_predictor(observations)

        self.model = Model(inputs=self.observation_history,
                           outputs=[self.pi, self.v])

        opt = Adam(args.optimizer.lr_init)
        if self.args.support_size > 0:
            self.model.compile(loss=['categorical_crossentropy'] * 2,
                               optimizer=opt)
        else:
            self.model.compile(
                loss=['categorical_crossentropy', 'mean_squared_error'],
                optimizer=opt)

    def build_predictor(self, observations):
        fc_sequence = self.crafter.dense_sequence(self.args.num_dense,
                                                  observations)

        pi = Dense(self.action_size, activation='softmax',
                   name='pi')(fc_sequence)
        v = Dense(1, activation='linear', name='v')(fc_sequence) \
            if self.args.support_size == 0 else \
            Dense(self.args.support_size * 2 + 1, activation='softmax', name='v')(fc_sequence)

        return pi, v
Exemple #9
0
class MuZeroAtariNetwork:

    def __init__(self, game, args):
        # Network arguments
        self.board_x, self.board_y, self.planes = game.getDimensions()
        self.latent_x, self.latent_y = (6, 6)
        self.action_size = game.getActionSize()
        self.args = args
        self.crafter = Crafter(args)

        assert self.action_size == self.latent_x * self.latent_y, \
            "The action space should be the same size as the latent space"

        # s: batch_size x time x state_x x state_y
        self.observation_history = Input(shape=(self.board_x, self.board_y, self.planes * self.args.observation_length))
        # a: one hot encoded vector of shape batch_size x (state_x * state_y)
        self.action_plane = Input(shape=(self.action_size,))
        # s': batch_size  x board_x x board_y x 1
        self.latent_state = Input(shape=(self.latent_x, self.latent_y, 1))

        action_plane = Reshape((self.latent_x, self.latent_y, 1))(self.action_plane)
        latent_state = Reshape((self.latent_x, self.latent_y, 1))(self.latent_state)

        self.s = self.build_encoder(self.observation_history)
        self.r, self.s_next = self.build_dynamics(latent_state, action_plane)

        self.pi, self.v = self.build_predictor(latent_state)

        self.encoder = Model(inputs=self.observation_history, outputs=self.s)
        self.dynamics = Model(inputs=[self.latent_state, self.action_plane], outputs=[self.r, self.s_next])
        self.predictor = Model(inputs=self.latent_state, outputs=[self.pi, self.v])

        self.forward = Model(inputs=self.observation_history, outputs=[self.s, *self.predictor(self.s)])
        self.recurrent = Model(inputs=[self.latent_state, self.action_plane],
                               outputs=[self.r, self.s_next, *self.predictor(self.s_next)])

    def build_encoder(self, observations):
        down_sampled = self.crafter.activation()(Conv2D(self.args.num_channels, 3, 2)(observations))
        down_sampled = self.crafter.conv_residual_tower(self.args.num_towers, down_sampled,
                                                        self.args.residual_left, self.args.residual_right, use_bn=False)
        down_sampled = self.crafter.activation()(Conv2D(self.args.num_channels, 3, 2)(down_sampled))
        down_sampled = AveragePooling2D(3, 2)(down_sampled)
        down_sampled = self.crafter.conv_residual_tower(self.args.num_towers, down_sampled,
                                                        self.args.residual_left, self.args.residual_right, use_bn=False)
        down_sampled = AveragePooling2D(3, 2)(down_sampled)

        latent_state = self.crafter.activation()((
            Conv2D(self.args.latent_depth, 3, padding='same', use_bias=False)(down_sampled)))
        latent_state = MinMaxScaler()(latent_state)

        return latent_state  # 2-dimensional 1-time step latent state. (Encodes history of images into one state).

    def build_dynamics(self, encoded_state, action_plane):
        stacked = Concatenate(axis=-1)([encoded_state, action_plane])
        reshaped = Reshape((self.latent_x, self.latent_y, -1))(stacked)
        down_sampled = self.crafter.conv_residual_tower(2 * self.args.num_towers, reshaped,
                                                        self.args.residual_left, self.args.residual_right, use_bn=False)

        latent_state = self.crafter.activation()((
            Conv2D(self.args.latent_depth, 3, padding='same', use_bias=False)(down_sampled)))
        flat = Flatten()(latent_state)
        latent_state = MinMaxScaler()(latent_state)

        r = Dense(self.args.support_size * 2 + 1, name='r')(flat)
        if not self.args.support_size:
            r = Activation('softmax')(r)

        return r, latent_state

    def build_predictor(self, latent_state):
        out_tensor = self.crafter.build_conv_block(latent_state, use_bn=False)

        pi = Dense(self.action_size, activation='softmax', name='pi')(out_tensor)
        v = Dense(self.args.support_size * 2 + 1, name='v')(out_tensor)
        v = Activation('softmax')(v) if self.args.support_size else Activation('tanh')(v)

        return pi, v
Exemple #10
0
class MuZeroGymNetwork:
    def __init__(self, game, args):
        # Network arguments
        self.x, self.y, self.planes = game.getDimensions()
        self.latents = args.latent_depth
        self.action_size = game.getActionSize()
        self.args = args
        self.crafter = Crafter(args)

        # s: batch_size x time x state_x x state_y
        self.observation_history = Input(shape=(self.x, self.y, self.planes))
        # a: one hot encoded vector of shape batch_size x (state_x * state_y)
        self.action_tensor = Input(shape=(self.action_size, ))
        # s': batch_size  x board_x x board_y x 1
        self.latent_state = Input(shape=(self.latents, 1))

        observations = Reshape(
            (self.x * self.y * self.planes, ))(self.observation_history)
        latent_state = Reshape((self.latents, ))(self.latent_state)

        # Build tensorflow computation graph
        self.s = self.build_encoder(observations)
        self.r, self.s_next = self.build_dynamics(latent_state,
                                                  self.action_tensor)
        self.pi, self.v = self.build_predictor(latent_state)

        self.encoder = Model(inputs=self.observation_history,
                             outputs=self.s,
                             name="r")
        self.dynamics = Model(inputs=[self.latent_state, self.action_tensor],
                              outputs=[self.r, self.s_next],
                              name='d')
        self.predictor = Model(inputs=self.latent_state,
                               outputs=[self.pi, self.v],
                               name='p')

        self.forward = Model(inputs=self.observation_history,
                             outputs=[self.s, *self.predictor(self.s)],
                             name='initial')
        self.recurrent = Model(
            inputs=[self.latent_state, self.action_tensor],
            outputs=[self.r, self.s_next, *self.predictor(self.s_next)],
            name='recurrent')

        # Decoder functionality.
        self.decoded_observations = self.build_decoder(latent_state)
        self.decoder = Model(inputs=self.latent_state,
                             outputs=self.decoded_observations,
                             name='decoder')

    def build_encoder(self, observations):
        fc_sequence = self.crafter.dense_sequence(self.args.num_dense,
                                                  observations)

        latent_state = Dense(self.latents, activation='linear',
                             name='s_0')(fc_sequence)
        latent_state = Activation('tanh')(
            latent_state) if self.latents <= 3 else MinMaxScaler()(
                latent_state)
        latent_state = Reshape((self.latents, 1))(latent_state)

        return latent_state  # 2-dimensional 1-time step latent state. (Encodes history of images into one state).

    def build_dynamics(self, encoded_state, action_plane):
        stacked = Concatenate()([encoded_state, action_plane])
        fc_sequence = self.crafter.dense_sequence(self.args.num_dense, stacked)

        latent_state = Dense(self.latents, activation='linear',
                             name='s_next')(fc_sequence)
        latent_state = Activation('tanh')(
            latent_state) if self.latents <= 3 else MinMaxScaler()(
                latent_state)
        latent_state = Reshape((self.latents, 1))(latent_state)

        r = Dense(1, activation='linear', name='r')(fc_sequence) \
            if self.args.support_size == 0 else \
            Dense(self.args.support_size * 2 + 1, activation='softmax', name='r')(fc_sequence)

        return r, latent_state

    def build_predictor(self, latent_state):
        fc_sequence = self.crafter.dense_sequence(self.args.num_dense,
                                                  latent_state)

        pi = Dense(self.action_size, activation='softmax',
                   name='pi')(fc_sequence)
        v = Dense(1, activation='linear', name='v')(fc_sequence) \
            if self.args.support_size == 0 else \
            Dense(self.args.support_size * 2 + 1, activation='softmax', name='v')(fc_sequence)

        return pi, v

    def build_decoder(self, latent_state):
        fc_sequence = self.crafter.dense_sequence(self.args.num_dense,
                                                  latent_state)

        out = Dense(self.x * self.y * self.planes, name='o_k')(fc_sequence)
        o = Reshape((self.x, self.y, self.planes))(out)
        return o
Exemple #11
0
class MuZeroHexNetwork:
    def __init__(self, game, args):
        # Network arguments
        self.board_x, self.board_y, self.planes = game.getDimensions()
        self.action_size = game.getActionSize()
        self.args = args
        self.crafter = Crafter(args)

        # s: batch_size x time x state_x x state_y
        self.observation_history = Input(shape=(self.board_x, self.board_y,
                                                self.planes *
                                                self.args.observation_length))
        # a: one hot encoded vector of shape batch_size x (state_x * state_y)
        self.action_plane = Input(shape=(self.action_size, ))
        # s': batch_size  x board_x x board_y x 1
        self.encoded_state = Input(shape=(self.board_x, self.board_y,
                                          self.args.latent_depth))

        # Format action vector to plane
        omit_resign = Lambda(lambda x: x[..., :-1],
                             output_shape=(self.board_x * self.board_y, ),
                             input_shape=(self.action_size, ))(
                                 self.action_plane)
        action_plane = Reshape((self.board_x, self.board_y, 1))(omit_resign)

        self.s = self.build_encoder(self.observation_history)
        self.r, self.s_next = self.build_dynamics(self.encoded_state,
                                                  action_plane)
        self.pi, self.v = self.build_predictor(self.encoded_state)

        self.encoder = Model(inputs=self.observation_history,
                             outputs=self.s,
                             name='h')
        self.dynamics = Model(inputs=[self.encoded_state, self.action_plane],
                              outputs=[self.r, self.s_next],
                              name='g')
        self.predictor = Model(inputs=self.encoded_state,
                               outputs=[self.pi, self.v],
                               name='f')

        self.forward = Model(inputs=self.observation_history,
                             outputs=[self.s, *self.predictor(self.s)])
        self.recurrent = Model(
            inputs=[self.encoded_state, self.action_plane],
            outputs=[self.r, self.s_next, *self.predictor(self.s_next)])

        # Decoder functionality.
        self.decoded_observations = self.build_decoder(self.encoded_state)
        self.decoder = Model(inputs=self.encoded_state,
                             outputs=self.decoded_observations,
                             name='decoder')

    def build_encoder(self, observations):
        conv = self.crafter.conv_tower(self.args.num_convs,
                                       observations,
                                       use_bn=False)
        res = self.crafter.conv_residual_tower(self.args.num_towers,
                                               conv,
                                               self.args.residual_left,
                                               self.args.residual_right,
                                               use_bn=False)

        latent_state = self.crafter.activation()(
            (Conv2D(self.args.latent_depth, 3, padding='same',
                    use_bias=False)(res)))
        latent_state = MinMaxScaler()(latent_state)

        return latent_state

    def build_dynamics(self, encoded_state, action_plane):
        stacked = Concatenate(axis=-1)([encoded_state, action_plane])
        reshaped = Reshape(
            (self.board_x, self.board_y, 1 + self.args.latent_depth))(stacked)

        conv = self.crafter.conv_tower(self.args.num_convs,
                                       reshaped,
                                       use_bn=False)
        res = self.crafter.conv_residual_tower(self.args.num_towers,
                                               conv,
                                               self.args.residual_left,
                                               self.args.residual_right,
                                               use_bn=False)

        latent_state = self.crafter.activation()(
            (Conv2D(self.args.latent_depth, 3, padding='same')(res)))
        latent_state = MinMaxScaler()(latent_state)

        flat = Flatten()(latent_state)

        # Cancel gradient/ predictions as r is not trained in boardgames.
        r = Dense(self.args.support_size * 2 + 1, name='r')(flat)
        r = Lambda(lambda x: x * 0)(r)

        return r, latent_state

    def build_predictor(self, latent_state):
        out_tensor = self.crafter.conv_tower(self.args.num_convs,
                                             latent_state,
                                             use_bn=False)

        small = self.crafter.activation()((Conv2D(32,
                                                  3,
                                                  padding='same',
                                                  use_bias=False)(out_tensor)))

        flat = Flatten()(small)

        fc = self.crafter.dense_sequence(1, flat)

        pi = Dense(self.action_size, activation='softmax', name='pi')(fc)
        v = Dense(1, activation='tanh', name='v')(fc) \
            if self.args.support_size == 0 else \
            Dense(self.args.support_size * 2 + 1, activation='softmax', name='v')(fc)

        return pi, v

    def build_decoder(self, latent_state):
        conv = self.crafter.conv_tower(self.args.num_convs,
                                       latent_state,
                                       use_bn=False)
        res = self.crafter.conv_residual_tower(self.args.num_towers,
                                               conv,
                                               self.args.residual_left,
                                               self.args.residual_right,
                                               use_bn=False)

        o = Conv2D(self.planes * self.args.observation_length,
                   3,
                   padding='same')(res)

        return o