Пример #1
0
 def __init__(self,
              env_cls,
              horizon,
              policy,
              tasks,
              n_trajs,
              time_to_go=True,
              eval_freq=30):
     self.sampler = TrajectorySampler(env_cls=env_cls,
                                      policy=policy,
                                      horizon=horizon,
                                      tasks=tasks)
     self.n_trajs = n_trajs
     self.time_to_go = time_to_go
     self.eval_freq = eval_freq
     self._prefix = f'{policy.name}_{tasks.name}_policy_video'
Пример #2
0
class PolicyVideoEval(BaseEval):
    """Generates a video of the policy running in the environment 
    and compares the labels of the terminal states."""
    def __init__(self,
                 env_cls,
                 horizon,
                 policy,
                 tasks,
                 n_trajs,
                 time_to_go=True,
                 eval_freq=30):
        self.sampler = TrajectorySampler(env_cls=env_cls,
                                         policy=policy,
                                         horizon=horizon,
                                         tasks=tasks)
        self.n_trajs = n_trajs
        self.time_to_go = time_to_go
        self.eval_freq = eval_freq
        self._prefix = f'{policy.name}_{tasks.name}_policy_video'

    def _add_time_to_go(self, frame, prog):
        frame = np.pad(frame, [[0, BAR_HEIGHT], [0, 0]], 'constant')
        frame[-BAR_HEIGHT:, :int(frame.shape[1] * prog)] = 255
        return frame

    def eval(self, model):
        print(f'Evaluating {self.prefix}')
        trajs = self.sampler.collect_trajectories(n_interactions=None,
                                                  n_trajs=self.n_trajs)

        frames = []
        for traj in trajs:
            task = traj.task

            for i in range(len(traj.obs)):
                obs = traj.obs[i]
                frame = np.concatenate([obs[-1], task.obs[-1]], axis=1)
                if self.time_to_go:
                    frame = self._add_time_to_go(
                        frame, traj.policy_infos[i].time_to_go)
                frames.append(frame)

            # Freeze the video for 10 frames after a goal is achieved
            for _ in range(10):
                frames.append(frames[-1])

        ann_frames = []
        for i in range(len(frames)):
            frame_image = Image.fromarray(frames[i]).convert('RGBA')
            np_frame = np.array(frame_image)
            np_frame = np.moveaxis(np_frame, -1, 0)
            ann_frames.append(np_frame)

        ann_frames = np.array(ann_frames)

        video = wandb.Video(ann_frames, fps=10, format='mp4')

        logs = {'policy_video': video}

        return logs
Пример #3
0
def generate_goals(env_cls, horizon, policy, n_goals, include_terminal, random_horizon):
    tasks = NoneTaskDist()
    sampler = TrajectorySampler(env_cls=env_cls,
                                policy=policy,
                                horizon=horizon,
                                tasks=tasks,
                                lazy_labels=True,
                                random_horizon=random_horizon)
    goals_collected = 0
    goals = []

    prog = tqdm(total=n_goals)
    while len(goals) < n_goals:
        traj = sampler.collect_trajectories(n_interactions=None,
                                            n_trajs=1)[0]
        if include_terminal or len(traj.obs) == horizon:
            goals.append(Goal(obs=traj.obs[-1], info=traj.infos[-1]))
            prog.update(1)

    prog.close()
    return goals
Пример #4
0
def generate_goals(env_cls,
                   horizon,
                   policy,
                   n_goals,
                   n_samples,
                   include_fn=None,
                   random_horizon=False):
    tasks = NoneTaskDist()
    sampler = TrajectorySampler(env_cls=env_cls,
                                policy=policy,
                                horizon=horizon,
                                tasks=tasks,
                                lazy_labels=True,
                                random_horizon=random_horizon)

    # gather example label to figure out d
    traj = sampler.collect_trajectories(n_interactions=None, n_trajs=1)[0]
    labels = traj.infos[-1].labels
    feats = label_features(labels._asdict(), include_fn)

    feat_matrix = np.zeros([n_goals, feats.shape[0]])

    goals_collected = 0
    goals = []

    for _ in tqdm(range(n_samples)):
        traj = sampler.collect_trajectories(n_interactions=None, n_trajs=1)[0]
        if len(traj.obs) == horizon:
            goal = Goal(obs=traj.obs[-1], info=traj.infos[-1])
            if len(goals) < n_goals:
                goals.append(goal)
                feats = label_features(goal.info.labels._asdict(), include_fn)
                feat_matrix[len(goals) - 1] = feats
            else:
                accept, idx = propose_goal(feat_matrix, goal, include_fn)
                if accept:
                    goals[idx] = goal

    return goals
