def __init__(self, envs):
        self.envs = [EnvContainer(env) for env in envs]

        self.n_envs = len(self.envs)
        self.n_abstract_mdps = 2
        self.abstract_dim = 4
        self.state_dim = 4
        self.states = []
        self.state_to_idx = None

        all_encoder_lst = nn.ModuleList()
        for i in range(self.n_envs):
            encoder_lst = nn.ModuleList()
            for j in range(self.n_abstract_mdps):
                encoder = Mlp((128, 128, 128),
                              output_size=self.abstract_dim,
                              input_size=self.state_dim,
                              output_activation=F.softmax,
                              layer_norm=True)
                encoder.apply(init_weights)
                encoder_lst.append(encoder)

            all_encoder_lst.append(encoder_lst)
        self.all_encoder_lst = all_encoder_lst

        self.optimizer = optim.Adam(self.all_encoder_lst.parameters(), lr=1e-4)
예제 #2
0
class AbstractMDPsContrastive:
    def __init__(self, envs):
        self.envs = [EnvContainer(env) for env in envs]

        self.n_abstract_mdps = 2
        self.abstract_dim = 4
        self.state_dim = 4
        self.states = []
        self.state_to_idx = None

        self.encoder = Mlp((128, 128, 128), output_size=self.abstract_dim, input_size=self.state_dim,
                           output_activation=F.softmax, layer_norm=True)

        self.encoder.apply(init_weights)
        self.transitions = nn.Parameter(torch.zeros((self.abstract_dim, self.abstract_dim)))

        self.optimizer = optim.Adam(self.encoder.parameters(), lr=1e-4)

    def train(self, max_epochs=100):

        data_lst = []
        for i, env in enumerate(self.envs):
            d = np.array(env.transitions)
            d = np.concatenate([d, np.zeros((d.shape[0], 1)) + i], 1)
            data_lst.append(d)

        all_data = from_numpy(np.concatenate(data_lst, 0))

        dataset = data.TensorDataset(all_data)
        dataloader = data.DataLoader(dataset, batch_size=32, shuffle=True)

        mixture = from_numpy(np.ones((len(self.envs), self.n_abstract_mdps)) / self.n_abstract_mdps)
        all_abstract_t = from_numpy(np.ones((self.n_abstract_mdps, self.abstract_dim, self.abstract_dim)) / self.abstract_dim)
        for epoch in range(1, max_epochs + 1):
            stats, abstract_t, y1 = self.train_epoch(dataloader, epoch, mixture, all_abstract_t)
            #if stats['Loss'] < 221.12:
            #   break
            print(stats)
        print(y1[:5])
        print(abstract_t)

    def kl(self, dist1, dist2):
        return (dist1 * (torch.log(dist1 + 1e-8) - torch.log(dist2 + 1e-8))).sum(1)

    def entropy(self, dist):
        return -(dist * torch.log(dist + 1e-8)).sum(-1)



    def compute_abstract_t(self, env, hardcounts=False):
        trans = env.transitions_np
        s1 = trans[:, :4]
        s2 = trans[:, 4:]
        all_states = self.encoder(from_numpy(env.all_states()))
        y1 = self.encoder(from_numpy(s1))
        y2 = self.encoder(from_numpy(s2))
        y3 = self.encoder(from_numpy(env.sample_states(s1.shape[0])))

        # Hardcode if y1 and y2 were what you wanted
        options = ['optimal', 'onestate', 'uniform', 'onestate_uniform']
        option = options[3]
        #y1 = env.true_values(s1, option=option)
        #y2 = env.true_values(s2, option=option)

        a_t = from_numpy(np.zeros((self.abstract_dim, self.abstract_dim)))
        for i in range(self.abstract_dim):
            for j in range(self.abstract_dim):
                if hardcounts:
                    a_t[i, j] += ((y1.max(-1)[1] == i).float() * (y2.max(-1)[1] == j).float()).sum()
                else:
                    a_t[i, j] += (y1[:, i] * y2[:, j]).sum(0)

        n_a_t = from_numpy(np.zeros((self.abstract_dim, self.abstract_dim)))
        for i in range(self.abstract_dim):
            n_a_t[i, :] += a_t[i, :] / (a_t[i, :].sum() + 1e-8)


        return n_a_t, y1, y2, y3, all_states


    def train_epoch(self, dataloader, epoch, mixture, all_abstract_t):
        stats = OrderedDict([('Loss', 0),
                      ('Converge', 0),
                      ('Diverge', 0),
                      ('Entropy1', 0),
                      ('Entropy2', 0),
                      ('Dev', 0)
                      ])

        data = [self.compute_abstract_t(env, hardcounts=False) for env in self.envs]
        abstract_t = [x[0] for x in data]
        y1 = torch.cat([x[1] for x in data], 0)
        y2 = torch.cat([x[2] for x in data], 0)
        y3 = torch.cat([x[4] for x in data], 0)


        a_loss = from_numpy(np.zeros(1))
        for i in range(self.abstract_dim):
            for j in range(self.abstract_dim):
                a_loss += (y1[:, i] * y2[:, j] * torch.log(abstract_t[0][i, j].detach() + 1e-8)).sum()


        entropy1 = self.entropy(y3.sum(0) / y3.sum())   # maximize entropy of spread over all data points, marginal entropy

        entropy2 = self.entropy(y3).mean() # minimize conditional entropy over single data point
        loss = -a_loss - 1000*entropy1

        loss.backward()
        nn.utils.clip_grad_norm(self.encoder.parameters(), 5.0)
        self.optimizer.step()

        stats['Loss'] += loss.item()
        stats['Entropy1'] += entropy1.item()
        stats['Entropy2'] += entropy2.item()
        return stats, abstract_t[0], y1

    def gen_plot(self):
        plots = [env.gen_plot(self.encoder) for env in self.envs]

        plots = np.concatenate(plots, 1)

        plt.imshow(plots)
        #plt.savefig('/home/jcoreyes/abstract/rlkit/examples/abstractmdp/fig1.png')
        plt.show()