Beispiel #1
0
    def __init__(
        self,
        config: ExperimentConfig,
        output_dir: str,
        loaded_config_src_files: Optional[Dict[str, str]],
        seed: Optional[int] = None,
        mode: str = "train",
        deterministic_cudnn: bool = False,
        deterministic_agents: bool = False,
        mp_ctx: Optional[BaseContext] = None,
        multiprocessing_start_method: str = "default",
        extra_tag: str = "",
        disable_tensorboard: bool = False,
        disable_config_saving: bool = False,
    ):
        self.config = config
        self.output_dir = output_dir
        self.loaded_config_src_files = loaded_config_src_files
        self.seed = seed if seed is not None else random.randint(0, 2**31 - 1)
        self.deterministic_cudnn = deterministic_cudnn
        if multiprocessing_start_method == "default":
            if torch.cuda.is_available():
                multiprocessing_start_method = "forkserver"
            else:
                # Spawn seems to play nicer with cpus and debugging
                multiprocessing_start_method = "spawn"
        self.mp_ctx = self.init_context(mp_ctx, multiprocessing_start_method)
        self.extra_tag = extra_tag
        self.mode = mode
        self.visualizer: Optional[VizSuite] = None
        self.deterministic_agents = deterministic_agents
        self.disable_tensorboard = disable_tensorboard
        self.disable_config_saving = disable_config_saving

        assert self.mode in [
            "train",
            "test",
        ], "Only 'train' and 'test' modes supported in runner"

        if self.deterministic_cudnn:
            set_deterministic_cudnn()

        set_seed(self.seed)

        self.queues = {
            "results": self.mp_ctx.Queue(),
            "checkpoints": self.mp_ctx.Queue(),
        }

        self.processes: Dict[str, List[Union[BaseProcess,
                                             mp.Process]]] = defaultdict(list)

        self.current_checkpoint = None

        self.local_start_time_str = time.strftime("%Y-%m-%d_%H-%M-%S",
                                                  time.localtime(time.time()))

        self._is_closed: bool = False
 def set_seed(self, seed: int):
     self.seed = seed
     if seed is not None:
         set_seed(seed)
 def set_seed(self, seed: int) -> None:
     set_seed(seed)
     self.np_seeded_random_gen, _ = seeding.np_random(seed)
     self.seed = seed
Beispiel #4
0
    def start_train(
        self,
        checkpoint: Optional[str] = None,
        restart_pipeline: bool = False,
        max_sampler_processes_per_worker: Optional[int] = None,
    ):
        if not self.disable_config_saving:
            self.save_project_state()

        devices = self.worker_devices("train")
        num_workers = len(devices)

        # Be extra careful to ensure that all models start
        # with the same initializations.
        set_seed(self.seed)
        initial_model_state_dict = self.config.create_model(
            sensor_preprocessor_graph=MachineParams.instance_from(
                self.config.machine_params(
                    self.mode)).sensor_preprocessor_graph).state_dict()

        distributed_port = 0
        if num_workers > 1:
            distributed_port = find_free_port()

        for trainer_it in range(num_workers):
            train: BaseProcess = self.mp_ctx.Process(
                target=self.train_loop,
                kwargs=dict(
                    id=trainer_it,
                    checkpoint=checkpoint,
                    restart_pipeline=restart_pipeline,
                    experiment_name=self.experiment_name,
                    config=self.config,
                    results_queue=self.queues["results"],
                    checkpoints_queue=self.queues["checkpoints"]
                    if self.running_validation else None,
                    checkpoints_dir=self.checkpoint_dir(),
                    seed=self.seed,
                    deterministic_cudnn=self.deterministic_cudnn,
                    mp_ctx=self.mp_ctx,
                    num_workers=num_workers,
                    device=devices[trainer_it],
                    distributed_port=distributed_port,
                    max_sampler_processes_per_worker=
                    max_sampler_processes_per_worker,
                    initial_model_state_dict=initial_model_state_dict,
                ),
            )
            train.start()
            self.processes["train"].append(train)

        get_logger().info("Started {} train processes".format(
            len(self.processes["train"])))

        # Validation
        if self.running_validation:
            device = self.worker_devices("valid")[0]
            self.init_visualizer("valid")
            valid: BaseProcess = self.mp_ctx.Process(
                target=self.valid_loop,
                args=(0, ),
                kwargs=dict(
                    config=self.config,
                    results_queue=self.queues["results"],
                    checkpoints_queue=self.queues["checkpoints"],
                    seed=
                    12345,  # TODO allow same order for randomly sampled tasks? Is this any useful anyway?
                    deterministic_cudnn=self.deterministic_cudnn,
                    deterministic_agents=self.deterministic_agents,
                    mp_ctx=self.mp_ctx,
                    device=device,
                    max_sampler_processes_per_worker=
                    max_sampler_processes_per_worker,
                ),
            )
            valid.start()
            self.processes["valid"].append(valid)

            get_logger().info("Started {} valid processes".format(
                len(self.processes["valid"])))
        else:
            get_logger().info(
                "No processes allocated to validation, no validation will be run."
            )

        self.log(self.local_start_time_str, num_workers)

        return self.local_start_time_str
