Пример #1
0
    def train(self):
        cfg = self.config

        batch_size = cfg.batchsize

        x, s1, s2, y, _, _, _, _ = self.gridworld_data

        for epoch in range(int(cfg.epochs)):

            err, acc = 0.0, 0.0
            num_batches = int(x.shape[0] / batch_size)

            for i in range(0, x.shape[0], batch_size):
                j = i + batch_size
                if j <= x.shape[0]:
                    feeder = {
                        self.image: x[i:j],
                        self.path_ver: s1[i:j],
                        self.path_hor: s2[i:j],
                        self.true:
                        y[i * cfg.statebatchsize:j * cfg.statebatchsize]
                    }

                    _, e, a = self._session.run([self.opt, self.err, self.acc],
                                                feed_dict=feeder)
                    err += e
                    acc += a

            EventSystem.send(
                'algorithm.train', {
                    'epoch': epoch,
                    'train_err': err / num_batches,
                    'train_acc': acc / num_batches,
                })
Пример #2
0
 def _send_train_event(e, state, reward, done, qmax):
     EventSystem.send(
         'algorithm.train_episode', {
             'episode': e,
             'reward': reward,
             'qmax': np.mean(qmax),
             'state': state,
             'done': done
         })
Пример #3
0
    def run_experiment(self, cfg):
        with EventTimer('algorithm.train'), VinTrainLogger(), tf.Session():
            EventSystem.subscribe('algorithm.train', lambda _: alg.eval())
            alg = VinAlgorithm(cfg)
            alg.train()

            acc = alg.eval()
            EventSystem.send('train.summary', [
                "\n", "-" * 32,
                fields([
                    ['Accuracy', "%.2f%%" % (acc * 100)],
                ], -6)
            ])
            self.assertGreater(acc, .9)
Пример #4
0
    def eval(self):
        _, _, _, _, x, s1, s2, y = self.gridworld_data

        acc = self.acc.eval({
            self.image: x,
            self.path_ver: s1,
            self.path_hor: s2,
            self.true: y
        })

        EventSystem.send('algorithm.eval', {
            'acc': acc,
        })

        return acc
Пример #5
0
    def run_experiment(self, cfg):
        with Timer(), TrainLogger(), tf.Session():
            episodes, steps = cfg['train.episodes'], cfg['train.steps']
            EventSystem.subscribe('algorithm.train_episode', lambda _: alg.eval(10, steps))
            world = cfg['world.class'](cfg)
            alg = cfg['algorithm.class'](cfg, world)
            alg.train(episodes, steps)

            r, d = alg.eval(1000, steps)
            EventSystem.send('train.summary', ["\n", "-" * 32, fields([
                ['Reward', "%.2f" % r],
                ['Done', "%.0f%%" % (d*100)]
            ], -6)])

            self.assertGreater(r, 100)
            self.assertGreater(d, .10)
Пример #6
0
    def eval(self, episodes, steps):
        state = None
        reward = 0
        done = 0
        for ep in range(episodes):
            state = self.world.reset()
            for __ in range(steps):
                a = self._predict(state)
                state, r, d = self.world.step(a)
                reward += r
                if d:
                    done += 1
                    break

        reward /= float(episodes)
        done /= float(episodes)
        EventSystem.send('algorithm.eval', {
            'ave_reward': reward,
            'ave_done': done,
            'state': state,
        })
        return reward, done
Пример #7
0
 def test_lambda_subscription(self):
     EventSystem.unsubscribe_all()
     EventSystem.subscribe('test_event', lambda data: print(data))
     num = EventSystem.send('test_event', ([1, 2], "3"))
     self.assertEqual(num, 1)
Пример #8
0
 def test_object_subscription(self):
     EventSystem.unsubscribe_all()
     _ = TestSubscriber()
     num = EventSystem.send('test_event', ([1, 2], "3"))
     self.assertEqual(num, 1)
Пример #9
0
 def _on_train(self, _):
     e = 'train'
     self._add_history(e)
     d = self._get_last_interval(e)
     if d is not None:
         EventSystem.send('timer', {e: d})
Пример #10
0
 def _unsubcribe_all(self):
     EventSystem.unsubscribe(subscriber=self)
Пример #11
0
 def _subscribe(self, event: str, method: classmethod):
     EventSystem.subscribe(event, method, self)