Ejemplo n.º 1
0
def main(args):
    config_path = args.config_path
    if config_path is None:
        config_path = utils.select_run()
    if config_path is None:
        return
    print(config_path)
    cfg = utils.load_config(config_path)

    # Create env
    if args.real:
        real_robot_indices = list(map(int, args.real_robot_indices.split(',')))
        real_cube_indices = list(map(int, args.real_cube_indices.split(',')))
        env = utils.get_env_from_cfg(cfg,
                                     real=True,
                                     real_robot_indices=real_robot_indices,
                                     real_cube_indices=real_cube_indices)
    else:
        env = utils.get_env_from_cfg(cfg, show_gui=True)

    # Create policy
    policy = utils.get_policy_from_cfg(cfg, env.get_robot_group_types())

    # Run policy
    state = env.reset()
    try:
        while True:
            action = policy.step(state)
            state, _, done, _ = env.step(action)
            if done:
                state = env.reset()
    finally:
        env.close()
Ejemplo n.º 2
0
def main(args):
    config_path = args.config_path
    if config_path is None:
        config_path = utils.select_run()
    if config_path is None:
        print('Please provide a config path')
        return
    cfg = utils.read_config(config_path)
    env = utils.get_env_from_cfg(cfg, use_gui=True)
    policy = utils.get_policy_from_cfg(cfg, env.get_action_space())
    state = env.reset()
    while True:
        action, _ = policy.step(state)
        state, _, done, _ = env.step(action)
        if done:
            state = env.reset()
Ejemplo n.º 3
0
def run_eval(cfg, num_episodes=20):
    random_seed = 0

    # Create env
    env = utils.get_env_from_cfg(cfg,
                                 random_seed=random_seed,
                                 use_egl_renderer=False)

    # Create policy
    policy = utils.get_policy_from_cfg(cfg,
                                       env.get_robot_group_types(),
                                       random_seed=random_seed)

    # Run policy
    data = [[] for _ in range(num_episodes)]
    episode_count = 0
    state = env.reset()
    while True:
        action = policy.step(state)
        state, _, done, info = env.step(action)
        data[episode_count].append({
            'simulation_steps':
            info['simulation_steps'],
            'cubes':
            info['total_cubes'],
            'robot_collisions':
            info['total_robot_collisions'],
        })
        if done:
            episode_count += 1
            print('Completed {}/{} episodes'.format(episode_count,
                                                    num_episodes))
            if episode_count >= num_episodes:
                break
            state = env.reset()
    env.close()

    return data
Ejemplo n.º 4
0
def _run_eval(cfg, num_episodes=20):
    env = utils.get_env_from_cfg(cfg, random_seed=0)
    policy = utils.get_policy_from_cfg(cfg,
                                       env.get_action_space(),
                                       random_seed=0)
    data = [[] for _ in range(num_episodes)]
    episode_count = 0
    state = env.reset()
    while True:
        action, _ = policy.step(state)
        state, _, done, info = env.step(action)
        data[episode_count].append({
            'distance': info['cumulative_distance'],
            'cubes': info['cumulative_cubes']
        })
        if done:
            state = env.reset()
            episode_count += 1
            print('Completed {}/{} episodes'.format(episode_count,
                                                    num_episodes))
            if episode_count >= num_episodes:
                break
    return data
