Exemplo n.º 1
0
    def act_and_train(self, obs, reward):

        statevar = self.batch_states([obs], np, self.phi)

        self.past_rewards[self.t - 1] = reward

        if self.t - self.t_start == self.t_max:
            self.update(statevar)

        self.past_states[self.t] = statevar
        if isinstance(self.target_q_function, Recurrent):
            # Evaluate it to update states
            self.target_q_function(statevar)
        qout = self.q_function(statevar)
        action = self.explorer.select_action(
            self.t_global.value,
            lambda: qout.greedy_actions.data[0],
            action_value=qout)
        q = qout.evaluate_actions(np.asarray([action]))
        self.past_action_values[self.t] = q
        self.t += 1
        self.average_q += ((1 - self.average_q_decay) *
                           (float(q.data[0]) - self.average_q))
        with self.t_global.get_lock():
            self.t_global.value += 1
            t_global = self.t_global.value

        if t_global % self.i_target == 0:
            self.logger.debug('target synchronized t_global:%s t_local:%s',
                              t_global, self.t)
            copy_param.copy_param(self.target_q_function, self.q_function)

        return action
Exemplo n.º 2
0
 def load(self, dirname):
     logger.debug('Load parameters from %s', dirname)
     super().load(dirname)
     copy_param.copy_param(target_link=self.shared_generator,
                           source_link=self.generator)
     copy_param.copy_param(target_link=self.shared_discriminator,
                           source_link=self.discriminator)
Exemplo n.º 3
0
    def test_copy_param_type_check(self):
        a = L.Linear(None, 5)
        b = L.Linear(1, 5)

        with self.assertRaises(TypeError):
            # Copy b's parameters to a, but since `a` parameter is not
            # initialized, it should raise error.
            copy_param.copy_param(a, b)
Exemplo n.º 4
0
    def test_copy_param_scalar(self):
        a = chainer.Chain()
        with a.init_scope():
            a.p = chainer.Parameter(np.array(1))
        b = chainer.Chain()
        with b.init_scope():
            b.p = chainer.Parameter(np.array(2))

        self.assertNotEqual(a.p.array, b.p.array)

        # Copy b's parameters to a
        copy_param.copy_param(a, b)

        self.assertEqual(a.p.array, b.p.array)
Exemplo n.º 5
0
    def test_copy_param(self):
        a = L.Linear(1, 5)
        b = L.Linear(1, 5)

        s = chainer.Variable(np.random.rand(1, 1).astype(np.float32))
        a_out = list(a(s).array.ravel())
        b_out = list(b(s).array.ravel())
        self.assertNotEqual(a_out, b_out)

        # Copy b's parameters to a
        copy_param.copy_param(a, b)

        a_out_new = list(a(s).array.ravel())
        b_out_new = list(b(s).array.ravel())
        self.assertEqual(a_out_new, b_out)
        self.assertEqual(b_out_new, b_out)
Exemplo n.º 6
0
 def load(self, dirname):
     super().load(dirname)
     if self.train_async:
         copy_param.copy_param(target_link=self.shared_model,
                               source_link=self.model)
Exemplo n.º 7
0
 def sync_parameters(self):
     copy_param.copy_param(target_link=self.model,
                           source_link=self.shared_model)
Exemplo n.º 8
0
 def sync_parameters(self):
     copy_param.copy_param(target_link=self.model,
                           source_link=self.shared_model)
     copy_param.soft_copy_param(target_link=self.shared_average_model,
                                source_link=self.model,
                                tau=1 - self.trust_region_alpha)
Exemplo n.º 9
0
 def sync_parameters(self):
     copy_param.copy_param(target_link=self.q_function,
                           source_link=self.shared_q_function)
Exemplo n.º 10
0
 def load(self, dirname):
     super().load(dirname)
     copy_param.copy_param(target_link=self.shared_q_function,
                           source_link=self.q_function)
Exemplo n.º 11
0
 def sync_parameters(self):
     copy_param.copy_param(target_link=self.generator,
                           source_link=self.shared_generator)
     copy_param.copy_param(target_link=self.discriminator,
                           source_link=self.shared_discriminator)