def get_non_linear_results( ob_space, encoder, latent_dim, batch_size=128, num_batches=10000, ) -> NonLinearResults: state_dim = ob_space.low.size decoder = Mlp( hidden_sizes=[64, 64], output_size=state_dim, input_size=latent_dim, ) decoder.to(ptu.device) optimizer = optim.Adam(decoder.parameters()) initial_loss = last_10_percent_loss = 0 for i in range(num_batches): states = get_batch(ob_space, batch_size) x = ptu.from_numpy(states) z = encoder(x) x_hat = decoder(z) loss = ((x - x_hat)**2).mean() optimizer.zero_grad() loss.backward() optimizer.step() if i == 0: initial_loss = ptu.get_numpy(loss) if i == int(num_batches * 0.9): last_10_percent_loss = ptu.get_numpy(loss) eval_states = get_batch(ob_space, batch_size=2**15) x = ptu.from_numpy(eval_states) z = encoder(x) x_hat = decoder(z) reconstruction = ptu.get_numpy(x_hat) loss = ((eval_states - reconstruction)**2).mean() last_10_percent_contribution = ( (last_10_percent_loss - loss)) / (initial_loss - loss) del decoder, optimizer return NonLinearResults( loss=loss, initial_loss=initial_loss, last_10_percent_contribution=last_10_percent_contribution, )
logger = SummaryWriter(comment='_' + args.env + '_rnd') # prepare networks M = args.layer_size network = Mlp( input_size=obs_dim + action_dim, output_size=1, hidden_sizes=[M, M], ).to(device) target_network = Mlp( input_size=obs_dim + action_dim, output_size=1, hidden_sizes=[M, M], ).to(device) for param in target_network.parameters(): param.requires_grad = False optimizer = optim.Adam(network.parameters(), lr=args.lr) best_loss = np.Inf for epoch in range(args.epochs): t_loss = train(network, target_network, dataloader, optimizer, epoch, use_cuda) if use_tb: logger.add_scalar(log_dir + '/train-loss', t_loss, epoch) if t_loss < best_loss: best_loss = t_loss file_name = 'models/{}_{}.pt'.format(timestamp, args.env) print('Writing model checkpoint, loss:{:.2g}'.format(t_loss)) print('Writing model checkpoint : {}'.format(file_name))
class AbstractMDPsContrastive: def __init__(self, envs): self.envs = [EnvContainer(env) for env in envs] self.n_abstract_mdps = 2 self.abstract_dim = 4 self.state_dim = 4 self.states = [] self.state_to_idx = None self.encoder = Mlp((128, 128, 128), output_size=self.abstract_dim, input_size=self.state_dim, output_activation=F.softmax, layer_norm=True) self.encoder.apply(init_weights) self.transitions = nn.Parameter(torch.zeros((self.abstract_dim, self.abstract_dim))) self.optimizer = optim.Adam(self.encoder.parameters(), lr=1e-4) def train(self, max_epochs=100): data_lst = [] for i, env in enumerate(self.envs): d = np.array(env.transitions) d = np.concatenate([d, np.zeros((d.shape[0], 1)) + i], 1) data_lst.append(d) all_data = from_numpy(np.concatenate(data_lst, 0)) dataset = data.TensorDataset(all_data) dataloader = data.DataLoader(dataset, batch_size=32, shuffle=True) mixture = from_numpy(np.ones((len(self.envs), self.n_abstract_mdps)) / self.n_abstract_mdps) all_abstract_t = from_numpy(np.ones((self.n_abstract_mdps, self.abstract_dim, self.abstract_dim)) / self.abstract_dim) for epoch in range(1, max_epochs + 1): stats, abstract_t, y1 = self.train_epoch(dataloader, epoch, mixture, all_abstract_t) #if stats['Loss'] < 221.12: # break print(stats) print(y1[:5]) print(abstract_t) def kl(self, dist1, dist2): return (dist1 * (torch.log(dist1 + 1e-8) - torch.log(dist2 + 1e-8))).sum(1) def entropy(self, dist): return -(dist * torch.log(dist + 1e-8)).sum(-1) def compute_abstract_t(self, env, hardcounts=False): trans = env.transitions_np s1 = trans[:, :4] s2 = trans[:, 4:] all_states = self.encoder(from_numpy(env.all_states())) y1 = self.encoder(from_numpy(s1)) y2 = self.encoder(from_numpy(s2)) y3 = self.encoder(from_numpy(env.sample_states(s1.shape[0]))) # Hardcode if y1 and y2 were what you wanted options = ['optimal', 'onestate', 'uniform', 'onestate_uniform'] option = options[3] #y1 = env.true_values(s1, option=option) #y2 = env.true_values(s2, option=option) a_t = from_numpy(np.zeros((self.abstract_dim, self.abstract_dim))) for i in range(self.abstract_dim): for j in range(self.abstract_dim): if hardcounts: a_t[i, j] += ((y1.max(-1)[1] == i).float() * (y2.max(-1)[1] == j).float()).sum() else: a_t[i, j] += (y1[:, i] * y2[:, j]).sum(0) n_a_t = from_numpy(np.zeros((self.abstract_dim, self.abstract_dim))) for i in range(self.abstract_dim): n_a_t[i, :] += a_t[i, :] / (a_t[i, :].sum() + 1e-8) return n_a_t, y1, y2, y3, all_states def train_epoch(self, dataloader, epoch, mixture, all_abstract_t): stats = OrderedDict([('Loss', 0), ('Converge', 0), ('Diverge', 0), ('Entropy1', 0), ('Entropy2', 0), ('Dev', 0) ]) data = [self.compute_abstract_t(env, hardcounts=False) for env in self.envs] abstract_t = [x[0] for x in data] y1 = torch.cat([x[1] for x in data], 0) y2 = torch.cat([x[2] for x in data], 0) y3 = torch.cat([x[4] for x in data], 0) a_loss = from_numpy(np.zeros(1)) for i in range(self.abstract_dim): for j in range(self.abstract_dim): a_loss += (y1[:, i] * y2[:, j] * torch.log(abstract_t[0][i, j].detach() + 1e-8)).sum() entropy1 = self.entropy(y3.sum(0) / y3.sum()) # maximize entropy of spread over all data points, marginal entropy entropy2 = self.entropy(y3).mean() # minimize conditional entropy over single data point loss = -a_loss - 1000*entropy1 loss.backward() nn.utils.clip_grad_norm(self.encoder.parameters(), 5.0) self.optimizer.step() stats['Loss'] += loss.item() stats['Entropy1'] += entropy1.item() stats['Entropy2'] += entropy2.item() return stats, abstract_t[0], y1 def gen_plot(self): plots = [env.gen_plot(self.encoder) for env in self.envs] plots = np.concatenate(plots, 1) plt.imshow(plots) #plt.savefig('/home/jcoreyes/abstract/rlkit/examples/abstractmdp/fig1.png') plt.show()
# Logger use_tb = args.log_dir is not None log_dir = args.log_dir if use_tb: logger = SummaryWriter(comment='_' + args.env + '_bc') # prepare networks M = args.layer_size network = Mlp( input_size=obs_dim, output_size=action_dim, hidden_sizes=[M, M], output_activation=F.tanh, ).to(device) optimizer = optim.Adam(network.parameters(), lr=args.lr) epch = 0 if args.load_model is not None: if os.path.isfile(args.load_model): checkpoint = torch.load(args.load_model) network.load_state_dict(checkpoint['network_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) t_loss = checkpoint['train_loss'] epch = checkpoint['epoch'] print('Loading model: {}. Resuming from epoch: {}'.format( args.load_model, epch)) else: print('Model: {} not found'.format(args.load_model)) best_loss = np.Inf
class ToolsEnv(FoodEnvBase): """ Empty grid environment, no obstacles, sparse reward """ class Actions(IntEnum): # Absolute directions west = 0 east = 1 north = 2 south = 3 # collect (but don't consume) an item mine = 4 # consume a stored food item to boost health (does nothing if no stored food) eat = 5 # place objects down place = 6 def __init__( self, grid_size=32, health_cap=100, food_rate=4, max_pantry_size=50, obs_vision=False, food_rate_decay=0.0, init_resources=None, gen_resources=True, # adjust lifespans to hold the number of resources const, based on resource gen probs. used for env sweeps. fixed_expected_resources=False, resource_prob=None, resource_prob_decay=None, resource_prob_min=None, place_schedule=None, make_rtype='sparse', rtype='default', lifespans=None, default_lifespan=0, task=None, rnd=False, cbe=False, seed_val=1, fixed_reset=False, end_on_task_completion=True, can_die=False, include_health=False, include_num_objs=False, replenish_empty_resources=None, replenish_low_resources=None, # nonzero for reset case, 0 for reset free time_horizon=0, reset_hitting=True, **kwargs): assert task is not None, 'Must specify task of form "make berry", "navigate 3 5", "pickup axe", etc.' # a step count that isn't reset across resets self.env_shaping_step_count = 0 self.init_resources = init_resources or {} self.lifespans = lifespans or {} self.default_lifespan = default_lifespan self.food_rate_decay = food_rate_decay self.interactions = { # the 2 ingredients must be in alphabetical order ('metal', 'wood'): 'axe', ('axe', 'tree'): 'berry', } self.ingredients = {v: k for k, v in self.interactions.items()} self.gen_resources = gen_resources self.resource_prob = resource_prob or {} self.resource_prob_decay = resource_prob_decay or {} self.resource_prob_min = resource_prob_min or {} self.lifelong = time_horizon == 0 # tuple of (bump, schedule), giving place_radius at time t = (t + bump) // schedule self.place_schedule = place_schedule # store whether the place_schedule has reached full grid radius yet, at which point it'll stop calling each time self.full_grid = False self.human_radius = None # human controlled radius for placing resources. Will do nothing if None self.include_health = include_health self.include_num_objs = include_num_objs self.replenish_empty_resources = replenish_empty_resources or [] self.replenish_low_resources = replenish_low_resources or {} self.time_horizon = time_horizon # adjust lifespans if needed: if fixed_expected_resources: for type, num in self.init_resources.items(): if type not in self.lifespans: if not self.resource_prob.get(type, 0): self.lifespans[type] = self.default_lifespan else: self.lifespans[type] = int(num / self.resource_prob[type]) self.seed_val = seed_val self.fixed_reset = fixed_reset if not hasattr(self, 'object_to_idx'): self.object_to_idx = { 'empty': 0, 'wall': 1, 'food': 2, 'wood': 3, 'metal': 4, 'tree': 5, 'axe': 6, 'berry': 7 } # TASK stuff self.task = task self.task = task.split( ) # e.g. 'pickup axe', 'navigate 3 5', 'make berry', 'make_lifelong axe' self.make_sequence = self.get_make_sequence() self.onetime_reward_sequence = [ False for _ in range(len(self.make_sequence)) ] self.make_rtype = make_rtype self.max_make_idx = -1 self.last_idx = -1 # most recent obj type made self.made_obj_type = None # obj type made in the last step, if any self.just_made_obj_type = None # obj type most recently placed on, if any self.last_placed_on = None # obj type placed on in the last step, if any self.just_placed_on = None # obj type just picked up, if any self.just_mined_type = None # used for task 'make_lifelong' self.num_solves = 0 self.end_on_task_completion = end_on_task_completion self.end_on_task_completion = not self.lifelong self.reset_hitting = reset_hitting self.hitting_time = 0 # Exploration! assert not (cbe and rnd), "can't have both CBE and RND" # CBE self.cbe = cbe # RND self.rnd = rnd self.obs_count = {} # below two variables are to keep running count of stdev for RND normalization self.sum_rnd_loss = 0 self.sum_square_rnd_loss = 0 self.sum_rnd_obs = 0 self.sum_square_rnd_obs = 0 self.rnd_loss = MSELoss() # food self.pantry = [] self.max_pantry_size = max_pantry_size self.actions = ToolsEnv.Actions # stores info about picked up items self.info_last = { 'pickup_%s' % k: 0 for k in self.object_to_idx.keys() if k not in ['empty', 'wall', 'tree'] } self.info_last.update( {'made_%s' % v: 0 for v in self.interactions.values()}) # stores visited locations for heat map self.visit_count = np.zeros((grid_size, grid_size), dtype=np.uint32) super().__init__(grid_size=grid_size, health_cap=health_cap, food_rate=food_rate, obs_vision=obs_vision, can_die=can_die, **kwargs) shape = None # TODO some of these shapes are wrong. fix by running through each branch and getting empirical obs shape if self.only_partial_obs: if self.agent_view_size == 5: shape = (273, ) elif self.agent_view_size == 7: shape = (465, ) elif self.grid_size == 32: if self.obs_vision: shape = (58569, ) else: if self.fully_observed: shape = (2067, ) else: shape = (2163, ) elif self.grid_size == 16: if not self.obs_vision: if self.fully_observed: shape = (631, ) elif self.grid_size == 7: if not self.obs_vision and self.fully_observed: shape = (60, ) elif self.grid_size == 8: if not self.obs_vision: if self.fully_observed: shape = (1091, ) # remove health component if not used if not include_health: shape = (shape[0] - 1, ) if not include_num_objs: shape = (shape[0] - 8, ) self.observation_space = spaces.Box(low=0, high=255, shape=shape, dtype='uint8') if self.rnd: self.rnd_network = Mlp([128, 128], 32, self.observation_space.low.size) self.rnd_target_network = Mlp([128, 128], 32, self.observation_space.low.size) self.rnd_optimizer = Adam(self.rnd_target_network.parameters(), lr=3e-4) def get_make_sequence(self): def add_ingredients(obj, seq): for ingredient in self.ingredients.get(obj, []): add_ingredients(ingredient, seq) seq.append(obj) make_sequence = [] goal_obj = self.task[1] add_ingredients(goal_obj, make_sequence) return make_sequence def human_set_place_radius(self, r): self.human_radius = r def place_radius(self): assert self.place_schedule is not None, \ '`place_schedule` must be specified as (bump, period), giving radius(t) = (t + bump) // period' if self.human_radius is not None: return self.human_radius else: return (self.env_shaping_step_count + self.place_schedule[0]) // self.place_schedule[1] def place_items(self): if not self.gen_resources: return counts = self.count_all_types() placed = set() for type, thresh in self.replenish_low_resources.items(): if counts.get(type, 0) < thresh: self.place_prob( TYPE_TO_CLASS_ABS[type](lifespan=self.lifespans.get( type, self.default_lifespan)), 1) placed.add(type) for type, prob in self.resource_prob.items(): place_prob = max( self.resource_prob_min.get(type, 0), prob - self.resource_prob_decay.get(type, 0) * self.step_count) if type in placed: place_prob = 0 elif type in self.init_resources and not counts.get( type, 0) and type in self.replenish_empty_resources: # replenish resource if gone and was initially provided, or if resource is below the specified threshold place_prob = 1 elif counts.get(type, 0) > (self.grid_size - 2)**2 // max( 8, len(self.resource_prob)): # don't add more if it's already taking up over 1/8 of the space (lower threshold if >8 diff obj types being generated) place_prob = 0 if self.place_schedule and not self.full_grid: diam = self.place_radius() if diam >= 2 * self.grid_size: self.full_grid = True self.place_prob(TYPE_TO_CLASS_ABS[type]( lifespan=self.lifespans.get(type, self.default_lifespan)), place_prob, top=(np.clip(self.agent_pos - diam // 2, 0, self.grid_size - 1)), size=(diam, diam)) else: self.place_prob( TYPE_TO_CLASS_ABS[type](lifespan=self.lifespans.get( type, self.default_lifespan)), place_prob) def extra_gen_grid(self): for type, count in self.init_resources.items(): if self.task and self.task[0] == 'pickup' and type == self.task[1]: for _ in range(count): self.place_obj(TYPE_TO_CLASS_ABS[type]()) else: for _ in range(count): self.place_obj( TYPE_TO_CLASS_ABS[type](lifespan=self.lifespans.get( type, self.default_lifespan))) def extra_step(self, action, matched): if matched: return matched agent_cell = self.grid.get(*self.agent_pos) matched = True # Collect resources. Add to shelf. if action == self.actions.mine: if agent_cell and agent_cell.can_mine(self): mined = False # check if food or other resource, which we're storing separately if self.include_health and agent_cell.food_value() > 0: self.add_health(agent_cell.food_value()) mined = True self.just_eaten_type = agent_cell.type else: mined = self.add_to_shelf(agent_cell) if mined: self.just_mined_type = agent_cell.type self.info_last['pickup_%s' % agent_cell.type] = self.info_last[ 'pickup_%s' % agent_cell.type] + 1 self.grid.set(*self.agent_pos, None) # Consume stored food. elif action == self.actions.eat: pass # actions to use each element of inventory elif action == self.actions.place: self.place_act() else: matched = False return matched def place_act(self): agent_cell = self.grid.get(*self.agent_pos) if self.carrying is None: # there's nothing to place return elif agent_cell is None: # there's nothing to combine it with, so just place it on the grid self.grid.set(*self.agent_pos, self.carrying) else: # let's try to combine the placed object with the existing object interact_tup = tuple(sorted([self.carrying.type, agent_cell.type])) new_type = self.interactions.get(interact_tup, None) if not new_type: # the objects cannot be combined, no-op return else: self.last_placed_on = agent_cell self.just_placed_on = agent_cell # replace existing obj with new obj new_obj = TYPE_TO_CLASS_ABS[new_type]( lifespan=self.lifespans.get(new_type, self.default_lifespan)) self.grid.set(*self.agent_pos, new_obj) self.made_obj_type = new_obj.type self.just_made_obj_type = new_obj.type self.info_last['made_%s' % new_type] = self.info_last['made_%s' % new_type] + 1 # remove placed object from inventory self.carrying = None def add_to_shelf(self, obj): """ Returns whether adding to shelf succeeded """ if self.carrying is None: self.carrying = obj return True return False def gen_shelf_obs(self): """ Return one-hot encoding of carried object type. """ shelf_obs = np.zeros((1, len(self.object_to_idx)), dtype=np.uint8) if self.carrying is not None: shelf_obs[0, self.object_to_idx[self.carrying.type]] = 1 return shelf_obs def step(self, action): self.env_shaping_step_count += 1 self.just_made_obj_type = None self.just_eaten_type = None self.just_placed_on = None self.just_mined_type = None obs, reward, done, info = super().step(action, incl_health=self.include_health) shelf_obs = self.gen_shelf_obs() """ Generate obs """ obs_grid_string = obs.tostring() extra_obs = shelf_obs.flatten() # magic number repeating shelf 8 times to fill up more of the obs extra_obs = np.repeat(extra_obs, 8) num_objs = np.repeat(self.info_last['pickup_%s' % self.task[1]], 8) obs = np.concatenate( (obs, extra_obs, num_objs)) if self.include_num_objs else np.concatenate( (obs, extra_obs)) """ Generate reward """ solved = self.solved_task() if 'make' in self.task[0]: reward = self.get_make_reward() if self.task[0] == 'make': info.update({ 'progress': (self.max_make_idx + 1) / len(self.make_sequence) }) else: reward = int(solved) """ Generate info """ info.update({'health': self.health}) info.update(self.info_last) if solved: if self.end_on_task_completion: done = True info.update({'solved': True}) if self.lifelong: # remove obj so can keep making self.carrying = None else: info.update({'solved': False}) if self.time_horizon and self.step_count % self.time_horizon == 0: done = True """ Exploration bonuses """ self.obs_count[obs_grid_string] = self.obs_count.get( obs_grid_string, 0) + 1 if self.cbe: reward += 1 / np.sqrt(self.obs_count[obs_grid_string]) elif self.rnd: self.sum_rnd_obs += obs torch_obs = torch_ify(obs) true_rnd = self.rnd_network(torch_obs) pred_rnd = self.rnd_target_network(torch_obs) loss = self.rnd_loss(true_rnd, pred_rnd) self.rnd_optimizer.zero_grad() loss.backward() self.rnd_optimizer.step() # RND exploration bonus self.sum_rnd_loss += loss self.sum_square_rnd_loss += loss**2 mean = self.sum_rnd_loss / self.step_count stdev = (self.sum_square_rnd_loss / self.step_count) - mean**2 try: bonus = np.clip((loss / stdev).detach().numpy(), -1, 1) except ZeroDivisionError: # stdev is 0, which should occur only in the first timestep bonus = 1 reward += bonus if self.hitting_time == 0 and reward > 0: self.hitting_time = self.step_count # funny ordering because otherwise we'd get the transpose due to how the grid indices work self.visit_count[self.agent_pos[1], self.agent_pos[0]] += 1 return obs, reward, done, info def reset(self, seed=None, return_seed=False): prev_step_count = self.step_count if self.fixed_reset: self.seed(self.seed_val) else: if seed is None: seed = self._rand_int(0, 100000) self.seed(seed) obs = super().reset(incl_health=self.include_health) extra_obs = np.repeat(self.gen_shelf_obs(), 8) num_objs = np.repeat(self.info_last['pickup_%s' % self.task[1]], 8) obs = np.concatenate( (obs, extra_obs.flatten(), num_objs)) if self.include_num_objs else np.concatenate( (obs, extra_obs.flatten())) self.pantry = [] self.made_obj_type = None self.last_placed_on = None self.max_make_idx = -1 self.last_idx = -1 if self.reset_hitting: self.hitting_time = 0 else: # to be used for measuring hitting time in episodic setting self.step_count = prev_step_count self.obs_count = {} self.info_last = { 'pickup_%s' % k: 0 for k in self.object_to_idx.keys() if k not in ['empty', 'wall', 'tree'] } self.info_last.update( {'made_%s' % v: 0 for v in self.interactions.values()}) return (obs, seed) if return_seed else obs def solved_task(self): if 'make' in self.task[0]: return self.carrying is not None and self.carrying.type == self.task[ 1] elif self.task[0] == 'navigate': pos = np.array(self.task[1:]) return np.array_equal(pos, self.agent_pos) elif self.task[0] == 'pickup': return self.carrying is not None and (self.carrying.type == self.task[1]) else: raise NotImplementedError def get_make_reward(self): reward = 0 if self.make_rtype == 'sparse': reward = POS_RWD * int(self.solved_task()) if reward and self.lifelong: self.carrying = None self.num_solves += 1 elif self.make_rtype == 'sparse_negstep': reward = POS_RWD * int(self.solved_task()) or -0.01 if reward > 0 and self.lifelong: self.carrying = None self.num_solves += 1 elif self.make_rtype == 'dense': carry_idx = self.make_sequence.index( self.carrying.type ) if self.carrying and self.carrying.type in self.make_sequence else -1 just_place_idx = self.make_sequence.index( self.just_placed_on.type ) if self.just_placed_on and self.just_placed_on.type in self.make_sequence else -1 just_made_idx = self.make_sequence.index( self.just_made_obj_type ) if self.just_made_obj_type in self.make_sequence else -1 idx = max(carry_idx, just_place_idx) true_idx = max(idx, self.max_make_idx - 1) cur_idx = max(carry_idx, just_made_idx) # print('carry: %d, place: %d, made: %d, j_made: %d, idx: %d, true: %d, cur: %d' # % (carry_idx, just_place_idx, just_made_idx, just_made_idx, idx, true_idx, cur_idx)) if carry_idx == len(self.make_sequence) - 1: reward = POS_RWD self.max_make_idx = -1 self.num_solves += 1 self.last_idx = -1 elif just_made_idx > self.max_make_idx: reward = MED_RWD self.max_make_idx = just_made_idx elif idx == self.max_make_idx + 1: reward = MED_RWD self.max_make_idx = idx if cur_idx < self.last_idx: reward = NEG_RWD else: next_pos = self.get_closest_obj_pos( self.make_sequence[true_idx + 1]) if next_pos is not None: dist = np.linalg.norm(next_pos - self.agent_pos, ord=1) reward = -0.01 * dist # else there is no obj of that type, so 0 reward if carry_idx != len(self.make_sequence) - 1: self.last_idx = cur_idx elif self.make_rtype == 'waypoint': just_mined_idx = self.make_sequence.index( self.just_mined_type ) if self.just_mined_type in self.make_sequence else -1 just_place_idx = self.make_sequence.index( self.just_placed_on.type ) if self.just_placed_on and self.just_placed_on.type in self.make_sequence else -1 just_made_idx = self.make_sequence.index( self.just_made_obj_type ) if self.just_made_obj_type in self.make_sequence else -1 idx = max(just_mined_idx, just_place_idx) if idx >= 0: reward = POS_RWD**(idx // 2) elif self.make_rtype in ['one-time', 'dense-fixed']: carry_idx = self.make_sequence.index( self.carrying.type ) if self.carrying and self.carrying.type in self.make_sequence else -1 just_place_idx = self.make_sequence.index( self.just_placed_on.type ) if self.just_placed_on and self.just_placed_on.type in self.make_sequence else -1 just_made_idx = self.make_sequence.index( self.just_made_obj_type ) if self.just_made_obj_type in self.make_sequence else -1 max_idx = max(carry_idx, just_place_idx) # print('carry: %d, j_place: %d, j_made: %d, max: %d, last: %d' % (carry_idx, just_place_idx, just_made_idx, max_idx, self.last_idx)) if carry_idx == len(self.make_sequence) - 1: # exponent reasoning: 3rd obj in list should yield 100, 5th yields 10000, etc. reward = POS_RWD**(carry_idx // 2) self.onetime_reward_sequence = [ False for _ in range(len(self.make_sequence)) ] self.num_solves += 1 # remove the created goal object self.carrying = None self.last_idx = -1 if self.lifelong: # otherwise messes with progress metric self.max_make_idx = -1 elif max_idx != -1 and not self.onetime_reward_sequence[max_idx]: # exponent reasoning: 3rd obj in list should yield 100, 5th yields 10000, etc. reward = POS_RWD**(max_idx // 2) self.onetime_reward_sequence[max_idx] = True elif max(max_idx, just_made_idx) < self.last_idx: reward = -np.abs(NEG_RWD**(self.last_idx // 2 + 1)) elif self.make_rtype == 'dense-fixed': next_pos = self.get_closest_obj_pos(self.make_sequence[ self.onetime_reward_sequence.index(False)]) if next_pos is not None: dist = np.linalg.norm(next_pos - self.agent_pos, ord=1) reward = -0.01 * dist if max_idx > self.max_make_idx: self.max_make_idx = max_idx # only do this if it didn't just solve the task if carry_idx != len(self.make_sequence) - 1: self.last_idx = max_idx else: raise TypeError('Make reward type "%s" not recognized' % self.make_rtype) return reward def get_closest_obj_pos(self, type=None): def test_point(point, type): try: obj = self.grid.get(*point) if obj and (type is None or obj.type == type): return True except AssertionError: # OOB grid access return False corner = np.array([0, -1]) # range of max L1 distance on a grid of length self.grid_size - 2 (2 because of the borders) for i in range(0, 2 * self.grid_size - 5): # width of the fixed distance level set (diamond shape centered at agent pos) width = i + 1 test_pos = self.agent_pos + corner * i for j in range(width): if test_point(test_pos, type): return test_pos test_pos += np.array([1, 1]) test_pos -= np.array([1, 1]) for j in range(width): if test_point(test_pos, type): return test_pos test_pos += np.array([-1, 1]) test_pos -= np.array([-1, 1]) for j in range(width): if test_point(test_pos, type): return test_pos test_pos += np.array([-1, -1]) test_pos -= np.array([-1, -1]) for j in range(width): if test_point(test_pos, type): return test_pos test_pos += np.array([1, -1]) return None def decay_health(self): if self.include_health: super().decay_health()
class AbstractMDPVI: def __init__(self, env): self.env = env self.width = self.env.grid.height self.height = self.env.grid.height self.abstract_dim = 4 self.state_dim = 2 self.states = [] self.state_to_idx = None self.encoder = Mlp((64, 64, 64), output_size=self.abstract_dim, input_size=self.state_dim, output_activation=F.softmax, layer_norm=False) states = [] for j in range(self.env.grid.height): for i in range(self.env.grid.width): if self.env.grid.get(i, j) == None: states.append((i, j)) self.states = states self.states_np = np.array(states) self.state_to_idx = {s: i for i, s in enumerate(self.states)} self.next_states = [] for i, state in enumerate(states): next_states = self._gen_transitions(state) self.next_states.append(next_states) self.next_states = np.array(self.next_states) self.encoder.cuda() self.optimizer = optim.Adam(self.encoder.parameters(), lr=1e-4) def _gen_transitions(self, state): actions = np.array([[1, 0], [-1, 0], [0, 1], [0, -1]]) next_states = [] for action in actions: ns = np.array(state) + action if ns[0] >= 0 and ns[1] >= 0 and ns[0] < self.width and ns[1] < self.height and \ self.env.grid.get(*ns) == None: next_states.append(self.state_to_idx[tuple(ns.tolist())]) else: next_states.append(-1) return next_states def train_vi(self): dataset = data.TensorDataset(from_numpy(np.arange(len(self.states)))) dataloader = data.DataLoader(dataset, batch_size=32, shuffle=True) self.qvalues = from_numpy( np.zeros((len(self.states), len(self.states), 4))) self.values = from_numpy(np.zeros( (len(self.states), len(self.states)))) states = from_numpy(self.states_np) next_states = from_numpy(self.next_states).long() def eval_reward(s1, s2): match = (s1 == s2).float() return (1 - match) * -1 + match * 0 # indexed as goal state, then state for goal_state in range(len(self.states)): for train_itr in range(20): for batch_idx, s1 in enumerate(dataloader): s1 = s1[0].long() s2 = next_states[s1] for i in range(s2.shape[1]): ns = s2[:, i] valid_trans = (ns != -1).float() update = eval_reward( ns, goal_state) + 1.0 * self.values[goal_state, ns] #update[ns==goal_state] = 0 # Overwrite invalid actions with high negative reward so not chosen by max in value update self.qvalues[goal_state, s1, i] = valid_trans * update + ( 1 - valid_trans) * -1000 self.values[goal_state, s1] = self.qvalues[goal_state, s1, :].max(-1)[0] self.values[goal_state, goal_state] = 0 print(float(goal_state) / (len(self.states))) #import pdb; pdb.set_trace() self.values -= 1 self.values[np.arange(len(self.states)), np.arange(len(self.states))] = 0 np.save( "/home/jcoreyes/abstract/rlkit/examples/abstractmdp/values.npy", get_numpy(self.values)) def test_vi(self): Z = np.zeros((self.width, self.height)) values = get_numpy(self.values) for i, state in enumerate(self.states): Z[state] = values[0, i] print(Z) Z += Z.min() Z /= Z.max() plt.imshow(Z) plt.show() def gen_plot(self): #X = np.arange(0, self.width) #Y = np.arange(0, self.height) #X, Y = np.meshgrid(X, Y) Z = np.zeros((self.width, self.height)) for state in self.states: dist = get_numpy( self.encoder(from_numpy(np.array(state)).unsqueeze(0))) Z[state] = np.argmax(dist) + 1 #fig = plt.figure() #ax = Axes3D(fig) #surf = ax.plot_surface(X, Y, Z) #cset = ax.plot_surface(X, Y, Z, cmap=cm.coolwarm) #ax.clabel(cset) plt.imshow(Z) plt.show() def train(self, max_epochs=100): dataset = data.TensorDataset(from_numpy(np.arange(len(self.states)))) dataloader = data.DataLoader(dataset, batch_size=128, shuffle=True) values = np.load( "/home/jcoreyes/abstract/rlkit/examples/abstractmdp/values.npy") import pdb pdb.set_trace() values = from_numpy(np.abs(values)) for epoch in range(1, max_epochs + 1): stats = self.train_epoch(dataloader, epoch, values) print(stats) def kl(self, dist1, dist2): return (dist1 * (torch.log(dist1 + 1e-8) - torch.log(dist2 + 1e-8))).sum(1) def entropy(self, dist): return -(dist * torch.log(dist + 1e-8)).sum(1) def train_epoch(self, dataloader, epoch, values): stats = dict([('Loss', 0)]) states = from_numpy(self.states_np) for batch_idx, s1 in enumerate(dataloader): s1 = s1[0].long() bs = s1.shape[0] self.optimizer.zero_grad() s2 = from_numpy(np.random.randint(0, len(self.states), bs)).long() # Sample s2 where s1 is in same abstract state as s1 y1 = self.encoder(states[s1]) y2 = self.encoder(states[s2]) p1, a1 = y1.max(-1) p2, a2 = y2.max(-1) match = (a1 == a2) & (s1 != s2) distances = values[s2, s1] #reward = (distances < 7).float() * 1.0 reward = -distances * 5e-5 #import pdb; pdb.set_trace() surr_loss = -torch.log(y1[torch.arange(bs), a1] + 1e-8) * reward loss = (surr_loss - 1.0 * self.entropy(y1)) * match.float() loss = loss.sum() / bs loss.backward() nn.utils.clip_grad_norm(self.encoder.parameters(), 5.0) self.optimizer.step() stats['Loss'] += loss.item() stats['Loss'] /= (batch_idx + 1) return stats
class AbstractMDPsContrastive: def __init__(self, envs): self.envs = [EnvContainer(env) for env in envs] self.abstract_dim = 4 self.state_dim = 4 self.states = [] self.state_to_idx = None self.encoder = Mlp((64, 64, 64), output_size=self.abstract_dim, input_size=self.state_dim, output_activation=F.softmax, layer_norm=True) self.transitions = nn.Parameter( torch.zeros((self.abstract_dim, self.abstract_dim))) self.optimizer = optim.Adam(self.encoder.parameters()) def train(self, max_epochs=100): data_lst = [] for i, env in enumerate(self.envs): d = np.array(env.transitions) d = np.concatenate([d, np.zeros((d.shape[0], 1)) + i], 1) data_lst.append(d) all_data = from_numpy(np.concatenate(data_lst, 0)) dataset = data.TensorDataset(all_data) dataloader = data.DataLoader(dataset, batch_size=32, shuffle=True) for epoch in range(1, max_epochs + 1): stats = self.train_epoch(dataloader, epoch) print(stats) def kl(self, dist1, dist2): return (dist1 * (torch.log(dist1 + 1e-8) - torch.log(dist2 + 1e-8))).sum(1) def entropy(self, dist): return -(dist * torch.log(dist + 1e-8)).sum(-1) def compute_abstract_t(self, env): trans = env.transitions_np s1 = trans[:, :4] s2 = trans[:, 4:] y1 = self.encoder(from_numpy(s1)) y2 = self.encoder(from_numpy(s2)) y3 = self.encoder(from_numpy(env.sample_states(s1.shape[0]))) a_t = from_numpy(np.zeros((self.abstract_dim, self.abstract_dim))) for i in range(self.abstract_dim): for j in range(self.abstract_dim): a_t[i, j] += (y1[:, i] * y2[:, j]).sum(0) n_a_t = from_numpy(np.zeros((self.abstract_dim, self.abstract_dim))) for i in range(self.abstract_dim): n_a_t[i, :] = a_t[i, :] / a_t[i, :].sum() return n_a_t, y1, y2, y3 def train_epoch(self, dataloader, epoch): stats = OrderedDict([('Loss', 0), ('Converge', 0), ('Diverge', 0), ('Entropy1', 0), ('Entropy2', 0), ('Dev', 0)]) data = [self.compute_abstract_t(env) for env in self.envs] abstract_t = [x[0] for x in data] y1 = torch.cat([x[1] for x in data], 0) y2 = torch.cat([x[2] for x in data], 0) y3 = torch.cat([x[3] for x in data], 0) sample = from_numpy(np.random.randint(0, y1.shape[0], 128)).long() y1 = y1[sample] y2 = y2[sample] y3 = y3[sample] mean_t = sum(abstract_t) / len(abstract_t) self.mean_t = mean_t dev = [torch.pow(x[0] - mean_t, 2).mean() for x in abstract_t] #sample = from_numpy(np.random.randint(0, y1.shape[0], 32)).long() bs = y1.shape[0] converge = (self.kl(y1, y2) + self.kl(y2, y1)).sum() / bs diverge = (-self.kl(y1, y3) - self.kl(y2, y3)).sum() / bs #import pdb; pdb.set_trace() self.spread = y1.sum(0) / y1.sum() entropy1 = -self.entropy( y1.sum(0) / y1.sum()) # maximize entropy of spread over all data points entropy2 = self.entropy( y1).mean() # minimize entropy of single data point #l4 = -self.kl(y3, y1) #l5 = self.entropy(y1) * 0.1 #l6 = sum(dev) / len(dev) #import pdb; pdb.set_trace() loss = converge + diverge + entropy1 + entropy2 + 100 #loss = l5 #import pdb; pdb.set_trace() loss.backward() nn.utils.clip_grad_norm(self.encoder.parameters(), 5.0) self.optimizer.step() stats['Loss'] += loss.item() stats['Converge'] += converge.item() stats['Diverge'] += diverge.item() stats['Entropy1'] += entropy1.item() stats['Entropy2'] += entropy2.item() #stats['Dev'] += l6.item() return stats def gen_plot(self): plots = [env.gen_plot(self.encoder) for env in self.envs] plots = np.concatenate(plots, 1) plt.imshow(plots) #plt.savefig('/home/jcoreyes/abstract/rlkit/examples/abstractmdp/fig1.png') plt.show()
class OnlineVaeRelabelingBuffer(ObsDictRelabelingBuffer): def __init__( self, vae, *args, decoded_obs_key='image_observation', decoded_achieved_goal_key='image_achieved_goal', decoded_desired_goal_key='image_desired_goal', exploration_rewards_type='None', exploration_rewards_scale=1.0, vae_priority_type='None', start_skew_epoch=0, power=1.0, internal_keys=None, exploration_schedule_kwargs=None, priority_function_kwargs=None, exploration_counter_kwargs=None, relabeling_goal_sampling_mode='vae_prior', decode_vae_goals=False, **kwargs ): if internal_keys is None: internal_keys = [] for key in [ decoded_obs_key, decoded_achieved_goal_key, decoded_desired_goal_key ]: if key not in internal_keys: internal_keys.append(key) super().__init__(internal_keys=internal_keys, *args, **kwargs) # assert isinstance(self.env, VAEWrappedEnv) self.vae = vae self.decoded_obs_key = decoded_obs_key self.decoded_desired_goal_key = decoded_desired_goal_key self.decoded_achieved_goal_key = decoded_achieved_goal_key self.exploration_rewards_type = exploration_rewards_type self.exploration_rewards_scale = exploration_rewards_scale self.start_skew_epoch = start_skew_epoch self.vae_priority_type = vae_priority_type self.power = power self._relabeling_goal_sampling_mode = relabeling_goal_sampling_mode self.decode_vae_goals = decode_vae_goals if exploration_schedule_kwargs is None: self.explr_reward_scale_schedule = \ ConstantSchedule(self.exploration_rewards_scale) else: self.explr_reward_scale_schedule = \ PiecewiseLinearSchedule(**exploration_schedule_kwargs) self._give_explr_reward_bonus = ( exploration_rewards_type != 'None' and exploration_rewards_scale != 0. ) self._exploration_rewards = np.zeros((self.max_size, 1), dtype=np.float32) self._prioritize_vae_samples = ( vae_priority_type != 'None' and power != 0. ) self._vae_sample_priorities = np.zeros((self.max_size, 1), dtype=np.float32) self._vae_sample_probs = None self.use_dynamics_model = ( self.exploration_rewards_type == 'forward_model_error' ) if self.use_dynamics_model: self.initialize_dynamics_model() type_to_function = { 'reconstruction_error': self.reconstruction_mse, 'bce': self.binary_cross_entropy, 'latent_distance': self.latent_novelty, 'latent_distance_true_prior': self.latent_novelty_true_prior, 'forward_model_error': self.forward_model_error, 'gaussian_inv_prob': self.gaussian_inv_prob, 'bernoulli_inv_prob': self.bernoulli_inv_prob, 'vae_prob': self.vae_prob, 'hash_count': self.hash_count_reward, 'None': self.no_reward, } self.exploration_reward_func = ( type_to_function[self.exploration_rewards_type] ) self.vae_prioritization_func = ( type_to_function[self.vae_priority_type] ) if priority_function_kwargs is None: self.priority_function_kwargs = dict() else: self.priority_function_kwargs = priority_function_kwargs if self.exploration_rewards_type == 'hash_count': if exploration_counter_kwargs is None: exploration_counter_kwargs = dict() self.exploration_counter = CountExploration(env=self.env, **exploration_counter_kwargs) self.epoch = 0 def add_path(self, path): if self.decode_vae_goals: self.add_decoded_vae_goals_to_path(path) super().add_path(path) def add_decoded_vae_goals_to_path(self, path): # decoding the self-sampled vae images should be done in batch (here) # rather than in the env for efficiency desired_goals = combine_dicts( path['observations'], [self.desired_goal_key] )[self.desired_goal_key] desired_decoded_goals = self.env._decode(desired_goals) desired_decoded_goals = desired_decoded_goals.reshape( len(desired_decoded_goals), -1 ) for idx, next_obs in enumerate(path['observations']): path['observations'][idx][self.decoded_desired_goal_key] = \ desired_decoded_goals[idx] path['next_observations'][idx][self.decoded_desired_goal_key] = \ desired_decoded_goals[idx] def random_batch(self, batch_size): batch = super().random_batch(batch_size) exploration_rewards_scale = float(self.explr_reward_scale_schedule.get_value(self.epoch)) if self._give_explr_reward_bonus: batch_idxs = batch['indices'].flatten() batch['exploration_rewards'] = self._exploration_rewards[batch_idxs] batch['rewards'] += exploration_rewards_scale * batch['exploration_rewards'] return batch def get_diagnostics(self): if self._vae_sample_probs is None or self._vae_sample_priorities is None: stats = create_stats_ordered_dict( 'VAE Sample Weights', np.zeros(self._size), ) stats.update(create_stats_ordered_dict( 'VAE Sample Probs', np.zeros(self._size), )) else: vae_sample_priorities = self._vae_sample_priorities[:self._size] vae_sample_probs = self._vae_sample_probs[:self._size] stats = create_stats_ordered_dict( 'VAE Sample Weights', vae_sample_priorities, ) stats.update(create_stats_ordered_dict( 'VAE Sample Probs', vae_sample_probs, )) return stats def refresh_latents(self, epoch): self.epoch = epoch self.skew = (self.epoch > self.start_skew_epoch) batch_size = 512 next_idx = min(batch_size, self._size) if self.exploration_rewards_type == 'hash_count': # you have to count everything then compute exploration rewards cur_idx = 0 next_idx = min(batch_size, self._size) while cur_idx < self._size: idxs = np.arange(cur_idx, next_idx) normalized_imgs = self._next_obs[self.decoded_obs_key][idxs] self.update_hash_count(normalized_imgs) cur_idx = next_idx next_idx += batch_size next_idx = min(next_idx, self._size) cur_idx = 0 obs_sum = np.zeros(self.vae.representation_size) obs_square_sum = np.zeros(self.vae.representation_size) while cur_idx < self._size: idxs = np.arange(cur_idx, next_idx) self._obs[self.observation_key][idxs] = \ self.env._encode(self._obs[self.decoded_obs_key][idxs]) self._next_obs[self.observation_key][idxs] = \ self.env._encode(self._next_obs[self.decoded_obs_key][idxs]) # WARNING: we only refresh the desired/achieved latents for # "next_obs". This means that obs[desired/achieve] will be invalid, # so make sure there's no code that references this. # TODO: enforce this with code and not a comment self._next_obs[self.desired_goal_key][idxs] = \ self.env._encode(self._next_obs[self.decoded_desired_goal_key][idxs]) self._next_obs[self.achieved_goal_key][idxs] = \ self.env._encode(self._next_obs[self.decoded_achieved_goal_key][idxs]) normalized_imgs = self._next_obs[self.decoded_obs_key][idxs] if self._give_explr_reward_bonus: rewards = self.exploration_reward_func( normalized_imgs, idxs, **self.priority_function_kwargs ) self._exploration_rewards[idxs] = rewards.reshape(-1, 1) if self._prioritize_vae_samples: if ( self.exploration_rewards_type == self.vae_priority_type and self._give_explr_reward_bonus ): self._vae_sample_priorities[idxs] = ( self._exploration_rewards[idxs] ) else: self._vae_sample_priorities[idxs] = ( self.vae_prioritization_func( normalized_imgs, idxs, **self.priority_function_kwargs ).reshape(-1, 1) ) obs_sum+= self._obs[self.observation_key][idxs].sum(axis=0) obs_square_sum+= np.power(self._obs[self.observation_key][idxs], 2).sum(axis=0) cur_idx = next_idx next_idx += batch_size next_idx = min(next_idx, self._size) self.vae.dist_mu = obs_sum/self._size self.vae.dist_std = np.sqrt(obs_square_sum/self._size - np.power(self.vae.dist_mu, 2)) if self._prioritize_vae_samples: """ priority^power is calculated in the priority function for image_bernoulli_prob or image_gaussian_inv_prob and directly here if not. """ if self.vae_priority_type == 'vae_prob': self._vae_sample_priorities[:self._size] = relative_probs_from_log_probs( self._vae_sample_priorities[:self._size] ) self._vae_sample_probs = self._vae_sample_priorities[:self._size] else: self._vae_sample_probs = self._vae_sample_priorities[:self._size] ** self.power p_sum = np.sum(self._vae_sample_probs) assert p_sum > 0, "Unnormalized p sum is {}".format(p_sum) self._vae_sample_probs /= np.sum(self._vae_sample_probs) self._vae_sample_probs = self._vae_sample_probs.flatten() def sample_weighted_indices(self, batch_size): if ( self._prioritize_vae_samples and self._vae_sample_probs is not None and self.skew ): indices = np.random.choice( len(self._vae_sample_probs), batch_size, p=self._vae_sample_probs, ) assert ( np.max(self._vae_sample_probs) <= 1 and np.min(self._vae_sample_probs) >= 0 ) else: indices = self._sample_indices(batch_size) return indices def _sample_goals_from_env(self, batch_size): self.env.goal_sampling_mode = self._relabeling_goal_sampling_mode return self.env.sample_goals(batch_size) def sample_buffer_goals(self, batch_size): """ Samples goals from weighted replay buffer for relabeling or exploration. Returns None if replay buffer is empty. Example of what might be returned: dict( image_desired_goals: image_achieved_goals[weighted_indices], latent_desired_goals: latent_desired_goals[weighted_indices], ) """ if self._size == 0: return None weighted_idxs = self.sample_weighted_indices( batch_size, ) next_image_obs = self._next_obs[self.decoded_obs_key][weighted_idxs] next_latent_obs = self._next_obs[self.achieved_goal_key][weighted_idxs] return { self.decoded_desired_goal_key: next_image_obs, self.desired_goal_key: next_latent_obs } def random_vae_training_data(self, batch_size, epoch): # epoch no longer needed. Using self.skew in sample_weighted_indices # instead. weighted_idxs = self.sample_weighted_indices( batch_size, ) next_image_obs = self._next_obs[self.decoded_obs_key][weighted_idxs] observations = ptu.from_numpy(next_image_obs) return dict( observations=observations, ) def reconstruction_mse(self, next_vae_obs, indices): torch_input = ptu.from_numpy(next_vae_obs) recon_next_vae_obs, _, _ = self.vae(torch_input) error = torch_input - recon_next_vae_obs mse = torch.sum(error ** 2, dim=1) return ptu.get_numpy(mse) def gaussian_inv_prob(self, next_vae_obs, indices): return np.exp(self.reconstruction_mse(next_vae_obs, indices)) def binary_cross_entropy(self, next_vae_obs, indices): torch_input = ptu.from_numpy(next_vae_obs) recon_next_vae_obs, _, _ = self.vae(torch_input) error = - torch_input * torch.log( torch.clamp( recon_next_vae_obs, min=1e-30, # corresponds to about -70 ) ) bce = torch.sum(error, dim=1) return ptu.get_numpy(bce) def bernoulli_inv_prob(self, next_vae_obs, indices): torch_input = ptu.from_numpy(next_vae_obs) recon_next_vae_obs, _, _ = self.vae(torch_input) prob = ( torch_input * recon_next_vae_obs + (1 - torch_input) * (1 - recon_next_vae_obs) ).prod(dim=1) return ptu.get_numpy(1 / prob) def vae_prob(self, next_vae_obs, indices, **kwargs): return compute_p_x_np_to_np( self.vae, next_vae_obs, power=self.power, **kwargs ) def forward_model_error(self, next_vae_obs, indices): obs = self._obs[self.observation_key][indices] next_obs = self._next_obs[self.observation_key][indices] actions = self._actions[indices] state_action_pair = ptu.from_numpy(np.c_[obs, actions]) prediction = self.dynamics_model(state_action_pair) mse = self.dynamics_loss(prediction, ptu.from_numpy(next_obs)) return ptu.get_numpy(mse) def latent_novelty(self, next_vae_obs, indices): distances = ((self.env._encode(next_vae_obs) - self.vae.dist_mu) / self.vae.dist_std) ** 2 return distances.sum(axis=1) def latent_novelty_true_prior(self, next_vae_obs, indices): distances = self.env._encode(next_vae_obs) ** 2 return distances.sum(axis=1) def _kl_np_to_np(self, next_vae_obs, indices): torch_input = ptu.from_numpy(next_vae_obs) mu, log_var = self.vae.encode(torch_input) return ptu.get_numpy( - torch.sum(1 + log_var - mu.pow(2) - log_var.exp(), dim=1) ) def update_hash_count(self, next_vae_obs): torch_input = ptu.from_numpy(next_vae_obs) mus, log_vars = self.vae.encode(torch_input) mus = ptu.get_numpy(mus) self.exploration_counter.increment_counts(mus) return None def hash_count_reward(self, next_vae_obs, indices): obs = self.env._encode(next_vae_obs) return self.exploration_counter.compute_count_based_reward(obs) def no_reward(self, next_vae_obs, indices): return np.zeros((len(next_vae_obs), 1)) def initialize_dynamics_model(self): obs_dim = self._obs[self.observation_key].shape[1] self.dynamics_model = Mlp( hidden_sizes=[128, 128], output_size=obs_dim, input_size=obs_dim + self._action_dim, ) self.dynamics_model.to(ptu.device) self.dynamics_optimizer = Adam(self.dynamics_model.parameters()) self.dynamics_loss = MSELoss() def train_dynamics_model(self, batches=50, batch_size=100): if not self.use_dynamics_model: return for _ in range(batches): indices = self._sample_indices(batch_size) self.dynamics_optimizer.zero_grad() obs = self._obs[self.observation_key][indices] next_obs = self._next_obs[self.observation_key][indices] actions = self._actions[indices] if self.exploration_rewards_type == 'inverse_model_error': obs, next_obs = next_obs, obs state_action_pair = ptu.from_numpy(np.c_[obs, actions]) prediction = self.dynamics_model(state_action_pair) mse = self.dynamics_loss(prediction, ptu.from_numpy(next_obs)) mse.backward() self.dynamics_optimizer.step() def log_loss_under_uniform(self, model, data, batch_size, rl_logger, priority_function_kwargs): import torch.nn.functional as F log_probs_prior = [] log_probs_biased = [] log_probs_importance = [] kles = [] mses = [] for i in range(0, data.shape[0], batch_size): img = data[i:min(data.shape[0], i + batch_size), :] torch_img = ptu.from_numpy(img) reconstructions, obs_distribution_params, latent_distribution_params = self.vae(torch_img) priority_function_kwargs['sampling_method'] = 'true_prior_sampling' log_p, log_q, log_d = compute_log_p_log_q_log_d(model, img, **priority_function_kwargs) log_prob_prior = log_d.mean() priority_function_kwargs['sampling_method'] = 'biased_sampling' log_p, log_q, log_d = compute_log_p_log_q_log_d(model, img, **priority_function_kwargs) log_prob_biased = log_d.mean() priority_function_kwargs['sampling_method'] = 'importance_sampling' log_p, log_q, log_d = compute_log_p_log_q_log_d(model, img, **priority_function_kwargs) log_prob_importance = (log_p - log_q + log_d).mean() kle = model.kl_divergence(latent_distribution_params) mse = F.mse_loss(torch_img, reconstructions, reduction='elementwise_mean') mses.append(mse.item()) kles.append(kle.item()) log_probs_prior.append(log_prob_prior.item()) log_probs_biased.append(log_prob_biased.item()) log_probs_importance.append(log_prob_importance.item()) rl_logger["Uniform Data Log Prob (Prior)"] = np.mean(log_probs_prior) rl_logger["Uniform Data Log Prob (Biased)"] = np.mean(log_probs_biased) rl_logger["Uniform Data Log Prob (Importance)"] = np.mean(log_probs_importance) rl_logger["Uniform Data KL"] = np.mean(kles) rl_logger["Uniform Data MSE"] = np.mean(mses) def _get_sorted_idx_and_train_weights(self): idx_and_weights = zip(range(len(self._vae_sample_probs)), self._vae_sample_probs) return sorted(idx_and_weights, key=lambda x: x[1])
class AbstractMDPsContrastive: def __init__(self, envs): self.envs = [EnvContainer(env) for env in envs] self.abstract_dim = 4 self.state_dim = len(self.envs[0].states) + 2 self.states = [] self.state_to_idx = None self.encoder = Mlp((128, 128, 128), output_size=self.abstract_dim, input_size=self.state_dim, output_activation=F.softmax, layer_norm=True) self.transitions = nn.Parameter( torch.zeros((self.abstract_dim, self.abstract_dim))) self.optimizer = optim.Adam(self.encoder.parameters()) def train(self, max_epochs=100): for epoch in range(1, max_epochs + 1): stats = self.train_epoch(epoch) print(stats) def kl(self, dist1, dist2): return (dist1 * (torch.log(dist1 + 1e-8) - torch.log(dist2 + 1e-8))).sum(1) def entropy(self, dist): return -(dist * torch.log(dist + 1e-8)).sum(1) def compute_abstract_t(self, env): trans = env.transitions_onehot #import pdb; pdb.set_trace() s1 = trans[:, :self.state_dim] s2 = trans[:, self.state_dim:] y1 = self.encoder(from_numpy(s1)) y2 = self.encoder(from_numpy(s2)) y3 = self.encoder(from_numpy(env.sample_states(s1.shape[0]))) a_t = from_numpy(np.zeros((self.abstract_dim, self.abstract_dim))) for i in range(self.abstract_dim): for j in range(self.abstract_dim): a_t[i, j] += (y1[:, i] * y2[:, j]).sum(0) a_t = a_t / a_t.sum(1) return a_t, y1, y2, y3 def train_epoch(self, epoch): stats = OrderedDict([('Loss', 0), ('Converge', 0), ('Diverge', 0), ('Entropy', 0), ('Dev', 0)]) data = [self.compute_abstract_t(env) for env in self.envs] abstract_t = [x[0] for x in data] y1 = torch.cat([x[1] for x in data], 0) y2 = torch.cat([x[2] for x in data], 0) y3 = torch.cat([x[3] for x in data], 0) mean_t = sum(abstract_t) / len(abstract_t) dev = [torch.pow(x[0] - mean_t, 2).mean() for x in abstract_t] #import pdb; pdb.set_trace() #import pdb; pdb.set_trace() bs = y1.shape[0] l1 = self.kl(y1, y2) l2 = self.kl(y2, y1) l3 = (-self.kl(y1, y3) - self.kl(y3, y1)) * 20 #l4 = -self.kl(y3, y1) l5 = -self.entropy(y1) * 1 l6 = sum(dev) / len(dev) * 0 #import pdb; pdb.set_trace() loss = (l1 + l2) + l3 + l5 #loss = l5 loss = loss.sum() / bs + l6 #import pdb; pdb.set_trace() loss.backward() nn.utils.clip_grad_norm(self.encoder.parameters(), 5.0) self.optimizer.step() stats['Loss'] += loss.item() stats['Converge'] += ((l1 + l2).sum() / bs).item() stats['Diverge'] += (l3.sum() / bs).item() stats['Entropy'] += (l5.sum() / bs).item() stats['Dev'] += l6.item() self.y1 = y1 return stats def gen_plot(self): plots = [env.gen_plot(self.encoder) for env in self.envs] plots = np.concatenate(plots, 1) plt.imshow(plots) plt.savefig( '/home/jcoreyes/abstract/rlkit/examples/abstractmdp/fig1.png') plt.show()
class AbstractMDPContrastive: def __init__(self, env): self.env = env self.width = self.env.grid.height self.height = self.env.grid.height self.abstract_dim = 4 self.state_dim = 2 self.states = [] self.state_to_idx = None self.encoder = Mlp((64, 64, 64), output_size=self.abstract_dim, input_size=self.state_dim, output_activation=F.softmax, layer_norm=True) states = [] for j in range(self.env.grid.height): for i in range(self.env.grid.width): if self.env.grid.get(i, j) == None: states.append((i, j)) state_to_idx = {s: i for i, s in enumerate(states)} self.states = states self.state_to_idx = state_to_idx transitions = [] for i, state in enumerate(states): next_states = self._gen_transitions(state) for ns in next_states: transitions.append(list(state) + list(ns)) self.transitions = transitions self.optimizer = optim.Adam(self.encoder.parameters()) def _gen_transitions(self, state): actions = np.array([[1, 0], [-1, 0], [0, 1], [0, -1]]) next_states = [] for action in actions: ns = np.array(state) + action if ns[0] >= 0 and ns[1] >= 0 and ns[0] < self.width and ns[1] < self.height and \ self.env.grid.get(*ns) == None: next_states.append(ns) return next_states def gen_plot(self): #X = np.arange(0, self.width) #Y = np.arange(0, self.height) #X, Y = np.meshgrid(X, Y) Z = np.zeros((self.width, self.height)) for state in self.states: dist = get_numpy( self.encoder(from_numpy(np.array(state)).unsqueeze(0))) Z[state] = np.argmax(dist) + 1 #fig = plt.figure() #ax = Axes3D(fig) #surf = ax.plot_surface(X, Y, Z) #cset = ax.plot_surface(X, Y, Z, cmap=cm.coolwarm) #ax.clabel(cset) plt.imshow(Z) plt.show() def train(self, max_epochs=100): transitions = from_numpy(np.array(self.transitions)) dataset = data.TensorDataset(transitions[:, :2], transitions[:, 2:]) dataloader = data.DataLoader(dataset, batch_size=32, shuffle=True) for epoch in range(1, max_epochs + 1): stats = self.train_epoch(dataloader, epoch) print(stats) def kl(self, dist1, dist2): return (dist1 * (torch.log(dist1 + 1e-8) - torch.log(dist2 + 1e-8))).sum(1) def entropy(self, dist): return -(dist * torch.log(dist + 1e-8)).sum(1) def train_epoch(self, dataloader, epoch): stats = dict([('Loss', 0)]) for batch_idx, (s1, s2) in enumerate(dataloader): bs = s1.shape[0] self.optimizer.zero_grad() s3 = np.array([ self.states[x] for x in np.random.randint(0, len(self.states), bs) ]) y1 = self.encoder(s1) y2 = self.encoder(s2) y3 = self.encoder(from_numpy(s3)) l1 = self.kl(y1, y2) l2 = self.kl(y2, y1) l3 = -self.kl(y1, y3) #l4 = -self.kl(y3, y1) l5 = -self.entropy(y1) loss = (l1 + l2) + 0.8 * l3 + 0.3 * l5 loss = loss.sum() / bs loss.backward() nn.utils.clip_grad_norm(self.encoder.parameters(), 5.0) self.optimizer.step() stats['Loss'] += loss.item() stats['Loss'] /= (batch_idx + 1) return stats