예제 #1
0
    def __init__(self, dataset_dir, max_demos, num_cameras, num_frames,
                 channels, action_space, steps_action, num_signals, num_skills,
                 image_augmentation, **unused_kwargs):

        channels = list(channels)
        if image_augmentation:
            channels += ['mask']

        # define frames
        frames = Frames(path=dataset_dir,
                        channels=channels,
                        max_demos=max_demos,
                        augmentation=image_augmentation)
        frames.keys.set_query_limit('demo')

        # define actions and signals
        actions = Actions(dataset_dir, action_space)
        signals = Signals(dataset_dir, [('state', 'joint_position'),
                                        ('state', 'grip_velocity')])
        self._num_skills = num_skills
        self._get_skills = num_skills > 1
        if self._get_skills:
            actions.keys.set_query_limit('skill')
            actions.keys.set_skill_labels(
                Signals(dataset_dir, [('state', 'skill')]))
        else:
            actions.keys.set_query_limit('demo')
        actions.keys.set_max_demos(max_demos)
        signals.keys.set_max_demos(max_demos)

        # check datasets length match
        assert len(frames) == len(actions) and len(actions) == len(signals),\
            'Frames length {} Actions length {} Signals length {}'\
            .format(len(frames), len(actions), len(signals))

        self._frames = frames
        self._actions = actions
        self._signals = signals
        self._num_signals = num_signals
        self._channels = channels
        self._num_frames = num_frames
        self._num_cameras = num_cameras
        self._steps_action = steps_action
        self._action_space = action_space
        self._skills_indices = []

        if self._get_skills:
            # enable balanced skill sampling
            if actions.get_skill(0, skip_undefined=True) is None:
                raise ValueError(
                    'Skill sampling is on while the dataset contains no skill.'
                )
            self._init_skills_sampling()
예제 #2
0
 def _get_script_action(self, skill):
     # copy-paste from ppo.envs.mime
     if self._prev_script != skill:
         self._prev_script = skill
         self._prev_action_chain = self._env.unwrapped.scene.script_subtask(
             skill)
     action_chain = itertools.chain(*self._prev_action_chain)
     action_applied = Actions.get_dict_null_action(self.action_keys)
     action_update = next(action_chain, None)
     if action_update is None:
         self._need_master_action = True
     else:
         self._need_master_action = False
         action_applied.update(
             Actions.filter_action(action_update, self.action_keys))
     return action_applied
예제 #3
0
파일: regression.py 프로젝트: wx-b/rlbc
    def __init__(self,
                 archi,
                 mode,
                 action_space,
                 path=None,
                 input_type=('depth', ),
                 signal_keys=None,
                 signal_lengths=None,
                 statistics=None,
                 device='cuda',
                 **kwargs):
        super(Regression, self).__init__()

        # attributes of MetaNetwork
        self.archi = archi
        self.input_type = input_type
        self.action_keys, self.dim_action = Actions.action_space_to_keys(
            action_space)
        self.input_dim = CHANNELS[tuple(self.input_type)]
        self.output_dim = self.dim_action
        self.statistics = {'train': statistics, 'gt': statistics}

        self.signal_keys = []
        for signal_key in signal_keys:
            self.signal_keys.append(signal_key[-1])
        self.signal_lengths = signal_lengths

        self.args = self._get_args()
        self.net = make_resnet(
            archi, mode, self.input_dim, output_dim=self.output_dim)
        self.to(device)
예제 #4
0
파일: single.py 프로젝트: wx-b/rlbc
 def get_script_action(self, skill):
     if self.prev_script != skill.item():
         self.prev_script = skill.item()
         self.prev_action_chain = self.env.unwrapped.scene.script_subtask(
             skill)
     action_chain = itertools.chain(*self.prev_action_chain)
     action_applied = Actions.get_dict_null_action(self.action_keys)
     action_update = next(action_chain, None)
     if action_update is None:
         if self.render:
             print('env {:02d} needs a new master action (ts = {})'.format(
                 self.env_idx, self.step_counter))
         self.need_master_action = True
     else:
         self.need_master_action = False
         action_applied.update(
             Actions.filter_action(action_update, self.action_keys))
     if self.skills_timescales is not None:
         skill_timescale = self.skills_timescales[skill]
         self.need_master_action = self.step_counter_after_new_action >= skill_timescale
     return action_applied