Пример #5
0
    def __init__(self,
                 env_cls,
                 horizon,
                 policy,
                 tasks,
                 n_trajs,
                 include_labels=None,
                 n_bound_samples=500,
                 eval_freq=30,
                 output_trajs=False):
        self.sampler = TrajectorySampler(env_cls=env_cls,
                                         policy=policy,
                                         horizon=horizon,
                                         tasks=tasks,
                                         lazy_labels=True)
        self.n_trajs = n_trajs
        self.eval_freq = eval_freq
        self._prefix = f'{policy.name}_{tasks.name}_label_diff'

        self.include_labels = include_labels
        self.output_trajs = output_trajs

        self.cutoffs = self._get_label_cutoffs(
            env_cls, horizon, n_bound_samples)
Пример #6
0
    def __init__(
        self,
        env_cls,
        horizon,
        k_dist,
        n_interactions,
        replay_ratio,  # How many times to use a data point before getting a new one
        policy,
        task_dist,
        model,
        evals,
        batch_size,
        log_period=int(1e5),  # log every 1e5 interactions
        snapshot_period=int(1e6),
        buffer_size=int(1e6),
        min_step_learn=int(1e4),
        lr=1e-4,
        clip_param=1,
        device='cuda',
        frame_buffer=True,
        checkpoint_path=None,
    ):

        if wandb_is_enabled():
            wandb.save('*.pt')

        self.n_interactions = n_interactions

        r_prime = replay_ratio * (float(horizon) / float(batch_size))

        if r_prime > 1:
            self.n_trajs = 1
            self.n_batches = round(r_prime)
        else:
            self.n_trajs = round(1. / r_prime)
            self.n_batches = 1

        effective_r = float(self.n_batches * batch_size) / \
            float(self.n_trajs * horizon)

        print(f'Doing {self.n_batches} batches to {self.n_trajs} trajs.')
        print(f'Effective replay ratio: {effective_r}')
        self.model = model
        self.log_period = log_period
        self.snapshot_period = snapshot_period
        self.min_step_learn = min_step_learn
        self.evals = evals
        self.device = device

        self.action_space = env_cls().action_space
        self.n_actions = self.action_space.n

        optimizer = optim.AdamW(self.model.parameters(), lr=lr)

        if frame_buffer:
            self.replay_buffer = FrameBuffer(buffer_size=buffer_size,
                                             env_cls=env_cls,
                                             k_dist=k_dist,
                                             end_token=self.n_actions,
                                             action_seq=True)
        else:
            self.replay_buffer = ReplayBuffer(buffer_size=buffer_size,
                                              env_cls=env_cls,
                                              k_dist=k_dist,
                                              end_token=self.n_actions,
                                              action_seq=True)

        self.trainer = BatchSupervised(model=self.model,
                                       model_fn=self.model_fn,
                                       loss_fn=self.loss_fn,
                                       format_sample=self.format_sample,
                                       optimizer=optimizer,
                                       sampler=self.replay_buffer,
                                       batch_size=batch_size,
                                       clip_param=clip_param)

        self.policy = policy
        self.env_sampler = TrajectorySampler(env_cls=env_cls,
                                             policy=self.policy,
                                             horizon=horizon,
                                             tasks=task_dist,
                                             lazy_labels=False)

        self.ce_loss = torch.nn.NLLLoss(reduction='mean',
                                        ignore_index=self.n_actions + 1)

        self.total_steps = 0
        self.last_snapshot = 0

        self.checkpoint_path = checkpoint_path

        if self.checkpoint_path is not None:
            state = restore_state(self.checkpoint_path)
            if state is not None:
                self.model.load_state_dict(state.model_params)
                self.trainer.optimizer.load_state_dict(state.optimizer_params)
                self.replay_buffer = state.replay_buffer
                self.trainer.sampler = self.replay_buffer
                self.total_steps = state.total_steps
                print(state.total_steps)
                self.last_snapshot = state.last_snapshot

        if wandb_is_enabled():
            wandb.watch(self.model)