Ejemplo n.º 5
0
def main(cfg):
    # Set up logging and checkpointing
    log_dir = Path(cfg.log_dir)
    checkpoint_dir = Path(cfg.checkpoint_dir)
    print('log_dir: {}'.format(log_dir))
    print('checkpoint_dir: {}'.format(checkpoint_dir))

    # Create environment
    kwargs = {}
    if cfg.show_gui:
        import matplotlib  # pylint: disable=import-outside-toplevel
        matplotlib.use('agg')
    if cfg.use_predicted_intention:  # Enable ground truth intention map during training only
        kwargs['use_intention_map'] = True
        kwargs['intention_map_encoding'] = 'ramp'
    env = utils.get_env_from_cfg(cfg, **kwargs)

    robot_group_types = env.get_robot_group_types()
    num_robot_groups = len(robot_group_types)

    # Policy
    policy = utils.get_policy_from_cfg(cfg, robot_group_types, train=True)

    # Optimizers
    optimizers = []
    for i in range(num_robot_groups):
        optimizers.append(
            optim.SGD(policy.policy_nets[i].parameters(),
                      lr=cfg.learning_rate,
                      momentum=0.9,
                      weight_decay=cfg.weight_decay))
    if cfg.use_predicted_intention:
        optimizers_intention = []
        for i in range(num_robot_groups):
            optimizers_intention.append(
                optim.SGD(policy.intention_nets[i].parameters(),
                          lr=cfg.learning_rate,
                          momentum=0.9,
                          weight_decay=cfg.weight_decay))

    # Replay buffers
    replay_buffers = []
    for _ in range(num_robot_groups):
        replay_buffers.append(ReplayBuffer(cfg.replay_buffer_size))

    # Resume if applicable
    start_timestep = 0
    episode = 0
    if cfg.checkpoint_path is not None:
        checkpoint = torch.load(cfg.checkpoint_path)
        start_timestep = checkpoint['timestep']
        episode = checkpoint['episode']
        for i in range(num_robot_groups):
            optimizers[i].load_state_dict(checkpoint['optimizers'][i])
            replay_buffers[i] = checkpoint['replay_buffers'][i]
        if cfg.use_predicted_intention:
            for i in range(num_robot_groups):
                optimizers_intention[i].load_state_dict(
                    checkpoint['optimizers_intention'][i])
        print("=> loaded checkpoint '{}' (timestep {})".format(
            cfg.checkpoint_path, start_timestep))

    # Target nets
    target_nets = policy.build_policy_nets()
    for i in range(num_robot_groups):
        target_nets[i].load_state_dict(policy.policy_nets[i].state_dict())
        target_nets[i].eval()

    # Logging
    train_summary_writer = SummaryWriter(log_dir=str(log_dir / 'train'))
    visualization_summary_writer = SummaryWriter(log_dir=str(log_dir /
                                                             'visualization'))
    meters = Meters()

    state = env.reset()
    transition_tracker = TransitionTracker(state)
    learning_starts = np.round(cfg.learning_starts_frac *
                               cfg.total_timesteps).astype(np.uint32)
    total_timesteps_with_warm_up = learning_starts + cfg.total_timesteps
    for timestep in tqdm(range(start_timestep, total_timesteps_with_warm_up),
                         initial=start_timestep,
                         total=total_timesteps_with_warm_up,
                         file=sys.stdout):
        # Select an action for each robot
        exploration_eps = 1 - (1 - cfg.final_exploration) * min(
            1,
            max(0, timestep - learning_starts) /
            (cfg.exploration_frac * cfg.total_timesteps))
        if cfg.use_predicted_intention:
            use_ground_truth_intention = max(
                0, timestep - learning_starts
            ) / cfg.total_timesteps <= cfg.use_predicted_intention_frac
            action = policy.step(
                state,
                exploration_eps=exploration_eps,
                use_ground_truth_intention=use_ground_truth_intention)
        else:
            action = policy.step(state, exploration_eps=exploration_eps)
        transition_tracker.update_action(action)

        # Step the simulation
        state, reward, done, info = env.step(action)

        # Store in buffers
        transitions_per_buffer = transition_tracker.update_step_completed(
            reward, state, done)
        for i, transitions in enumerate(transitions_per_buffer):
            for transition in transitions:
                replay_buffers[i].push(*transition)

        # Reset if episode ended
        if done:
            state = env.reset()
            transition_tracker = TransitionTracker(state)
            episode += 1

        # Train networks
        if timestep >= learning_starts and (timestep +
                                            1) % cfg.train_freq == 0:
            all_train_info = {}
            for i in range(num_robot_groups):
                batch = replay_buffers[i].sample(cfg.batch_size)
                train_info = train(cfg, policy.policy_nets[i], target_nets[i],
                                   optimizers[i], batch,
                                   policy.apply_transform,
                                   cfg.discount_factors[i])

                if cfg.use_predicted_intention:
                    train_info_intention = train_intention(
                        policy.intention_nets[i], optimizers_intention[i],
                        batch, policy.apply_transform)
                    train_info.update(train_info_intention)

                for name, val in train_info.items():
                    all_train_info['{}/robot_group_{:02}'.format(name,
                                                                 i + 1)] = val

        # Update target networks
        if (timestep + 1) % cfg.target_update_freq == 0:
            for i in range(num_robot_groups):
                target_nets[i].load_state_dict(
                    policy.policy_nets[i].state_dict())

        ################################################################################
        # Logging

        # Meters
        if timestep >= learning_starts and (timestep +
                                            1) % cfg.train_freq == 0:
            for name, val in all_train_info.items():
                meters.update(name, val)

        if done:
            for name in meters.get_names():
                train_summary_writer.add_scalar(name, meters.avg(name),
                                                timestep + 1)
            meters.reset()

            train_summary_writer.add_scalar('steps', info['steps'],
                                            timestep + 1)
            train_summary_writer.add_scalar('total_cubes', info['total_cubes'],
                                            timestep + 1)
            train_summary_writer.add_scalar('episodes', episode, timestep + 1)

            for i in range(num_robot_groups):
                for name in [
                        'cumulative_cubes', 'cumulative_distance',
                        'cumulative_reward', 'cumulative_robot_collisions'
                ]:
                    train_summary_writer.add_scalar(
                        '{}/robot_group_{:02}'.format(name, i + 1),
                        np.mean(info[name][i]), timestep + 1)

            # Visualize Q-network outputs
            if timestep >= learning_starts:
                random_state = [[
                    random.choice(replay_buffers[i].buffer).state
                ] for _ in range(num_robot_groups)]
                _, info = policy.step(random_state, debug=True)
                for i in range(num_robot_groups):
                    visualization = utils.get_state_output_visualization(
                        random_state[i][0], info['output'][i][0]).transpose(
                            (2, 0, 1))
                    visualization_summary_writer.add_image(
                        'output/robot_group_{:02}'.format(i + 1),
                        visualization, timestep + 1)
                    if cfg.use_predicted_intention:
                        visualization_intention = utils.get_state_output_visualization(
                            random_state[i][0],
                            np.stack((random_state[i][0][:, :, -1],
                                      info['output_intention'][i][0]),
                                     axis=0)  # Ground truth and output
                        ).transpose((2, 0, 1))
                        visualization_summary_writer.add_image(
                            'output_intention/robot_group_{:02}'.format(i + 1),
                            visualization_intention, timestep + 1)

        ################################################################################
        # Checkpointing

        if (
                timestep + 1
        ) % cfg.checkpoint_freq == 0 or timestep + 1 == total_timesteps_with_warm_up:
            if not checkpoint_dir.exists():
                checkpoint_dir.mkdir(parents=True, exist_ok=True)

            # Save policy
            policy_filename = 'policy_{:08d}.pth.tar'.format(timestep + 1)
            policy_path = checkpoint_dir / policy_filename
            policy_checkpoint = {
                'timestep':
                timestep + 1,
                'state_dicts': [
                    policy.policy_nets[i].state_dict()
                    for i in range(num_robot_groups)
                ],
            }
            if cfg.use_predicted_intention:
                policy_checkpoint['state_dicts_intention'] = [
                    policy.intention_nets[i].state_dict()
                    for i in range(num_robot_groups)
                ]
            torch.save(policy_checkpoint, str(policy_path))

            # Save checkpoint
            checkpoint_filename = 'checkpoint_{:08d}.pth.tar'.format(timestep +
                                                                     1)
            checkpoint_path = checkpoint_dir / checkpoint_filename
            checkpoint = {
                'timestep':
                timestep + 1,
                'episode':
                episode,
                'optimizers':
                [optimizers[i].state_dict() for i in range(num_robot_groups)],
                'replay_buffers':
                [replay_buffers[i] for i in range(num_robot_groups)],
            }
            if cfg.use_predicted_intention:
                checkpoint['optimizers_intention'] = [
                    optimizers_intention[i].state_dict()
                    for i in range(num_robot_groups)
                ]
            torch.save(checkpoint, str(checkpoint_path))

            # Save updated config file
            cfg.policy_path = str(policy_path)
            cfg.checkpoint_path = str(checkpoint_path)
            utils.save_config(log_dir / 'config.yml', cfg)

            # Remove old checkpoint
            checkpoint_paths = list(
                checkpoint_dir.glob('checkpoint_*.pth.tar'))
            checkpoint_paths.remove(checkpoint_path)
            for old_checkpoint_path in checkpoint_paths:
                old_checkpoint_path.unlink()

    env.close()