예제 #5
0
    def __init__(self,
                 path,
                 epoch,
                 max_steps,
                 device='cpu',
                 env=None,
                 real_robot_mode=False,
                 timescales_list=None,
                 **kwargs):
        super(RLAgent, self).__init__(path, epoch, max_steps, device)
        self.model, self.rl_args, self.obs_running_stats = self._load_model()
        self.model.args = self.rl_args.bc_args
        self.real_robot_mode = real_robot_mode
        if not self.real_robot_mode:
            self.set_augmentation(self.rl_args.augmentation)
        else:
            self.set_augmentation('')
        self._max_steps = max_steps

        # skills timescales
        if timescales_list is None:
            if isinstance(self.rl_args.timescale, list):
                self._skills_timescales = self.rl_args.timescale
            else:
                assert isinstance(self.rl_args.timescale, int)
                self._skills_timescales = []
                for _ in range(self.rl_args.num_skills):
                    self._skills_timescales.append(self.rl_args.timescale)
        else:
            assert isinstance(timescales_list, list)
            self._skills_timescales = timescales_list

        # memory
        self._action_memory = self.rl_args.action_memory
        if self._action_memory > 0:
            last_skills_tensor = -torch.ones(self._action_memory).float()
            self._last_skills = {
                0: last_skills_tensor.to(torch.device(device))
            }
        else:
            self._last_skills = None

        # full state specific stuff
        self._env = env
        self.action_keys = Actions.action_space_to_keys(
            self.rl_args.bc_args['action_space'])[0]

        self.reset()
예제 #6
0
파일: model.py 프로젝트: wx-b/rlbc
 def get_worker_action(self, master_action, obs_dict):
     obs_tensor, env_idxs = misc.dict_to_tensor(obs_dict)
     master_action_filtered = []
     for env_idx in env_idxs:
         master_action_filtered.append(master_action[env_idx])
     master_action_filtered = torch.stack(master_action_filtered)
     action_tensor = self.base(obs_tensor, None, master_action=master_action_filtered)
     action_tensors_dict, env_idxs = misc.tensor_to_dict(action_tensor, env_idxs)
     action_tensors_dict_numpys = {key: value.cpu().numpy()
                                   for key, value in action_tensors_dict.items()}
     action_dicts_dict = {}
     master_action_dict, _ = misc.tensor_to_dict(master_action, env_idxs)
     for env_idx, action_tensor in action_tensors_dict_numpys.items():
         action_dict = Actions.tensor_to_dict(action_tensor, self.action_keys, self.statistics)
         action_dict['skill'] = master_action[env_idx].cpu().numpy()
         action_dicts_dict[env_idx] = action_dict
     return action_dicts_dict, env_idxs
예제 #7
0
파일: base.py 프로젝트: wx-b/rlbc
    def __init__(self,
                 archi,
                 mode,
                 num_frames,
                 action_space,
                 steps_action,
                 lam_grip=0.1,
                 input_type=('depth', ),
                 env_name='',
                 image_augmentation='',
                 statistics=None,
                 device='cuda',
                 network_extra_args=None,
                 **unused_kwargs):
        super(MetaPolicy, self).__init__()

        self.num_frames = num_frames
        self.action_space = action_space
        self.steps_action = len(steps_action)
        self.action_keys, self.dim_action = Actions.action_space_to_keys(
            action_space)
        self.dim_prediction = (self.dim_action + 1) * len(steps_action)
        self.dim_gt = self.dim_action * len(steps_action)
        self.lam_grip = lam_grip
        input_type = tuple(sorted(input_type))
        assert input_type in CHANNELS
        self.input_dim = CHANNELS[input_type] * num_frames
        self.output_dim = (self.dim_action + 1) * self.steps_action

        # attributes of MetaNetwork
        self.archi = archi
        self.mode = mode
        self.input_type = input_type
        self.env_name = env_name
        self.image_augmentation = image_augmentation
        self.statistics = statistics if statistics is not None else {}
        if network_extra_args is None:
            network_extra_args = {}

        self.net = make_resnet(archi,
                               mode,
                               self.input_dim,
                               output_dim=self.output_dim,
                               **network_extra_args)
        self.to(device)
        self.args = self._get_args()