Пример #7
0
class BatchTrainGLAMOR:
    """Trains a GLAMOR model.

    - env_cls: function that returns new environment instances
    - horizon: max steps in a trajectory before resetting
    - k_dist: distribution of k (length of trajectory segments on which to train)
    - n_interactions: total interactions during training
    - replay_ratio: how many times to use a data point before sampling a new one from the env
    - policy: policy to use when gathering new samples
    - task_dist: training task (goal) distribution
    - model: model object
    - evals: a list of evaluators which can test the model and return things to log
    - batch_size: sgd batch size
    - log_period: log every log_period interactions
    - snapshot_period: save a snapshot of the model every snapshot_period interactions
    - buffer_size: replay buffer size
    - min_step_learn: take this many steps before updating the model in the beginning
    - lr: sgd learning rate
    - clip_param: gradient clipping param
    - device: cuda/cpu
    - checkpoint_path: path where the checkpoint is stored, if any
    """
    def __init__(
        self,
        env_cls,
        horizon,
        k_dist,
        n_interactions,
        replay_ratio,  # How many times to use a data point before getting a new one
        policy,
        task_dist,
        model,
        evals,
        batch_size,
        log_period=int(1e5),  # log every 1e5 interactions
        snapshot_period=int(1e6),
        buffer_size=int(1e6),
        min_step_learn=int(1e4),
        lr=1e-4,
        clip_param=1,
        device='cuda',
        frame_buffer=True,
        checkpoint_path=None,
    ):

        if wandb_is_enabled():
            wandb.save('*.pt')

        self.n_interactions = n_interactions

        r_prime = replay_ratio * (float(horizon) / float(batch_size))

        if r_prime > 1:
            self.n_trajs = 1
            self.n_batches = round(r_prime)
        else:
            self.n_trajs = round(1. / r_prime)
            self.n_batches = 1

        effective_r = float(self.n_batches * batch_size) / \
            float(self.n_trajs * horizon)

        print(f'Doing {self.n_batches} batches to {self.n_trajs} trajs.')
        print(f'Effective replay ratio: {effective_r}')
        self.model = model
        self.log_period = log_period
        self.snapshot_period = snapshot_period
        self.min_step_learn = min_step_learn
        self.evals = evals
        self.device = device

        self.action_space = env_cls().action_space
        self.n_actions = self.action_space.n

        optimizer = optim.AdamW(self.model.parameters(), lr=lr)

        if frame_buffer:
            self.replay_buffer = FrameBuffer(buffer_size=buffer_size,
                                             env_cls=env_cls,
                                             k_dist=k_dist,
                                             end_token=self.n_actions,
                                             action_seq=True)
        else:
            self.replay_buffer = ReplayBuffer(buffer_size=buffer_size,
                                              env_cls=env_cls,
                                              k_dist=k_dist,
                                              end_token=self.n_actions,
                                              action_seq=True)

        self.trainer = BatchSupervised(model=self.model,
                                       model_fn=self.model_fn,
                                       loss_fn=self.loss_fn,
                                       format_sample=self.format_sample,
                                       optimizer=optimizer,
                                       sampler=self.replay_buffer,
                                       batch_size=batch_size,
                                       clip_param=clip_param)

        self.policy = policy
        self.env_sampler = TrajectorySampler(env_cls=env_cls,
                                             policy=self.policy,
                                             horizon=horizon,
                                             tasks=task_dist,
                                             lazy_labels=False)

        self.ce_loss = torch.nn.NLLLoss(reduction='mean',
                                        ignore_index=self.n_actions + 1)

        self.total_steps = 0
        self.last_snapshot = 0

        self.checkpoint_path = checkpoint_path

        if self.checkpoint_path is not None:
            state = restore_state(self.checkpoint_path)
            if state is not None:
                self.model.load_state_dict(state.model_params)
                self.trainer.optimizer.load_state_dict(state.optimizer_params)
                self.replay_buffer = state.replay_buffer
                self.trainer.sampler = self.replay_buffer
                self.total_steps = state.total_steps
                print(state.total_steps)
                self.last_snapshot = state.last_snapshot

        if wandb_is_enabled():
            wandb.watch(self.model)

    def model_fn(self, f_sample):
        obs_0, obs_k, actions = f_sample
        actions, baseline_actions = self.model.pred_actions(
            obs_0, obs_k, actions)
        return actions, baseline_actions

    def loss_fn(self, f_sample, model_res):
        actions, baseline_actions = model_res
        _, _, actions_target = f_sample
        actions_target = actions_target.view(-1, self.n_actions + 2)
        actions_target = torch.argmax(actions_target, dim=1)
        actions = actions.reshape(-1, self.n_actions + 2)
        baseline_actions = baseline_actions.reshape(-1, self.n_actions + 2)

        actions_loss = self.ce_loss(actions, actions_target)
        baseline_actions_loss = self.ce_loss(baseline_actions, actions_target)

        loss = actions_loss + baseline_actions_loss
        return loss, {
            'actions_ce_loss': actions_loss,
            'baseline_actions_loss': baseline_actions_loss,
            'ce_loss': loss,
            'main_loss': loss
        }

    def format_sample(self, sample_batch):
        batch_size = sample_batch.actions.shape[0]
        actions = sample_batch.actions.float().to(device=self.device)
        obs_0 = sample_batch.obs.float().to(device=self.device).view(
            batch_size, -1)
        obs_k = sample_batch.obs_k.float().to(device=self.device).view(
            batch_size, -1)

        return obs_0, obs_k, actions

    def train(self):

        last_log = 0

        prog = tqdm(total=self.log_period)
        last_log_time = time()

        logs = {}
        logs['cum_steps'] = 0
        if self.total_steps == 0:
            self.eval(logs, snapshot=True)
        else:
            last_log = self.total_steps

        while self.total_steps < self.n_interactions:
            self.policy.update(self.total_steps)
            self.model.eval()
            trajs = self.env_sampler.collect_trajectories(n_interactions=None,
                                                          n_trajs=self.n_trajs)
            self.replay_buffer.append_trajs(trajs)
            new_steps = sum([len(traj.obs) for traj in trajs])
            prog.update(new_steps)
            self.total_steps += new_steps
            if len(self.replay_buffer) > self.min_step_learn:
                logs = self.trainer.train(self.n_batches)
            else:
                logs = {}
            if self.total_steps - last_log > self.log_period:
                last_log = self.total_steps
                prog.close()
                time_since_log = time() - last_log_time
                step_per_sec = self.log_period / time_since_log
                logs['step_per_sec'] = step_per_sec
                logs['cum_steps'] = self.total_steps
                last_log_time = time()

                snapshot = self.total_steps - self.last_snapshot > self.snapshot_period
                if snapshot:
                    self.last_snapshot = self.total_steps

                self.eval(logs, snapshot=snapshot)
                prog = tqdm(total=self.log_period)

        # Log after the training is finished too.
        time_since_log = time() - last_log_time
        step_per_sec = self.log_period / time_since_log
        logs['step_per_sec'] = step_per_sec
        logs['cum_steps'] = self.total_steps
        self.eval(logs, snapshot=True)

    def eval(self, logs, snapshot=True):
        print('Beginning logging...')
        eval_start_time = time()

        self.model.eval()
        if self.evals is not None:
            for eval_ in self.evals:
                l = eval_.eval(self.model)
                prefix = eval_.prefix
                eval_logs = {
                    f'{prefix}/{key}': value
                    for key, value in l.items()
                }
                logs.update(eval_logs)
        logs['replay_size'] = len(self.replay_buffer)
        if self.policy.name == 'eps_greedy':
            logs['agent_eps'] = self.policy.probs[1]
        eval_duration = time() - eval_start_time
        logs['eval_time'] = eval_duration

        if wandb_is_enabled():
            wandb.log(logs)
        pretty_print(logs)

        if self.checkpoint_path is not None:
            commit_state(checkpoint_path=self.checkpoint_path,
                         model=self.model,
                         optimizer=self.trainer.optimizer,
                         replay_buffer=self.replay_buffer,
                         total_steps=self.total_steps,
                         last_snapshot=self.last_snapshot)

        # save model
        if snapshot:
            torch.save(
                self.model.state_dict(),
                os.path.join(run_path(), f'model_{logs["cum_steps"]}.pt'))
