Example #1
0
class PolicyNN():
    """ Neural net for policy approximation function.

    Policy parameterized by Gaussian means and variances. NN outputs mean
     action based on observation. Trainable variables hold log-variances
     for each action dimension (i.e. variances not determined by NN).
    """
    def __init__(self, obs_dim, act_dim, hid1_mult, kl_targ, init_logvar, eta):
        super(PolicyNN, self).__init__()
        self.kl_targ = kl_targ
        self.eta = eta
        self.obs_dim = obs_dim
        self.hid1_mult = hid1_mult
        self.act_dim = act_dim
        self.beta = []
        self.logprob = LogProb()
        self.kl_entropy = KLEntropy()

        self.batch_sz = 1
        self.init_logvar = init_logvar
        self.hid1_units = obs_dim * hid1_mult
        self.hid3_units = act_dim * 40  # 10 empirically determined
        self.hid2_units = int(np.sqrt(self.hid1_units * self.hid3_units))
        self.lr = 9e-4 / np.sqrt(
            self.hid2_units)  # 9e-4 empirically determined

        # heuristic to set learning rate based on NN size (tuned on 'Hopper-v1')
        self.dense1 = Dense(self.hid1_units,
                            activation='tanh',
                            input_shape=(self.obs_dim, ))
        self.dense2 = Dense(self.hid2_units,
                            activation='tanh',
                            input_shape=(self.hid1_units, ))
        self.dense3 = Dense(self.hid3_units,
                            activation='tanh',
                            input_shape=(self.hid2_units, ))
        self.dense4 = Dense(self.act_dim, input_shape=(self.hid3_units, ))

        self._build_model_flag = 'build'
        self.model = []
        self.logvars = []

    def call(self, inputs):
        obs, act, adv, old_means, old_logvars, old_logp = inputs
        new_means, new_logvars = self.call_polNN(obs, self.hid1_units,
                                                 self.hid2_units,
                                                 self.hid3_units)
        new_logp = self.logprob.call_LogP([act, new_means, new_logvars])
        kl, entropy = self.kl_entropy.call_KLE(
            [old_means, old_logvars, new_means, new_logvars])
        loss1 = -K.mean(adv * K.exp(new_logp - old_logp))
        loss2 = K.mean(self.beta * kl)
        # TODO - Take mean before or after hinge loss?
        loss3 = self.eta * K.square(
            K.maximum(0.0,
                      K.mean(kl) - 2.0 * self.kl_targ))
        self.add_loss(loss1 + loss2 + loss3)

        return [kl, entropy]

    def call_polNN(self, inputs, hid1_units, hid2_units, hid3_units):
        if self._build_model_flag == 'build':
            inputs = Input(shape=(self.obs_dim, ), dtype='float32')

        y = self.dense1(inputs)
        y = self.dense2(y)
        y = self.dense3(y)
        means = self.dense4(y)

        if self._build_model_flag == 'build':
            self.model = Model(inputs=inputs, outputs=means)
            optimizer = Adam(self.lr)
            self.beta = self.model.add_weight('beta',
                                              initializer='zeros',
                                              trainable=False)
            self.model.compile(optimizer=optimizer, loss='mse')

            logvar_speed = (10 * hid3_units) // 48
            self.logvars = self.model.add_weight(shape=(logvar_speed,
                                                        self.act_dim),
                                                 trainable=True,
                                                 initializer='zeros')
            print(
                'Policy Params -- h1: {}, h2: {}, h3: {}, lr: {:.3g}, logvar_speed: {}'
                .format(hid1_units, hid2_units, hid3_units, self.lr,
                        logvar_speed))
            self._build_model_flag = 'run'

        logvars = K.sum(self.logvars, axis=0, keepdims=True) + self.init_logvar
        logvars = K.tile(logvars, (self.batch_sz, 1))

        return [means, logvars]

    def get_lr(self):
        return self.lr

    def get_model(self):
        return self.model

    def get_hidden_layers(self):
        return self.hid1_units, self.hid2_units, self.hid3_units