예제 #8
0
파일: load.py 프로젝트: wx-b/rlbc
def bc_model(args, device):
    if args.bc_model_name:
        assert args.bc_model_epoch is not None, 'bc model epoch is not specified'
        bc_model_path = os.path.join(
            os.environ['RLBC_MODELS'], args.bc_model_name,
            'model_{}.pth'.format(args.bc_model_epoch))
        if device.type == 'cpu':
            loaded_dict = torch.load(bc_model_path,
                                     map_location=lambda storage, loc: storage)
        else:
            loaded_dict = torch.load(bc_model_path)
        args.bc_args = loaded_dict['args']
        print('loaded the BC checkpoint from {}'.format(bc_model_path))
        return args, loaded_dict['model'], loaded_dict['statistics']
    else:
        if 'Cam' in args.env_name:
            default_bc_args = dict(archi='resnet_18',
                                   mode='features',
                                   input_dim=3,
                                   num_frames=3,
                                   steps_action=4,
                                   action_space='tool_lin',
                                   dim_action=4,
                                   features_dim=512,
                                   env_name=args.env_name,
                                   input_type='depth')
            print('did not load a BC checkpoint, using default BC args: {}'.
                  format(default_bc_args))
        else:
            assert args.mime_action_space is not None
            default_bc_args = dict(action_space=args.mime_action_space,
                                   dim_action=Actions.action_space_to_keys(
                                       args.mime_action_space)[1],
                                   num_frames=1,
                                   env_name=args.env_name,
                                   input_type='full_state')
            print('Using a full state env with BC args: {}'.format(
                default_bc_args))
        args.bc_args = default_bc_args
        return args, None, None
예제 #9
0
파일: single.py 프로젝트: wx-b/rlbc
    def parse_args(self, args):
        # parse the args
        self.env_idx = args['env_idx']
        self.env_name = args['env_name']
        self.max_length = args['max_length']
        self.render = args['render'] and self.env_idx == 0
        self.action_keys = Actions.action_space_to_keys(
            args['bc_args']['action_space'])[0]
        if args['input_type'] == 'depth':
            self.channels = ('depth', )
        elif args['input_type'] == 'rgbd':
            self.channels = ('depth', 'rgb')
        else:
            raise NotImplementedError('Unknown input type = {}'.format(
                args['input_type']))
        self.augmentation = None
        self.augmentation_str = args['augmentation']
        self.use_expert_scripts = args['use_expert_scripts']
        if not self.use_expert_scripts:
            # timescales for skills (rlbc setup only)
            if isinstance(args['timescale'], list):
                self.skills_timescales = args['timescale']
            else:
                assert isinstance(args['timescale'], int)
                self.skills_timescales = []
                for _ in range(args['num_skills']):
                    self.skills_timescales.append(args['timescale'])
        else:
            self.skills_timescales = None

        # gifs writing
        self.gifdir = None
        if 'gifdir' in args:
            self.gifdir = os.path.join(args['gifdir'],
                                       '{:02d}'.format(self.env_idx))
            self.gif_counter = 0
            if self.gifdir:
                self.obs_history = {}
예제 #10
0
파일: model.py 프로젝트: wx-b/rlbc
    def __init__(self, obs_shape, action_space, bc_model, bc_statistics, **base_kwargs):
        super(MasterPolicy, self).__init__()

        self.action_keys = Actions.action_space_to_keys(base_kwargs['bc_args']['action_space'])[0]
        self.statistics = bc_statistics

        if len(obs_shape) == 3:
            self.base = ResnetBase(bc_model, **base_kwargs)
            # set the eval mode so the behavior of the skills is the same as in BC training
            self.base.resnet.eval()
        elif len(obs_shape) == 1:
            self.base = MLPBase(obs_shape[0], **base_kwargs)
        else:
            raise NotImplementedError

        if action_space.__class__.__name__ == "Discrete":
            num_outputs = action_space.n
            self.dist = Categorical(self.base.output_size, num_outputs)
        elif action_space.__class__.__name__ == "Box":
            num_outputs = action_space.shape[0]
            self.dist = DiagGaussian(self.base.output_size, num_outputs)
        else:
            raise NotImplementedError
예제 #11
0
파일: base.py 프로젝트: wx-b/rlbc
 def get_dict_action(self, obs, signals=None, skill=None):
     # get the action in the mime format (dictionary)
     action_tensor = self.get_action(obs, signals, skill)
     dict_action = Actions.tensor_to_dict(action_tensor, self.action_keys,
                                          self.statistics)
     return dict_action