Beispiel #1
0
class Runner:
    def __init__(self, T, K, C, sigma2_w, opt_iters, R0_init_scale):
        self.T = T
        self.K = K
        self.C = C
        self.sigma2_w = sigma2_w
        self.opt_iters = opt_iters
        self.R0_init_scale = R0_init_scale

        self.evaluator = Evaluator(K=self.K, C=self.C)
        self.DKM = MemoryWriterDKMBatchIterative(K=self.K,
                                                 C=self.C,
                                                 qw_sigma2=self.sigma2_w)
        self.VBM = MemoryWriterVBMBatchIterative(K=self.K, C=self.C)

        self.R0 = np.random.normal(loc=0.0,
                                   scale=self.R0_init_scale,
                                   size=(self.K, self.C))
        self.U0 = np.eye(K)

        self.pM = DistributionalMemory(R=self.R0, U=self.U0)

        self.orcl = Oracle(K=K, C=C)

    def run(self, with_info=False):
        Z = self.orcl.generate_episode(episode_len=self.T,
                                       memory_state=self.pM)
        DKM_qW, DKM_qM = self.DKM.write_episode(Z=Z,
                                                pM=self.pM,
                                                opt_iters=self.opt_iters)
        VBM_qW, VBM_qM = self.VBM.write_episode(Z=Z,
                                                pM=self.pM,
                                                opt_iters=self.opt_iters)

        DKM_elbo_per_frame = self.evaluator.compute_elbo_per_frame(Z=Z,
                                                                   qW=DKM_qW,
                                                                   qM=DKM_qM,
                                                                   pM=self.pM)
        VBM_elbo_per_frame = self.evaluator.compute_elbo_per_frame(Z=Z,
                                                                   qW=VBM_qW,
                                                                   qM=VBM_qM,
                                                                   pM=self.pM)

        info = {}
        if with_info:
            DKM_tr_Uf = np.sum(np.diag(DKM_qM.U))
            VBM_tr_Uf = np.sum(np.diag(VBM_qM.U))
            info['DKM_tr_Uf'] = DKM_tr_Uf
            info['VBM_tr_Uf'] = VBM_tr_Uf

        return DKM_elbo_per_frame, VBM_elbo_per_frame, info