def __init__(self, envs): self.envs = [EnvContainer(env) for env in envs] self.n_envs = len(self.envs) self.n_abstract_mdps = 2 self.abstract_dim = 4 self.state_dim = 4 self.states = [] self.state_to_idx = None all_encoder_lst = nn.ModuleList() for i in range(self.n_envs): encoder_lst = nn.ModuleList() for j in range(self.n_abstract_mdps): encoder = Mlp((128, 128, 128), output_size=self.abstract_dim, input_size=self.state_dim, output_activation=F.softmax, layer_norm=True) encoder.apply(init_weights) encoder_lst.append(encoder) all_encoder_lst.append(encoder_lst) self.all_encoder_lst = all_encoder_lst self.optimizer = optim.Adam(self.all_encoder_lst.parameters(), lr=1e-4)
class AbstractMDPsContrastive: def __init__(self, envs): self.envs = [EnvContainer(env) for env in envs] self.n_abstract_mdps = 2 self.abstract_dim = 4 self.state_dim = 4 self.states = [] self.state_to_idx = None self.encoder = Mlp((128, 128, 128), output_size=self.abstract_dim, input_size=self.state_dim, output_activation=F.softmax, layer_norm=True) self.encoder.apply(init_weights) self.transitions = nn.Parameter(torch.zeros((self.abstract_dim, self.abstract_dim))) self.optimizer = optim.Adam(self.encoder.parameters(), lr=1e-4) def train(self, max_epochs=100): data_lst = [] for i, env in enumerate(self.envs): d = np.array(env.transitions) d = np.concatenate([d, np.zeros((d.shape[0], 1)) + i], 1) data_lst.append(d) all_data = from_numpy(np.concatenate(data_lst, 0)) dataset = data.TensorDataset(all_data) dataloader = data.DataLoader(dataset, batch_size=32, shuffle=True) mixture = from_numpy(np.ones((len(self.envs), self.n_abstract_mdps)) / self.n_abstract_mdps) all_abstract_t = from_numpy(np.ones((self.n_abstract_mdps, self.abstract_dim, self.abstract_dim)) / self.abstract_dim) for epoch in range(1, max_epochs + 1): stats, abstract_t, y1 = self.train_epoch(dataloader, epoch, mixture, all_abstract_t) #if stats['Loss'] < 221.12: # break print(stats) print(y1[:5]) print(abstract_t) def kl(self, dist1, dist2): return (dist1 * (torch.log(dist1 + 1e-8) - torch.log(dist2 + 1e-8))).sum(1) def entropy(self, dist): return -(dist * torch.log(dist + 1e-8)).sum(-1) def compute_abstract_t(self, env, hardcounts=False): trans = env.transitions_np s1 = trans[:, :4] s2 = trans[:, 4:] all_states = self.encoder(from_numpy(env.all_states())) y1 = self.encoder(from_numpy(s1)) y2 = self.encoder(from_numpy(s2)) y3 = self.encoder(from_numpy(env.sample_states(s1.shape[0]))) # Hardcode if y1 and y2 were what you wanted options = ['optimal', 'onestate', 'uniform', 'onestate_uniform'] option = options[3] #y1 = env.true_values(s1, option=option) #y2 = env.true_values(s2, option=option) a_t = from_numpy(np.zeros((self.abstract_dim, self.abstract_dim))) for i in range(self.abstract_dim): for j in range(self.abstract_dim): if hardcounts: a_t[i, j] += ((y1.max(-1)[1] == i).float() * (y2.max(-1)[1] == j).float()).sum() else: a_t[i, j] += (y1[:, i] * y2[:, j]).sum(0) n_a_t = from_numpy(np.zeros((self.abstract_dim, self.abstract_dim))) for i in range(self.abstract_dim): n_a_t[i, :] += a_t[i, :] / (a_t[i, :].sum() + 1e-8) return n_a_t, y1, y2, y3, all_states def train_epoch(self, dataloader, epoch, mixture, all_abstract_t): stats = OrderedDict([('Loss', 0), ('Converge', 0), ('Diverge', 0), ('Entropy1', 0), ('Entropy2', 0), ('Dev', 0) ]) data = [self.compute_abstract_t(env, hardcounts=False) for env in self.envs] abstract_t = [x[0] for x in data] y1 = torch.cat([x[1] for x in data], 0) y2 = torch.cat([x[2] for x in data], 0) y3 = torch.cat([x[4] for x in data], 0) a_loss = from_numpy(np.zeros(1)) for i in range(self.abstract_dim): for j in range(self.abstract_dim): a_loss += (y1[:, i] * y2[:, j] * torch.log(abstract_t[0][i, j].detach() + 1e-8)).sum() entropy1 = self.entropy(y3.sum(0) / y3.sum()) # maximize entropy of spread over all data points, marginal entropy entropy2 = self.entropy(y3).mean() # minimize conditional entropy over single data point loss = -a_loss - 1000*entropy1 loss.backward() nn.utils.clip_grad_norm(self.encoder.parameters(), 5.0) self.optimizer.step() stats['Loss'] += loss.item() stats['Entropy1'] += entropy1.item() stats['Entropy2'] += entropy2.item() return stats, abstract_t[0], y1 def gen_plot(self): plots = [env.gen_plot(self.encoder) for env in self.envs] plots = np.concatenate(plots, 1) plt.imshow(plots) #plt.savefig('/home/jcoreyes/abstract/rlkit/examples/abstractmdp/fig1.png') plt.show()