Ejemplo n.º 6
0
def main(args):
    # Connect to aruco server for pose estimates
    try:
        conn = Client(('localhost', 6000), authkey=b'secret password')
    except ConnectionRefusedError:
        print('Could not connect to aruco server for pose estimates')
        return

    # Create action executor for the physical robot
    action_executor = vector_action_executor.VectorActionExecutor(
        args.robot_index)

    # Create env
    config_path = args.config_path
    if config_path is None:
        config_path = utils.select_run()
    if config_path is None:
        print('Please provide a config path')
        return
    cfg = utils.read_config(config_path)
    kwargs = {'num_cubes': args.num_cubes}
    if args.debug:
        kwargs['use_gui'] = True
    cube_indices = list(range(args.num_cubes))
    env = utils.get_env_from_cfg(cfg,
                                 physical_env=True,
                                 robot_index=action_executor.robot_index,
                                 cube_indices=cube_indices,
                                 **kwargs)
    env.reset()

    # Create policy
    policy = utils.get_policy_from_cfg(cfg, env.get_action_space())

    # Debug visualization
    if args.debug:
        cv2.namedWindow('out', cv2.WINDOW_NORMAL)
        #cv2.resizeWindow('out', 960, 480)

    try:
        while True:
            # Get new pose estimates
            poses = None
            while conn.poll():  # ensure up-to-date data
                poses = conn.recv()
            if poses is None:
                continue

            # Update poses in the simulation
            env.update_poses(poses)

            # Get new action
            state = env.get_state()
            if action_executor.is_action_completed() and args.debug:
                action, info = policy.step(state, debug=True)
                # Visualize
                assert not cfg.use_steering_commands
                output = info['output'].cpu().numpy()
                cv2.imshow(
                    'out',
                    utils.get_state_and_output_visualization(
                        state, output)[:, :, ::-1])
                cv2.waitKey(1)
            else:
                action, _ = policy.step(state)

            # Run selected action through simulation
            try_action_result = env.try_action(action)

            if action_executor.is_action_completed():
                # Update action executor
                action_executor.update_try_action_result(try_action_result)

            # Run action executor
            action_executor.step(poses)

    finally:
        action_executor.disconnect()