def __init__(self, env_cls, horizon, policy, tasks, n_trajs, time_to_go=True, eval_freq=30): self.sampler = TrajectorySampler(env_cls=env_cls, policy=policy, horizon=horizon, tasks=tasks) self.n_trajs = n_trajs self.time_to_go = time_to_go self.eval_freq = eval_freq self._prefix = f'{policy.name}_{tasks.name}_policy_video'
class PolicyVideoEval(BaseEval): """Generates a video of the policy running in the environment and compares the labels of the terminal states.""" def __init__(self, env_cls, horizon, policy, tasks, n_trajs, time_to_go=True, eval_freq=30): self.sampler = TrajectorySampler(env_cls=env_cls, policy=policy, horizon=horizon, tasks=tasks) self.n_trajs = n_trajs self.time_to_go = time_to_go self.eval_freq = eval_freq self._prefix = f'{policy.name}_{tasks.name}_policy_video' def _add_time_to_go(self, frame, prog): frame = np.pad(frame, [[0, BAR_HEIGHT], [0, 0]], 'constant') frame[-BAR_HEIGHT:, :int(frame.shape[1] * prog)] = 255 return frame def eval(self, model): print(f'Evaluating {self.prefix}') trajs = self.sampler.collect_trajectories(n_interactions=None, n_trajs=self.n_trajs) frames = [] for traj in trajs: task = traj.task for i in range(len(traj.obs)): obs = traj.obs[i] frame = np.concatenate([obs[-1], task.obs[-1]], axis=1) if self.time_to_go: frame = self._add_time_to_go( frame, traj.policy_infos[i].time_to_go) frames.append(frame) # Freeze the video for 10 frames after a goal is achieved for _ in range(10): frames.append(frames[-1]) ann_frames = [] for i in range(len(frames)): frame_image = Image.fromarray(frames[i]).convert('RGBA') np_frame = np.array(frame_image) np_frame = np.moveaxis(np_frame, -1, 0) ann_frames.append(np_frame) ann_frames = np.array(ann_frames) video = wandb.Video(ann_frames, fps=10, format='mp4') logs = {'policy_video': video} return logs
def generate_goals(env_cls, horizon, policy, n_goals, include_terminal, random_horizon): tasks = NoneTaskDist() sampler = TrajectorySampler(env_cls=env_cls, policy=policy, horizon=horizon, tasks=tasks, lazy_labels=True, random_horizon=random_horizon) goals_collected = 0 goals = [] prog = tqdm(total=n_goals) while len(goals) < n_goals: traj = sampler.collect_trajectories(n_interactions=None, n_trajs=1)[0] if include_terminal or len(traj.obs) == horizon: goals.append(Goal(obs=traj.obs[-1], info=traj.infos[-1])) prog.update(1) prog.close() return goals
def generate_goals(env_cls, horizon, policy, n_goals, n_samples, include_fn=None, random_horizon=False): tasks = NoneTaskDist() sampler = TrajectorySampler(env_cls=env_cls, policy=policy, horizon=horizon, tasks=tasks, lazy_labels=True, random_horizon=random_horizon) # gather example label to figure out d traj = sampler.collect_trajectories(n_interactions=None, n_trajs=1)[0] labels = traj.infos[-1].labels feats = label_features(labels._asdict(), include_fn) feat_matrix = np.zeros([n_goals, feats.shape[0]]) goals_collected = 0 goals = [] for _ in tqdm(range(n_samples)): traj = sampler.collect_trajectories(n_interactions=None, n_trajs=1)[0] if len(traj.obs) == horizon: goal = Goal(obs=traj.obs[-1], info=traj.infos[-1]) if len(goals) < n_goals: goals.append(goal) feats = label_features(goal.info.labels._asdict(), include_fn) feat_matrix[len(goals) - 1] = feats else: accept, idx = propose_goal(feat_matrix, goal, include_fn) if accept: goals[idx] = goal return goals
def __init__(self, env_cls, horizon, policy, tasks, n_trajs, include_labels=None, n_bound_samples=500, eval_freq=30, output_trajs=False): self.sampler = TrajectorySampler(env_cls=env_cls, policy=policy, horizon=horizon, tasks=tasks, lazy_labels=True) self.n_trajs = n_trajs self.eval_freq = eval_freq self._prefix = f'{policy.name}_{tasks.name}_label_diff' self.include_labels = include_labels self.output_trajs = output_trajs self.cutoffs = self._get_label_cutoffs( env_cls, horizon, n_bound_samples)
def __init__( self, env_cls, horizon, k_dist, n_interactions, replay_ratio, # How many times to use a data point before getting a new one policy, task_dist, model, evals, batch_size, log_period=int(1e5), # log every 1e5 interactions snapshot_period=int(1e6), buffer_size=int(1e6), min_step_learn=int(1e4), lr=1e-4, clip_param=1, device='cuda', frame_buffer=True, checkpoint_path=None, ): if wandb_is_enabled(): wandb.save('*.pt') self.n_interactions = n_interactions r_prime = replay_ratio * (float(horizon) / float(batch_size)) if r_prime > 1: self.n_trajs = 1 self.n_batches = round(r_prime) else: self.n_trajs = round(1. / r_prime) self.n_batches = 1 effective_r = float(self.n_batches * batch_size) / \ float(self.n_trajs * horizon) print(f'Doing {self.n_batches} batches to {self.n_trajs} trajs.') print(f'Effective replay ratio: {effective_r}') self.model = model self.log_period = log_period self.snapshot_period = snapshot_period self.min_step_learn = min_step_learn self.evals = evals self.device = device self.action_space = env_cls().action_space self.n_actions = self.action_space.n optimizer = optim.AdamW(self.model.parameters(), lr=lr) if frame_buffer: self.replay_buffer = FrameBuffer(buffer_size=buffer_size, env_cls=env_cls, k_dist=k_dist, end_token=self.n_actions, action_seq=True) else: self.replay_buffer = ReplayBuffer(buffer_size=buffer_size, env_cls=env_cls, k_dist=k_dist, end_token=self.n_actions, action_seq=True) self.trainer = BatchSupervised(model=self.model, model_fn=self.model_fn, loss_fn=self.loss_fn, format_sample=self.format_sample, optimizer=optimizer, sampler=self.replay_buffer, batch_size=batch_size, clip_param=clip_param) self.policy = policy self.env_sampler = TrajectorySampler(env_cls=env_cls, policy=self.policy, horizon=horizon, tasks=task_dist, lazy_labels=False) self.ce_loss = torch.nn.NLLLoss(reduction='mean', ignore_index=self.n_actions + 1) self.total_steps = 0 self.last_snapshot = 0 self.checkpoint_path = checkpoint_path if self.checkpoint_path is not None: state = restore_state(self.checkpoint_path) if state is not None: self.model.load_state_dict(state.model_params) self.trainer.optimizer.load_state_dict(state.optimizer_params) self.replay_buffer = state.replay_buffer self.trainer.sampler = self.replay_buffer self.total_steps = state.total_steps print(state.total_steps) self.last_snapshot = state.last_snapshot if wandb_is_enabled(): wandb.watch(self.model)
class BatchTrainGLAMOR: """Trains a GLAMOR model. - env_cls: function that returns new environment instances - horizon: max steps in a trajectory before resetting - k_dist: distribution of k (length of trajectory segments on which to train) - n_interactions: total interactions during training - replay_ratio: how many times to use a data point before sampling a new one from the env - policy: policy to use when gathering new samples - task_dist: training task (goal) distribution - model: model object - evals: a list of evaluators which can test the model and return things to log - batch_size: sgd batch size - log_period: log every log_period interactions - snapshot_period: save a snapshot of the model every snapshot_period interactions - buffer_size: replay buffer size - min_step_learn: take this many steps before updating the model in the beginning - lr: sgd learning rate - clip_param: gradient clipping param - device: cuda/cpu - checkpoint_path: path where the checkpoint is stored, if any """ def __init__( self, env_cls, horizon, k_dist, n_interactions, replay_ratio, # How many times to use a data point before getting a new one policy, task_dist, model, evals, batch_size, log_period=int(1e5), # log every 1e5 interactions snapshot_period=int(1e6), buffer_size=int(1e6), min_step_learn=int(1e4), lr=1e-4, clip_param=1, device='cuda', frame_buffer=True, checkpoint_path=None, ): if wandb_is_enabled(): wandb.save('*.pt') self.n_interactions = n_interactions r_prime = replay_ratio * (float(horizon) / float(batch_size)) if r_prime > 1: self.n_trajs = 1 self.n_batches = round(r_prime) else: self.n_trajs = round(1. / r_prime) self.n_batches = 1 effective_r = float(self.n_batches * batch_size) / \ float(self.n_trajs * horizon) print(f'Doing {self.n_batches} batches to {self.n_trajs} trajs.') print(f'Effective replay ratio: {effective_r}') self.model = model self.log_period = log_period self.snapshot_period = snapshot_period self.min_step_learn = min_step_learn self.evals = evals self.device = device self.action_space = env_cls().action_space self.n_actions = self.action_space.n optimizer = optim.AdamW(self.model.parameters(), lr=lr) if frame_buffer: self.replay_buffer = FrameBuffer(buffer_size=buffer_size, env_cls=env_cls, k_dist=k_dist, end_token=self.n_actions, action_seq=True) else: self.replay_buffer = ReplayBuffer(buffer_size=buffer_size, env_cls=env_cls, k_dist=k_dist, end_token=self.n_actions, action_seq=True) self.trainer = BatchSupervised(model=self.model, model_fn=self.model_fn, loss_fn=self.loss_fn, format_sample=self.format_sample, optimizer=optimizer, sampler=self.replay_buffer, batch_size=batch_size, clip_param=clip_param) self.policy = policy self.env_sampler = TrajectorySampler(env_cls=env_cls, policy=self.policy, horizon=horizon, tasks=task_dist, lazy_labels=False) self.ce_loss = torch.nn.NLLLoss(reduction='mean', ignore_index=self.n_actions + 1) self.total_steps = 0 self.last_snapshot = 0 self.checkpoint_path = checkpoint_path if self.checkpoint_path is not None: state = restore_state(self.checkpoint_path) if state is not None: self.model.load_state_dict(state.model_params) self.trainer.optimizer.load_state_dict(state.optimizer_params) self.replay_buffer = state.replay_buffer self.trainer.sampler = self.replay_buffer self.total_steps = state.total_steps print(state.total_steps) self.last_snapshot = state.last_snapshot if wandb_is_enabled(): wandb.watch(self.model) def model_fn(self, f_sample): obs_0, obs_k, actions = f_sample actions, baseline_actions = self.model.pred_actions( obs_0, obs_k, actions) return actions, baseline_actions def loss_fn(self, f_sample, model_res): actions, baseline_actions = model_res _, _, actions_target = f_sample actions_target = actions_target.view(-1, self.n_actions + 2) actions_target = torch.argmax(actions_target, dim=1) actions = actions.reshape(-1, self.n_actions + 2) baseline_actions = baseline_actions.reshape(-1, self.n_actions + 2) actions_loss = self.ce_loss(actions, actions_target) baseline_actions_loss = self.ce_loss(baseline_actions, actions_target) loss = actions_loss + baseline_actions_loss return loss, { 'actions_ce_loss': actions_loss, 'baseline_actions_loss': baseline_actions_loss, 'ce_loss': loss, 'main_loss': loss } def format_sample(self, sample_batch): batch_size = sample_batch.actions.shape[0] actions = sample_batch.actions.float().to(device=self.device) obs_0 = sample_batch.obs.float().to(device=self.device).view( batch_size, -1) obs_k = sample_batch.obs_k.float().to(device=self.device).view( batch_size, -1) return obs_0, obs_k, actions def train(self): last_log = 0 prog = tqdm(total=self.log_period) last_log_time = time() logs = {} logs['cum_steps'] = 0 if self.total_steps == 0: self.eval(logs, snapshot=True) else: last_log = self.total_steps while self.total_steps < self.n_interactions: self.policy.update(self.total_steps) self.model.eval() trajs = self.env_sampler.collect_trajectories(n_interactions=None, n_trajs=self.n_trajs) self.replay_buffer.append_trajs(trajs) new_steps = sum([len(traj.obs) for traj in trajs]) prog.update(new_steps) self.total_steps += new_steps if len(self.replay_buffer) > self.min_step_learn: logs = self.trainer.train(self.n_batches) else: logs = {} if self.total_steps - last_log > self.log_period: last_log = self.total_steps prog.close() time_since_log = time() - last_log_time step_per_sec = self.log_period / time_since_log logs['step_per_sec'] = step_per_sec logs['cum_steps'] = self.total_steps last_log_time = time() snapshot = self.total_steps - self.last_snapshot > self.snapshot_period if snapshot: self.last_snapshot = self.total_steps self.eval(logs, snapshot=snapshot) prog = tqdm(total=self.log_period) # Log after the training is finished too. time_since_log = time() - last_log_time step_per_sec = self.log_period / time_since_log logs['step_per_sec'] = step_per_sec logs['cum_steps'] = self.total_steps self.eval(logs, snapshot=True) def eval(self, logs, snapshot=True): print('Beginning logging...') eval_start_time = time() self.model.eval() if self.evals is not None: for eval_ in self.evals: l = eval_.eval(self.model) prefix = eval_.prefix eval_logs = { f'{prefix}/{key}': value for key, value in l.items() } logs.update(eval_logs) logs['replay_size'] = len(self.replay_buffer) if self.policy.name == 'eps_greedy': logs['agent_eps'] = self.policy.probs[1] eval_duration = time() - eval_start_time logs['eval_time'] = eval_duration if wandb_is_enabled(): wandb.log(logs) pretty_print(logs) if self.checkpoint_path is not None: commit_state(checkpoint_path=self.checkpoint_path, model=self.model, optimizer=self.trainer.optimizer, replay_buffer=self.replay_buffer, total_steps=self.total_steps, last_snapshot=self.last_snapshot) # save model if snapshot: torch.save( self.model.state_dict(), os.path.join(run_path(), f'model_{logs["cum_steps"]}.pt'))
class LabelCompareEval(BaseEval): """Generates a video of the policy running in the environment and compares the labels of the terminal states.""" def __init__(self, env_cls, horizon, policy, tasks, n_trajs, include_labels=None, n_bound_samples=500, eval_freq=30, output_trajs=False): self.sampler = TrajectorySampler(env_cls=env_cls, policy=policy, horizon=horizon, tasks=tasks, lazy_labels=True) self.n_trajs = n_trajs self.eval_freq = eval_freq self._prefix = f'{policy.name}_{tasks.name}_label_diff' self.include_labels = include_labels self.output_trajs = output_trajs self.cutoffs = self._get_label_cutoffs( env_cls, horizon, n_bound_samples) def eval(self, model): print(f'Evaluating {self.prefix}') trajs = self.sampler.collect_trajectories(n_interactions=None, n_trajs=self.n_trajs) label_diffs = [] for traj in trajs: task = traj.task traj_labels = traj.infos[-1].labels task_labels = task.info.labels label_diffs.append(self.compare_labels(traj_labels, task_labels)) avg_label_diffs = {} avg_value = 0 for key in label_diffs[0]: value = 0 for di in label_diffs: value += di[key] value /= len(label_diffs) avg_label_diffs[key] = value avg_value += value avg_value /= len(label_diffs[0].keys()) avg_label_diffs['avg_diff'] = avg_value avg_label_diffs['avg_pos_diff'] = self._get_euclidean_average( avg_label_diffs) labels_achieved = self._get_prop_achieved(label_diffs) for key in labels_achieved: avg_label_diffs[f'achieved_{key}'] = labels_achieved[key] total_achieved = self._get_total_achieved( label_diffs, include=self.include_labels) avg_label_diffs['total_achieved'] = total_achieved if self.output_trajs: avg_label_diffs['trajs'] = trajs return avg_label_diffs def _get_label_cutoffs(self, env_cls, horizon, n_bound_samples): """Gathers N trajectories with a random policy. Records 10% distances for each dimension.""" env = env_cls() goals = ListTerminalStateGoalDist(env_cls=env_cls, horizon=horizon, policy=RandomPolicy( action_space=env.action_space), n_goals=n_bound_samples, include_terminal=True, random_horizon=True) mins = {} maxs = {} for goal in goals: labels = goal.info.labels label_dict = labels._asdict() for key in label_dict: v = label_dict[key] if key in mins: mins[key] = min(v, mins[key]) maxs[key] = max(v, maxs[key]) else: mins[key] = v maxs[key] = v cutoffs = {} for key in mins: # avoid 0 range_ = maxs[key] - mins[key] if range_ < 0.01: range_ = 1 cutoffs[key] = range_ * 0.1 print(mins) print(maxs) print(cutoffs) return cutoffs def _get_prop_achieved(self, label_diffs): achieved = {} for key in label_diffs[0]: value = 0 for di in label_diffs: if di[key] < self.cutoffs[key]: value += 1 value /= len(label_diffs) achieved[key] = value return achieved def _get_total_achieved(self, label_diffs, include=None): """Calculates the intersection of achieved labels""" achieved = 0 if include is None: include = [] for key in label_diffs[0]: include.append(key) for di in label_diffs: a = True for key in di: if key in include and di[key] >= self.cutoffs[key]: a = False break if a: achieved += 1 achieved /= len(label_diffs) return achieved def _get_euclidean_average(self, labels): included_pos = set() total_dist = 0 for key in labels: split_key = key.split('_') suffix = split_key[-1] entity_name = '_'.join(split_key[:-1]) if suffix == 'x' or suffix == 'y' and not entity_name in included_pos: x_pos, y_pos = 0, 0 x_label = f'{entity_name}_x' y_label = f'{entity_name}_y' if x_label in labels: x_pos = labels[x_label] if y_label in labels: y_pos = labels[y_label] dist = sqrt(x_pos ** 2 + y_pos ** 2) total_dist += dist included_pos.add(entity_name) if len(included_pos) == 0: return 0 return total_dist / len(included_pos) def _control_metric(self, labels): total_dist = 0 def compare_labels(self, traj_labels, task_labels): """Returns a dict of the absolute values between label pairs.""" traj_labels = traj_labels._asdict() task_labels = task_labels._asdict() label_diffs = {} for key in traj_labels: label_diffs[key] = abs( float(traj_labels[key]) - float(task_labels[key])) return label_diffs