Beispiel #5
0
    def start_train(
        self,
        checkpoint: Optional[str] = None,
        restart_pipeline: bool = False,
        max_sampler_processes_per_worker: Optional[int] = None,
    ):
        self._initialize_start_train_or_start_test()

        if not self.disable_config_saving:
            self.save_project_state()

        devices = self.worker_devices(TRAIN_MODE_STR)
        num_workers = len(devices)

        # Be extra careful to ensure that all models start
        # with the same initializations.
        set_seed(self.seed)
        initial_model_state_dict = self.config.create_model(
            sensor_preprocessor_graph=MachineParams.instance_from(
                self.config.machine_params(
                    self.mode)).sensor_preprocessor_graph).state_dict()

        distributed_port = 0
        if num_workers > 1:
            distributed_port = find_free_port()

        model_hash = None
        for trainer_it in range(num_workers):
            training_kwargs = dict(
                id=trainer_it,
                checkpoint=checkpoint,
                restart_pipeline=restart_pipeline,
                experiment_name=self.experiment_name,
                config=self.config,
                results_queue=self.queues["results"],
                checkpoints_queue=self.queues["checkpoints"]
                if self.running_validation else None,
                checkpoints_dir=self.checkpoint_dir(),
                seed=self.seed,
                deterministic_cudnn=self.deterministic_cudnn,
                mp_ctx=self.mp_ctx,
                num_workers=num_workers,
                device=devices[trainer_it],
                distributed_port=distributed_port,
                max_sampler_processes_per_worker=
                max_sampler_processes_per_worker,
                initial_model_state_dict=initial_model_state_dict
                if model_hash is None else model_hash,
            )
            train: BaseProcess = self.mp_ctx.Process(
                target=self.train_loop,
                kwargs=training_kwargs,
            )
            try:
                train.start()
            except ValueError as e:
                # If the `initial_model_state_dict` is too large we sometimes
                # run into errors passing it with multiprocessing. In such cases
                # we instead has the state_dict and confirm, in each engine worker, that
                # this hash equals the model the engine worker instantiates.
                if e.args[0] == "too many fds":
                    model_hash = md5_hash_of_state_dict(
                        initial_model_state_dict)
                    training_kwargs["initial_model_state_dict"] = model_hash
                    train = self.mp_ctx.Process(
                        target=self.train_loop,
                        kwargs=training_kwargs,
                    )
                    train.start()
                else:
                    raise e

            self.processes[TRAIN_MODE_STR].append(train)

        get_logger().info("Started {} train processes".format(
            len(self.processes[TRAIN_MODE_STR])))

        # Validation
        if self.running_validation:
            device = self.worker_devices("valid")[0]
            self.init_visualizer("valid")
            valid: BaseProcess = self.mp_ctx.Process(
                target=self.valid_loop,
                args=(0, ),
                kwargs=dict(
                    config=self.config,
                    results_queue=self.queues["results"],
                    checkpoints_queue=self.queues["checkpoints"],
                    seed=
                    12345,  # TODO allow same order for randomly sampled tasks? Is this any useful anyway?
                    deterministic_cudnn=self.deterministic_cudnn,
                    deterministic_agents=self.deterministic_agents,
                    mp_ctx=self.mp_ctx,
                    device=device,
                    max_sampler_processes_per_worker=
                    max_sampler_processes_per_worker,
                ),
            )
            valid.start()
            self.processes["valid"].append(valid)

            get_logger().info("Started {} valid processes".format(
                len(self.processes["valid"])))
        else:
            get_logger().info(
                "No processes allocated to validation, no validation will be run."
            )

        self.log_and_close(self.local_start_time_str, num_workers)

        return self.local_start_time_str
    def test_binned_and_semantic_mapping(self, tmpdir):
        try:
            if not self.setup_path_for_use_with_rearrangement_project():
                return

            from baseline_configs.rearrange_base import RearrangeBaseExperimentConfig
            from baseline_configs.walkthrough.walkthrough_rgb_base import (
                WalkthroughBaseExperimentConfig, )
            from rearrange.constants import (
                FOV,
                PICKUPABLE_OBJECTS,
                OPENABLE_OBJECTS,
            )
            from datagen.datagen_utils import get_scenes

            ORDERED_OBJECT_TYPES = list(
                sorted(PICKUPABLE_OBJECTS + OPENABLE_OBJECTS))

            map_range_sensor = ReachableBoundsTHORSensor(margin=1.0)
            map_info = dict(
                map_range_sensor=map_range_sensor,
                vision_range_in_cm=40 * 5,
                map_size_in_cm=1050,
                resolution_in_cm=5,
            )
            map_sensors = [
                RelativePositionChangeTHORSensor(),
                map_range_sensor,
                DepthSensorThor(
                    height=224,
                    width=224,
                    use_normalization=False,
                    uuid="depth",
                ),
                BinnedPointCloudMapTHORSensor(
                    fov=FOV,
                    ego_only=False,
                    **map_info,
                ),
                SemanticMapTHORSensor(
                    fov=FOV,
                    ego_only=False,
                    ordered_object_types=ORDERED_OBJECT_TYPES,
                    **map_info,
                ),
            ]
            all_sensors = [
                *WalkthroughBaseExperimentConfig.SENSORS, *map_sensors
            ]

            open_x_displays = []
            try:
                open_x_displays = get_open_x_displays()
            except (AssertionError, IOError):
                pass
            walkthrough_task_sampler = WalkthroughBaseExperimentConfig.make_sampler_fn(
                stage="train",
                sensors=all_sensors,
                scene_to_allowed_rearrange_inds={
                    s: [0]
                    for s in get_scenes("train")
                },
                force_cache_reset=True,
                allowed_scenes=None,
                seed=1,
                x_display=open_x_displays[0]
                if len(open_x_displays) != 0 else None,
                thor_controller_kwargs={
                    **RearrangeBaseExperimentConfig.THOR_CONTROLLER_KWARGS,
                    # "server_class": ai2thor.wsgi_server.WsgiServer,  # Only for debugging
                },
            )

            targets_path = os.path.join(tmpdir,
                                        "rearrange_mapping_examples.pkl.gz")
            urllib.request.urlretrieve(
                "https://ai2-prior-allenact-public-test.s3-us-west-2.amazonaws.com/ai2thor_mapping/rearrange_mapping_examples.pkl.gz",
                targets_path,
            )
            goal_obs_dict = compress_pickle.load(targets_path)

            def compare_recursive(obs, goal_obs, key_list: List):
                if isinstance(obs, Dict):
                    for k in goal_obs:
                        compare_recursive(obs=obs[k],
                                          goal_obs=goal_obs[k],
                                          key_list=key_list + [k])
                elif isinstance(obs, (List, Tuple)):
                    for i in range(len(goal_obs)):
                        compare_recursive(obs=obs[i],
                                          goal_obs=goal_obs[i],
                                          key_list=key_list + [i])
                else:
                    # Should be a numpy array at this point
                    assert isinstance(obs, np.ndarray) and isinstance(
                        goal_obs, np.ndarray
                    ), f"After {key_list}, not numpy arrays, obs={obs}, goal_obs={goal_obs}"

                    obs = 1.0 * obs
                    goal_obs = 1.0 * goal_obs

                    where_nan = np.isnan(goal_obs)
                    obs[where_nan] = 0.0
                    goal_obs[where_nan] = 0.0
                    assert (
                        np.abs(1.0 * obs - 1.0 * goal_obs).mean() < 1e-4
                    ), f"Difference of {np.abs(1.0 * obs - 1.0 * goal_obs).mean()} at {key_list}."

            observations_dict = defaultdict(lambda: [])
            for i in range(5):  # Why 5, why not 5?
                set_seed(i)
                task = walkthrough_task_sampler.next_task()

                obs_list = observations_dict[i]
                obs_list.append(task.get_observations())
                k = 0
                compare_recursive(obs=obs_list[0],
                                  goal_obs=goal_obs_dict[i][0],
                                  key_list=[i, k])
                while not task.is_done():
                    obs = task.step(action=task.action_names().index(
                        random.choice(3 * [
                            "move_ahead",
                            "rotate_right",
                            "rotate_left",
                            "look_up",
                            "look_down",
                        ] + ["done"]))).observation
                    k += 1
                    obs_list.append(obs)
                    compare_recursive(
                        obs=obs,
                        goal_obs=goal_obs_dict[i][task.num_steps_taken()],
                        key_list=[i, k],
                    )

                    # Free space metric map in RGB using pointclouds coming from depth images. This
                    # is built iteratively after every step.
                    # R - is used to encode points at a height < 0.02m (i.e. the floor)
                    # G - is used to encode points at a height between 0.02m and 2m, i.e. objects the agent would run into
                    # B - is used to encode points higher than 2m, i.e. ceiling

                    # Uncomment if you wish to visualize the observations:
                    # import matplotlib.pyplot as plt
                    # plt.imshow(
                    #     np.flip(255 * (obs["binned_pc_map"]["map"] > 0), 0)
                    # )  # np.flip because we expect "up" to be -row
                    # plt.title("Free space map")
                    # plt.show()
                    # plt.close()

                    # See also `obs["binned_pc_map"]["egocentric_update"]` to see the
                    # the metric map from the point of view of the agent before it is
                    # rotated into the world-space coordinates and merged with past observations.

                    # Semantic map in RGB which is iteratively revealed using depth maps to figure out what
                    # parts of the scene the agent has seen so far.
                    # This map has shape 210x210x72 with the 72 channels corresponding to the 72
                    # object types in `ORDERED_OBJECT_TYPES`
                    semantic_map = obs["semantic_map"]["map"]

                    # We can't display all 72 channels in an RGB image so instead we randomly assign
                    # each object a color and then just allow them to overlap each other
                    colored_semantic_map = SemanticMapBuilder.randomly_color_semantic_map(
                        semantic_map)

                    # Here's the full semantic map with nothing masked out because the agent
                    # hasn't seen it yet
                    colored_semantic_map_no_fog = SemanticMapBuilder.randomly_color_semantic_map(
                        map_sensors[-1].semantic_map_builder.
                        ground_truth_semantic_map)

                    # Uncomment if you wish to visualize the observations:
                    # import matplotlib.pyplot as plt
                    # plt.imshow(
                    #     np.flip(  # np.flip because we expect "up" to be -row
                    #         np.concatenate(
                    #             (
                    #                 colored_semantic_map,
                    #                 255 + 0 * colored_semantic_map[:, :10, :],
                    #                 colored_semantic_map_no_fog,
                    #             ),
                    #             axis=1,
                    #         ),
                    #         0,
                    #     )
                    # )
                    # plt.title("Semantic map with and without exploration fog")
                    # plt.show()
                    # plt.close()

                    # See also
                    # * `obs["semantic_map"]["egocentric_update"]`
                    # * `obs["semantic_map"]["explored_mask"]`
                    # * `obs["semantic_map"]["egocentric_mask"]`

            # To save observations for comparison against future runs, uncomment the below.
            # os.makedirs("tmp_out", exist_ok=True)
            # compress_pickle.dump(
            #     {**observations_dict}, "tmp_out/rearrange_mapping_examples.pkl.gz"
            # )
        finally:
            try:
                walkthrough_task_sampler.close()
            except NameError:
                pass
    def test_pretrained_rearrange_walkthrough_mapping_agent(self, tmpdir):
        try:
            if not self.setup_path_for_use_with_rearrangement_project():
                return

            from baseline_configs.rearrange_base import RearrangeBaseExperimentConfig
            from baseline_configs.walkthrough.walkthrough_rgb_mapping_ppo import (
                WalkthroughRGBMappingPPOExperimentConfig, )
            from rearrange.constants import (
                FOV,
                PICKUPABLE_OBJECTS,
                OPENABLE_OBJECTS,
            )
            from datagen.datagen_utils import get_scenes

            open_x_displays = []
            try:
                open_x_displays = get_open_x_displays()
            except (AssertionError, IOError):
                pass
            walkthrough_task_sampler = WalkthroughRGBMappingPPOExperimentConfig.make_sampler_fn(
                stage="train",
                scene_to_allowed_rearrange_inds={
                    s: [0]
                    for s in get_scenes("train")
                },
                force_cache_reset=True,
                allowed_scenes=None,
                seed=2,
                x_display=open_x_displays[0]
                if len(open_x_displays) != 0 else None,
            )

            named_losses = (WalkthroughRGBMappingPPOExperimentConfig.
                            training_pipeline().named_losses)

            ckpt_path = os.path.join(
                tmpdir, "pretrained_walkthrough_mapping_agent_75mil.pt")
            if not os.path.exists(ckpt_path):
                urllib.request.urlretrieve(
                    "https://prior-model-weights.s3.us-east-2.amazonaws.com/embodied-ai/rearrangement/walkthrough/pretrained_walkthrough_mapping_agent_75mil.pt",
                    ckpt_path,
                )

            state_dict = torch.load(
                ckpt_path,
                map_location="cpu",
            )

            walkthrough_model = WalkthroughRGBMappingPPOExperimentConfig.create_model(
            )
            walkthrough_model.load_state_dict(state_dict["model_state_dict"])

            rollout_storage = RolloutStorage(
                num_steps=1,
                num_samplers=1,
                actor_critic=walkthrough_model,
                only_store_first_and_last_in_memory=True,
            )
            memory = rollout_storage.pick_memory_step(0)
            masks = rollout_storage.masks[:1]

            binned_map_losses = []
            semantic_map_losses = []
            for i in range(5):
                masks = 0 * masks

                set_seed(i + 1)
                task = walkthrough_task_sampler.next_task()

                def add_step_dim(input):
                    if isinstance(input, torch.Tensor):
                        return input.unsqueeze(0)
                    elif isinstance(input, Dict):
                        return {k: add_step_dim(v) for k, v in input.items()}
                    else:
                        raise NotImplementedError

                batch = add_step_dim(
                    batch_observations([task.get_observations()]))

                while not task.is_done():
                    ac_out, memory = cast(
                        Tuple[ActorCriticOutput, Memory],
                        walkthrough_model.forward(
                            observations=batch,
                            memory=memory,
                            prev_actions=None,
                            masks=masks,
                        ),
                    )

                    binned_map_losses.append(
                        named_losses["binned_map_loss"].loss(
                            step_count=0,  # Not used in this loss
                            batch={"observations": batch},
                            actor_critic_output=ac_out,
                        )[0].item())
                    assert (
                        binned_map_losses[-1] < 0.16
                    ), f"Binned map loss to large at ({i}, {task.num_steps_taken()})"

                    semantic_map_losses.append(
                        named_losses["semantic_map_loss"].loss(
                            step_count=0,  # Not used in this loss
                            batch={"observations": batch},
                            actor_critic_output=ac_out,
                        )[0].item())
                    assert (
                        semantic_map_losses[-1] < 0.004
                    ), f"Semantic map loss to large at ({i}, {task.num_steps_taken()})"

                    masks = masks.fill_(1.0)
                    obs = task.step(action=ac_out.distributions.sample().item(
                    )).observation
                    batch = add_step_dim(batch_observations([obs]))

                    if task.num_steps_taken() >= 10:
                        break

            # To save observations for comparison against future runs, uncomment the below.
            # os.makedirs("tmp_out", exist_ok=True)
            # compress_pickle.dump(
            #     {**observations_dict}, "tmp_out/rearrange_mapping_examples.pkl.gz"
            # )
        finally:
            try:
                walkthrough_task_sampler.close()
            except NameError:
                pass