Пример #8
0
class LabelCompareEval(BaseEval):
    """Generates a video of the policy running in the environment 
    and compares the labels of the terminal states."""

    def __init__(self,
                 env_cls,
                 horizon,
                 policy,
                 tasks,
                 n_trajs,
                 include_labels=None,
                 n_bound_samples=500,
                 eval_freq=30,
                 output_trajs=False):
        self.sampler = TrajectorySampler(env_cls=env_cls,
                                         policy=policy,
                                         horizon=horizon,
                                         tasks=tasks,
                                         lazy_labels=True)
        self.n_trajs = n_trajs
        self.eval_freq = eval_freq
        self._prefix = f'{policy.name}_{tasks.name}_label_diff'

        self.include_labels = include_labels
        self.output_trajs = output_trajs

        self.cutoffs = self._get_label_cutoffs(
            env_cls, horizon, n_bound_samples)

    def eval(self, model):
        print(f'Evaluating {self.prefix}')
        trajs = self.sampler.collect_trajectories(n_interactions=None,
                                                  n_trajs=self.n_trajs)

        label_diffs = []
        for traj in trajs:
            task = traj.task

            traj_labels = traj.infos[-1].labels
            task_labels = task.info.labels
            label_diffs.append(self.compare_labels(traj_labels, task_labels))

        avg_label_diffs = {}
        avg_value = 0
        for key in label_diffs[0]:
            value = 0
            for di in label_diffs:
                value += di[key]
            value /= len(label_diffs)
            avg_label_diffs[key] = value
            avg_value += value

        avg_value /= len(label_diffs[0].keys())
        avg_label_diffs['avg_diff'] = avg_value

        avg_label_diffs['avg_pos_diff'] = self._get_euclidean_average(
            avg_label_diffs)

        labels_achieved = self._get_prop_achieved(label_diffs)
        for key in labels_achieved:
            avg_label_diffs[f'achieved_{key}'] = labels_achieved[key]

        total_achieved = self._get_total_achieved(
            label_diffs, include=self.include_labels)
        avg_label_diffs['total_achieved'] = total_achieved

        if self.output_trajs:
            avg_label_diffs['trajs'] = trajs

        return avg_label_diffs

    def _get_label_cutoffs(self, env_cls, horizon, n_bound_samples):
        """Gathers N trajectories with a random policy. Records 10% distances
        for each dimension."""
        env = env_cls()
        goals = ListTerminalStateGoalDist(env_cls=env_cls,
                                          horizon=horizon,
                                          policy=RandomPolicy(
                                              action_space=env.action_space),
                                          n_goals=n_bound_samples,
                                          include_terminal=True,
                                          random_horizon=True)

        mins = {}
        maxs = {}

        for goal in goals:
            labels = goal.info.labels
            label_dict = labels._asdict()
            for key in label_dict:
                v = label_dict[key]
                if key in mins:
                    mins[key] = min(v, mins[key])
                    maxs[key] = max(v, maxs[key])
                else:
                    mins[key] = v
                    maxs[key] = v

        cutoffs = {}
        for key in mins:
            # avoid 0
            range_ = maxs[key] - mins[key]
            if range_ < 0.01:
                range_ = 1
            cutoffs[key] = range_ * 0.1

        print(mins)
        print(maxs)
        print(cutoffs)

        return cutoffs

    def _get_prop_achieved(self, label_diffs):
        achieved = {}
        for key in label_diffs[0]:
            value = 0
            for di in label_diffs:
                if di[key] < self.cutoffs[key]:
                    value += 1
            value /= len(label_diffs)
            achieved[key] = value

        return achieved

    def _get_total_achieved(self, label_diffs, include=None):
        """Calculates the intersection of achieved labels"""
        achieved = 0
        if include is None:
            include = []
            for key in label_diffs[0]:
                include.append(key)
        for di in label_diffs:
            a = True
            for key in di:
                if key in include and di[key] >= self.cutoffs[key]:
                    a = False
                    break
            if a:
                achieved += 1
        achieved /= len(label_diffs)
        return achieved

    def _get_euclidean_average(self, labels):
        included_pos = set()
        total_dist = 0
        for key in labels:
            split_key = key.split('_')
            suffix = split_key[-1]
            entity_name = '_'.join(split_key[:-1])
            if suffix == 'x' or suffix == 'y' and not entity_name in included_pos:
                x_pos, y_pos = 0, 0
                x_label = f'{entity_name}_x'
                y_label = f'{entity_name}_y'

                if x_label in labels:
                    x_pos = labels[x_label]
                if y_label in labels:
                    y_pos = labels[y_label]

                dist = sqrt(x_pos ** 2 + y_pos ** 2)
                total_dist += dist
                included_pos.add(entity_name)
        if len(included_pos) == 0:
            return 0
        return total_dist / len(included_pos)

    def _control_metric(self, labels):
        total_dist = 0

    def compare_labels(self, traj_labels, task_labels):
        """Returns a dict of the absolute values between label pairs."""
        traj_labels = traj_labels._asdict()
        task_labels = task_labels._asdict()

        label_diffs = {}

        for key in traj_labels:
            label_diffs[key] = abs(
                float(traj_labels[key]) - float(task_labels[key]))

        return label_diffs