Пример #1
0
 def save(self, iteration, num_to_keep=1):
     if self.args.save:
         pt_util.save(
             self, os.path.join(self.args.checkpoint_dir,
                                constants.TIME_STR), num_to_keep, iteration)
         if self.saves > 0 and self.saves % self.args.long_save_frequency == 0:
             pt_util.save(self, self.args.long_save_checkpoint_dir, -1,
                          iteration)
         self.saves += 1
Пример #2
0
def train_model(model, device, train_loader, optimizer, total_num_steps,
                logger, net_output_info, checkpoint_dir):
    try:
        model.train()
        if args.tensorboard:
            logger.network_conv_summary(model, total_num_steps)
        data_iter = iter(train_loader)
        for batch_idx in tqdm.tqdm(range(NUM_BATCHES_PER_EPOCH)):
            data = next(data_iter)
            labels = {key: val.to(device) for key, val in data.items()}
            labels["surface_normals"] = pt_util.depth_to_surface_normals(
                labels["depth"])
            data = labels["rgb"].detach()
            _, output, class_pred = model.forward(data, True)
            loss, loss_val, visual_loss_dict = optimizers.get_visual_loss(
                output, labels, net_output_info)
            object_loss = 0
            if "class_label" in labels:
                object_loss = optimizers.get_object_existence_loss(
                    class_pred, labels["class_label"])
                loss = loss + object_loss
                object_loss = object_loss.item()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if DEBUG:
                draw_outputs(output, labels, "train")
            total_num_steps += 1
            if not args.no_weight_update and batch_idx % args.log_interval == 0:
                if args.tensorboard:
                    log_dict = {"loss/visual/0_total": loss.item()}
                    if "class_label" in labels:
                        log_dict["loss/visual/object_loss"] = object_loss
                    for key, val in visual_loss_dict.items():
                        log_dict["loss/visual/" + key] = val
                    logger.dict_log(log_dict, step=total_num_steps)
            if args.tensorboard:
                if batch_idx % 100 == 0:
                    logger.network_variable_summary(model, total_num_steps)

        if args.save_checkpoints and not args.no_weight_update:
            pt_util.save(model,
                         checkpoint_dir,
                         num_to_keep=5,
                         iteration=total_num_steps)
        return total_num_steps
    except Exception as e:
        import traceback

        traceback.print_exc()
        if args.save_checkpoints and not args.no_weight_update:
            pt_util.save(model,
                         checkpoint_dir,
                         num_to_keep=-1,
                         iteration=total_num_steps)
        raise e
Пример #3
0
 def save_checkpoint(self, num_to_keep=-1, iteration=0):
     if self.shell_args.save_checkpoints and not self.shell_args.no_weight_update:
         pt_util.save(self.agent,
                      self.checkpoint_dir,
                      num_to_keep=num_to_keep,
                      iteration=iteration)
Пример #4
0
def main_worker(model, gpu, args, train_logger, test_logger, checkpoint_dir):
    global best_acc1
    args.gpu = gpu

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda(args.gpu)

    optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint["epoch"]
            best_acc1 = checkpoint["best_acc1"]
            if args.gpu is not None:
                # best_acc1 may be from a checkpoint from a different GPU
                best_acc1 = best_acc1.to(args.gpu)
            model.load_state_dict(checkpoint["state_dict"])
            optimizer.load_state_dict(checkpoint["optimizer"])
            print("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint["epoch"]))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    # Data loading code
    traindir = os.path.join(args.data, "train")
    valdir = os.path.join(args.data, "val")

    pt_util.save(model, checkpoint_dir, num_to_keep=5, iteration=0)

    train_dataset = datasets.ImageFolder(
        traindir,
        transforms.Compose(
            [
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Lambda(lambda x: x * 255),
            ]
        ),
    )

    train_sampler = None

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        shuffle=(train_sampler is None),
        num_workers=args.workers,
        pin_memory=True,
        sampler=train_sampler,
    )

    val_loader = torch.utils.data.DataLoader(
        datasets.ImageFolder(
            valdir,
            transforms.Compose(
                [
                    transforms.Resize(256),
                    transforms.CenterCrop(224),
                    transforms.ToTensor(),
                    transforms.Lambda(lambda x: x * 255),
                ]
            ),
        ),
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.workers,
        pin_memory=True,
    )

    if args.evaluate:
        validate(val_loader, model, criterion, args)
        return

    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch, args)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, args, train_logger)

        # evaluate on validation set
        acc1 = validate(val_loader, model, criterion, args, test_logger, (epoch + 1) * len(train_loader.dataset))

        # remember best acc@1 and save checkpoint
        is_best = acc1 > best_acc1
        best_acc1 = max(acc1, best_acc1)

        if is_best and args.save_checkpoints:
            pt_util.save(model, checkpoint_dir, num_to_keep=5, iteration=(epoch + 1) * len(train_loader.dataset))
Пример #5
0
    def evaluate_model(self):
        self.envs.unwrapped.call(["switch_dataset"] *
                                 self.shell_args.num_processes,
                                 [("val", )] * self.shell_args.num_processes)

        if not os.path.exists(self.eval_dir):
            os.makedirs(self.eval_dir)
        try:
            eval_net_file_name = sorted(
                glob.glob(
                    os.path.join(self.shell_args.log_prefix,
                                 self.shell_args.checkpoint_dirname, "*") +
                    "/*.pt"),
                key=os.path.getmtime,
            )[-1]
            eval_net_file_name = (
                self.shell_args.log_prefix.replace(os.sep, "_") + "_" +
                "_".join(eval_net_file_name.split(os.sep)[-2:])[:-3])
        except IndexError:
            print("Warning, no weights found")
            eval_net_file_name = "random_weights"
        eval_output_file = open(
            os.path.join(self.eval_dir, eval_net_file_name + ".csv"), "w")
        print("Writing results to", eval_output_file.name)

        # Save the evaled net for posterity
        if self.shell_args.save_checkpoints:
            save_model = self.agent
            pt_util.save(
                save_model,
                os.path.join(self.shell_args.log_prefix,
                             self.shell_args.checkpoint_dirname,
                             "eval_weights"),
                num_to_keep=-1,
                iteration=self.log_iter,
            )
            print("Wrote model to file for safe keeping")

        obs = self.envs.reset()
        if self.compute_surface_normals:
            obs["surface_normals"] = pt_util.depth_to_surface_normals(
                obs["depth"].to(self.device))
        obs["prev_action_one_hot"] = obs[
            "prev_action_one_hot"][:, ACTION_SPACE].to(torch.float32)
        recurrent_hidden_states = torch.zeros(
            self.shell_args.num_processes,
            self.agent.recurrent_hidden_state_size,
            dtype=torch.float32,
            device=self.device,
        )
        masks = torch.ones(self.shell_args.num_processes,
                           1,
                           dtype=torch.float32,
                           device=self.device)

        episode_rewards = deque(maxlen=10)
        current_episode_rewards = np.zeros(self.shell_args.num_processes)
        episode_lengths = deque(maxlen=10)
        current_episode_lengths = np.zeros(self.shell_args.num_processes)

        total_num_steps = self.log_iter
        fps_timer = [time.time(), total_num_steps]
        timers = np.zeros(3)

        num_episodes = 0

        print("Config\n", self.configs[0])

        # Initialize every time eval is run rather than just at the start
        dataset_sizes = np.array(
            [len(dataset.episodes) for dataset in self.eval_datasets])

        eval_stats = dict(
            episode_ids=[None for _ in range(self.shell_args.num_processes)],
            num_episodes=np.zeros(self.shell_args.num_processes,
                                  dtype=np.int32),
            num_steps=np.zeros(self.shell_args.num_processes, dtype=np.int32),
            reward=np.zeros(self.shell_args.num_processes, dtype=np.float32),
            spl=np.zeros(self.shell_args.num_processes, dtype=np.float32),
            visited_states=np.zeros(self.shell_args.num_processes,
                                    dtype=np.int32),
            success=np.zeros(self.shell_args.num_processes, dtype=np.int32),
            end_geodesic_distance=np.zeros(self.shell_args.num_processes,
                                           dtype=np.float32),
            start_geodesic_distance=np.zeros(self.shell_args.num_processes,
                                             dtype=np.float32),
            delta_geodesic_distance=np.zeros(self.shell_args.num_processes,
                                             dtype=np.float32),
            distance_from_start=np.zeros(self.shell_args.num_processes,
                                         dtype=np.float32),
        )
        eval_stats_means = dict(
            num_episodes=0,
            num_steps=0,
            reward=0,
            spl=0,
            visited_states=0,
            success=0,
            end_geodesic_distance=0,
            start_geodesic_distance=0,
            delta_geodesic_distance=0,
            distance_from_start=0,
        )
        eval_output_file.write("name,%s,iter,%d\n\n" %
                               (eval_net_file_name, self.log_iter))
        if self.shell_args.task == "pointnav":
            eval_output_file.write((
                "episode_id,num_steps,reward,spl,success,start_geodesic_distance,"
                "end_geodesic_distance,delta_geodesic_distance\n"))
        elif self.shell_args.task == "exploration":
            eval_output_file.write("episode_id,reward,visited_states\n")
        elif self.shell_args.task == "flee":
            eval_output_file.write("episode_id,reward,distance_from_start\n")
        distances = pt_util.to_numpy(obs["goal_geodesic_distance"])
        eval_stats["start_geodesic_distance"][:] = distances
        progress_bar = tqdm.tqdm(total=self.num_eval_episodes_total)
        all_done = False
        iter_count = 0
        video_frames = []
        previous_visual_features = None
        egomotion_pred = None
        prev_action = None
        prev_action_probs = None
        if hasattr(self.agent.base, "enable_decoder"):
            if self.shell_args.record_video:
                self.agent.base.enable_decoder()
            else:
                self.agent.base.disable_decoder()
        while not all_done:
            with torch.no_grad():
                start_t = time.time()
                value, action, action_log_prob, recurrent_hidden_states = self.agent.act(
                    {
                        "images":
                        obs["rgb"].to(self.device),
                        "target_vector":
                        obs["pointgoal"].to(self.device),
                        "prev_action_one_hot":
                        obs["prev_action_one_hot"].to(self.device),
                    },
                    recurrent_hidden_states,
                    masks,
                )
                action_cpu = pt_util.to_numpy(action.squeeze(1))
                translated_action_space = ACTION_SPACE[action_cpu]

                timers[1] += time.time() - start_t

                if self.shell_args.record_video:
                    if self.shell_args.use_motion_loss:
                        if previous_visual_features is not None:
                            egomotion_pred = self.agent.base.predict_egomotion(
                                self.agent.base.visual_features,
                                previous_visual_features)
                        previous_visual_features = self.agent.base.visual_features.detach(
                        )

                    # Copy so we don't mess with obs itself
                    draw_obs = OrderedDict()
                    for key, val in obs.items():
                        draw_obs[key] = pt_util.to_numpy(val).copy()
                    best_next_action = draw_obs.pop("best_next_action", None)

                    if prev_action is not None:
                        draw_obs["action_taken"] = pt_util.to_numpy(
                            self.agent.last_dist.probs).copy()
                        draw_obs["action_taken"][:] = 0
                        draw_obs["action_taken"][
                            np.arange(self.shell_args.num_processes),
                            prev_action] = 1
                        draw_obs["action_taken_name"] = SIM_ACTION_TO_NAME[
                            draw_obs['prev_action'].item()]
                        draw_obs["action_prob"] = pt_util.to_numpy(
                            prev_action_probs).copy()
                    else:
                        draw_obs["action_taken"] = None
                        draw_obs["action_taken_name"] = SIM_ACTION_TO_NAME[
                            SimulatorActions.STOP]
                        draw_obs["action_prob"] = None
                    prev_action = action_cpu
                    prev_action_probs = self.agent.last_dist.probs.detach()
                    if hasattr(
                            self.agent.base, "decoder_outputs"
                    ) and self.agent.base.decoder_outputs is not None:
                        min_channel = 0
                        for key, num_channels in self.agent.base.decoder_output_info:
                            outputs = self.agent.base.decoder_outputs[:,
                                                                      min_channel:
                                                                      min_channel
                                                                      +
                                                                      num_channels,
                                                                      ...]
                            draw_obs["output_" +
                                     key] = pt_util.to_numpy(outputs).copy()
                            min_channel += num_channels
                    draw_obs["rewards"] = eval_stats["reward"]
                    draw_obs["step"] = current_episode_lengths.copy()
                    draw_obs["method"] = self.shell_args.method_name
                    if best_next_action is not None:
                        draw_obs["best_next_action"] = best_next_action
                    if self.shell_args.use_motion_loss:
                        if egomotion_pred is not None:
                            draw_obs["egomotion_pred"] = pt_util.to_numpy(
                                F.softmax(egomotion_pred, dim=1)).copy()
                        else:
                            draw_obs["egomotion_pred"] = None
                    images, titles, normalize = draw_outputs.obs_to_images(
                        draw_obs)
                    im_inds = [0, 2, 3, 1, 6, 7, 8, 5]
                    height, width = images[0].shape[:2]
                    subplot_image = drawing.subplot(
                        images,
                        2,
                        4,
                        titles=titles,
                        normalize=normalize,
                        output_width=max(width, 320),
                        output_height=max(height, 320),
                        order=im_inds,
                        fancy_text=True,
                    )
                    video_frames.append(subplot_image)

                # save dists from previous step or else on reset they will be overwritten
                distances = pt_util.to_numpy(obs["goal_geodesic_distance"])

                start_t = time.time()
                obs, rewards, dones, infos = self.envs.step(
                    translated_action_space)
                timers[0] += time.time() - start_t
                obs["prev_action_one_hot"] = obs[
                    "prev_action_one_hot"][:, ACTION_SPACE].to(torch.float32)
                rewards *= REWARD_SCALAR
                rewards = np.clip(rewards, -10, 10)

                if self.shell_args.record_video and not dones[0]:
                    obs["top_down_map"] = infos[0]["top_down_map"]

                if self.compute_surface_normals:
                    obs["surface_normals"] = pt_util.depth_to_surface_normals(
                        obs["depth"].to(self.device))

                current_episode_rewards += pt_util.to_numpy(rewards).squeeze()
                current_episode_lengths += 1
                to_pause = []
                for ii, done_e in enumerate(dones):
                    if done_e:
                        num_episodes += 1

                        if self.shell_args.record_video:
                            if "top_down_map" in infos[ii]:
                                video_dir = os.path.join(
                                    self.shell_args.log_prefix, "videos")
                                if not os.path.exists(video_dir):
                                    os.makedirs(video_dir)
                                im_path = os.path.join(
                                    self.shell_args.log_prefix, "videos",
                                    "total_steps_%d.png" % total_num_steps)
                                top_down_map = maps.colorize_topdown_map(
                                    infos[ii]["top_down_map"]["map"])
                                imageio.imsave(im_path, top_down_map)

                            images_to_video(
                                video_frames,
                                os.path.join(self.shell_args.log_prefix,
                                             "videos"),
                                "total_steps_%d" % total_num_steps,
                            )
                            video_frames = []

                        eval_stats["episode_ids"][ii] = infos[ii]["episode_id"]

                        if self.shell_args.task == "pointnav":
                            print(
                                "FINISHED EPISODE %d Length %d Reward %.3f SPL %.4f"
                                % (
                                    num_episodes,
                                    current_episode_lengths[ii],
                                    current_episode_rewards[ii],
                                    infos[ii]["spl"],
                                ))
                            eval_stats["spl"][ii] = infos[ii]["spl"]
                            eval_stats["success"][
                                ii] = eval_stats["spl"][ii] > 0
                            eval_stats["num_steps"][
                                ii] = current_episode_lengths[ii]
                            eval_stats["end_geodesic_distance"][ii] = (
                                infos[ii]["final_distance"] if
                                eval_stats["success"][ii] else distances[ii])
                            eval_stats["delta_geodesic_distance"][ii] = (
                                eval_stats["start_geodesic_distance"][ii] -
                                eval_stats["end_geodesic_distance"][ii])
                        elif self.shell_args.task == "exploration":
                            print(
                                "FINISHED EPISODE %d Reward %.3f States Visited %d"
                                % (num_episodes, current_episode_rewards[ii],
                                   infos[ii]["visited_states"]))
                            eval_stats["visited_states"][ii] = infos[ii][
                                "visited_states"]
                        elif self.shell_args.task == "flee":
                            print(
                                "FINISHED EPISODE %d Reward %.3f Distance from start %.4f"
                                % (num_episodes, current_episode_rewards[ii],
                                   infos[ii]["distance_from_start"]))
                            eval_stats["distance_from_start"][ii] = infos[ii][
                                "distance_from_start"]

                        eval_stats["num_episodes"][ii] += 1
                        eval_stats["reward"][ii] = current_episode_rewards[ii]

                        if eval_stats["num_episodes"][ii] <= dataset_sizes[ii]:
                            progress_bar.update(1)
                            eval_stats_means["num_episodes"] += 1
                            eval_stats_means["reward"] += eval_stats["reward"][
                                ii]
                            if self.shell_args.task == "pointnav":
                                eval_output_file.write(
                                    "%s,%d,%f,%f,%d,%f,%f,%f\n" % (
                                        eval_stats["episode_ids"][ii],
                                        eval_stats["num_steps"][ii],
                                        eval_stats["reward"][ii],
                                        eval_stats["spl"][ii],
                                        eval_stats["success"][ii],
                                        eval_stats["start_geodesic_distance"]
                                        [ii],
                                        eval_stats["end_geodesic_distance"]
                                        [ii],
                                        eval_stats["delta_geodesic_distance"]
                                        [ii],
                                    ))
                                eval_stats_means["num_steps"] += eval_stats[
                                    "num_steps"][ii]
                                eval_stats_means["spl"] += eval_stats["spl"][
                                    ii]
                                eval_stats_means["success"] += eval_stats[
                                    "success"][ii]
                                eval_stats_means[
                                    "start_geodesic_distance"] += eval_stats[
                                        "start_geodesic_distance"][ii]
                                eval_stats_means[
                                    "end_geodesic_distance"] += eval_stats[
                                        "end_geodesic_distance"][ii]
                                eval_stats_means[
                                    "delta_geodesic_distance"] += eval_stats[
                                        "delta_geodesic_distance"][ii]
                            elif self.shell_args.task == "exploration":
                                eval_output_file.write("%s,%f,%d\n" % (
                                    eval_stats["episode_ids"][ii],
                                    eval_stats["reward"][ii],
                                    eval_stats["visited_states"][ii],
                                ))
                                eval_stats_means[
                                    "visited_states"] += eval_stats[
                                        "visited_states"][ii]
                            elif self.shell_args.task == "flee":
                                eval_output_file.write("%s,%f,%f\n" % (
                                    eval_stats["episode_ids"][ii],
                                    eval_stats["reward"][ii],
                                    eval_stats["distance_from_start"][ii],
                                ))
                                eval_stats_means[
                                    "distance_from_start"] += eval_stats[
                                        "distance_from_start"][ii]
                            eval_output_file.flush()
                            if eval_stats["num_episodes"][ii] == dataset_sizes[
                                    ii]:
                                to_pause.append(ii)

                        episode_rewards.append(current_episode_rewards[ii])
                        current_episode_rewards[ii] = 0
                        episode_lengths.append(current_episode_lengths[ii])
                        current_episode_lengths[ii] = 0
                        eval_stats["start_geodesic_distance"][ii] = obs[
                            "goal_geodesic_distance"][ii]

                # If done then clean the history of observations.
                masks = torch.FloatTensor([[0.0] if done_ else [1.0]
                                           for done_ in dones]).to(self.device)

                # Reverse in order to maintain order in case of multiple.
                to_pause.reverse()
                for ii in to_pause:
                    # Pause the environments that are done from the vectorenv.
                    print("Pausing env", ii)
                    self.envs.unwrapped.pause_at(ii)
                    current_episode_rewards = np.concatenate(
                        (current_episode_rewards[:ii],
                         current_episode_rewards[ii + 1:]))
                    current_episode_lengths = np.concatenate(
                        (current_episode_lengths[:ii],
                         current_episode_lengths[ii + 1:]))
                    for key in eval_stats:
                        eval_stats[key] = np.concatenate(
                            (eval_stats[key][:ii], eval_stats[key][ii + 1:]))
                    dataset_sizes = np.concatenate(
                        (dataset_sizes[:ii], dataset_sizes[ii + 1:]))

                    for key in obs:
                        if type(obs[key]) == torch.Tensor:
                            obs[key] = torch.cat(
                                (obs[key][:ii], obs[key][ii + 1:]), dim=0)
                        else:
                            obs[key] = np.concatenate(
                                (obs[key][:ii], obs[key][ii + 1:]), axis=0)

                    recurrent_hidden_states = torch.cat(
                        (recurrent_hidden_states[:ii],
                         recurrent_hidden_states[ii + 1:]),
                        dim=0)
                    masks = torch.cat((masks[:ii], masks[ii + 1:]), dim=0)

                if len(dataset_sizes) == 0:
                    progress_bar.close()
                    all_done = True

            total_num_steps += self.shell_args.num_processes

            if iter_count % (self.shell_args.log_interval * 100) == 0:
                log_dict = {}
                if len(episode_rewards) > 1:
                    end = time.time()
                    nsteps = total_num_steps - fps_timer[1]
                    fps = int((total_num_steps - fps_timer[1]) /
                              (end - fps_timer[0]))
                    timers /= nsteps
                    env_spf = timers[0]
                    forward_spf = timers[1]
                    print((
                        "{} Updates {}, num timesteps {}, FPS {}, Env FPS {}, "
                        "\n Last {} training episodes: mean/median reward {:.3f}/{:.3f}, "
                        "min/max reward {:.3f}/{:.3f}\n").format(
                            datetime.datetime.now(),
                            iter_count,
                            total_num_steps,
                            fps,
                            int(1.0 / env_spf),
                            len(episode_rewards),
                            np.mean(episode_rewards),
                            np.median(episode_rewards),
                            np.min(episode_rewards),
                            np.max(episode_rewards),
                        ))

                    if self.shell_args.tensorboard:
                        log_dict.update({
                            "stats/full_spf":
                            1.0 / (fps + 1e-10),
                            "stats/env_spf":
                            env_spf,
                            "stats/forward_spf":
                            forward_spf,
                            "stats/full_fps":
                            fps,
                            "stats/env_fps":
                            1.0 / (env_spf + 1e-10),
                            "stats/forward_fps":
                            1.0 / (forward_spf + 1e-10),
                            "episode/mean_rewards":
                            np.mean(episode_rewards),
                            "episode/median_rewards":
                            np.median(episode_rewards),
                            "episode/min_rewards":
                            np.min(episode_rewards),
                            "episode/max_rewards":
                            np.max(episode_rewards),
                            "episode/mean_lengths":
                            np.mean(episode_lengths),
                            "episode/median_lengths":
                            np.median(episode_lengths),
                            "episode/min_lengths":
                            np.min(episode_lengths),
                            "episode/max_lengths":
                            np.max(episode_lengths),
                        })
                        self.eval_logger.dict_log(log_dict, step=self.log_iter)
                    fps_timer[0] = time.time()
                    fps_timer[1] = total_num_steps
                    timers[:] = 0
            iter_count += 1
        print("Finished testing")
        print("Wrote results to", eval_output_file.name)

        eval_stats_means = {
            key: val / eval_stats_means["num_episodes"]
            for key, val in eval_stats_means.items()
        }
        if self.shell_args.tensorboard:
            log_dict = {"single_episode/reward": eval_stats_means["reward"]}
            if self.shell_args.task == "pointnav":
                log_dict.update({
                    "single_episode/num_steps":
                    eval_stats_means["num_steps"],
                    "single_episode/spl":
                    eval_stats_means["spl"],
                    "single_episode/success":
                    eval_stats_means["success"],
                    "single_episode/start_geodesic_distance":
                    eval_stats_means["start_geodesic_distance"],
                    "single_episode/end_geodesic_distance":
                    eval_stats_means["end_geodesic_distance"],
                    "single_episode/delta_geodesic_distance":
                    eval_stats_means["delta_geodesic_distance"],
                })
            elif self.shell_args.task == "exploration":
                log_dict["single_episode/visited_states"] = eval_stats_means[
                    "visited_states"]
            elif self.shell_args.task == "flee":
                log_dict[
                    "single_episode/distance_from_start"] = eval_stats_means[
                        "distance_from_start"]
            self.eval_logger.dict_log(log_dict, step=self.log_iter)
        self.envs.unwrapped.resume_all()
Пример #6
0
def main():
    torch_devices = [
        int(gpu_id.strip()) for gpu_id in args.pytorch_gpu_ids.split(",")
    ]
    render_gpus = [
        int(gpu_id.strip()) for gpu_id in args.render_gpu_ids.split(",")
    ]
    device = "cuda:" + str(torch_devices[0])

    decoder_output_info = [("reconstruction", 3), ("depth", 1),
                           ("surface_normals", 3)]
    if USE_SEMANTIC:
        decoder_output_info.append(("semantic", 41))

    model = ShallowVisualEncoder(decoder_output_info)
    model = pt_util.get_data_parallel(model, torch_devices)
    model = pt_util.DummyScope(model, ["base", "visual_encoder"])
    model.to(device)

    print("Model constructed")
    print(model)

    train_transforms = transforms.Compose([
        transforms.ToPILImage(),
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(224)
    ])

    train_transforms_depth = transforms.Compose([
        PIL.Image.fromarray,
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(224), np.array
    ])

    train_transforms_semantic = transforms.Compose([
        transforms.ToPILImage(),
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(224)
    ])

    sensors = ["RGB_SENSOR", "DEPTH_SENSOR"
               ] + (["SEMANTIC_SENSOR"] if USE_SEMANTIC else [])
    if args.dataset == "suncg":
        data_train = HabitatImageGenerator(
            render_gpus,
            "suncg",
            args.data_subset,
            "data/dumps/suncg/{split}/dataset_one_ep_per_scene.json.gz",
            images_before_reset=1000,
            sensors=sensors,
            transform=train_transforms,
            depth_transform=train_transforms_depth,
            semantic_transform=train_transforms_semantic,
        )
        print("Num train images", len(data_train))

        data_test = HabitatImageGenerator(
            render_gpus,
            "suncg",
            "val",
            "data/dumps/suncg/{split}/dataset_one_ep_per_scene.json.gz",
            images_before_reset=1000,
            sensors=sensors,
        )
    elif args.dataset == "mp3d":
        data_train = HabitatImageGenerator(
            render_gpus,
            "mp3d",
            args.data_subset,
            "data/dumps/mp3d/{split}/dataset_one_ep_per_scene.json.gz",
            images_before_reset=1000,
            sensors=sensors,
            transform=train_transforms,
            depth_transform=train_transforms_depth,
            semantic_transform=train_transforms_semantic,
        )
        print("Num train images", len(data_train))

        data_test = HabitatImageGenerator(
            render_gpus,
            "mp3d",
            "val",
            "data/dumps/mp3d/{split}/dataset_one_ep_per_scene.json.gz",
            images_before_reset=1000,
            sensors=sensors,
        )
    elif args.dataset == "gibson":
        data_train = HabitatImageGenerator(
            render_gpus,
            "gibson",
            args.data_subset,
            "data/datasets/pointnav/gibson/v1/{split}/{split}.json.gz",
            images_before_reset=1000,
            sensors=sensors,
            transform=train_transforms,
            depth_transform=train_transforms_depth,
            semantic_transform=train_transforms_semantic,
        )
        print("Num train images", len(data_train))

        data_test = HabitatImageGenerator(
            render_gpus,
            "gibson",
            "val",
            "data/datasets/pointnav/gibson/v1/{split}/{split}.json.gz",
            images_before_reset=1000,
            sensors=sensors,
        )
    else:
        raise NotImplementedError("No rule for this dataset.")

    print("Num train images", len(data_train))
    print("Num val images", len(data_test))

    print("Using device", device)
    print("num cpus:", args.num_processes)

    train_loader = torch.utils.data.DataLoader(
        data_train,
        batch_size=BATCH_SIZE,
        num_workers=args.num_processes,
        worker_init_fn=data_train.worker_init_fn,
        shuffle=False,
        pin_memory=True,
    )
    test_loader = torch.utils.data.DataLoader(
        data_test,
        batch_size=TEST_BATCH_SIZE,
        num_workers=len(render_gpus) if args.num_processes > 0 else 0,
        worker_init_fn=data_test.worker_init_fn,
        shuffle=False,
        pin_memory=True,
    )

    log_prefix = args.log_prefix
    time_str = misc_util.get_time_str()
    checkpoint_dir = os.path.join(log_prefix, args.checkpoint_dirname,
                                  time_str)

    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    start_iter = 0
    if args.load_model:
        start_iter = pt_util.restore_from_folder(
            model, os.path.join(log_prefix, args.checkpoint_dirname, "*"))

    train_logger = None
    test_logger = None
    if args.tensorboard:
        train_logger = tensorboard_logger.Logger(
            os.path.join(log_prefix, args.tensorboard_dirname,
                         time_str + "_train"))
        test_logger = tensorboard_logger.Logger(
            os.path.join(log_prefix, args.tensorboard_dirname,
                         time_str + "_test"))

    total_num_steps = start_iter

    if args.save_checkpoints and not args.no_weight_update:
        pt_util.save(model,
                     checkpoint_dir,
                     num_to_keep=5,
                     iteration=total_num_steps)

    evaluate_model(model, device, test_loader, total_num_steps, test_logger,
                   decoder_output_info)

    for epoch in range(0, EPOCHS + 1):
        total_num_steps = train_model(model, device, train_loader, optimizer,
                                      total_num_steps, train_logger,
                                      decoder_output_info, checkpoint_dir)
        evaluate_model(model, device, test_loader, total_num_steps,
                       test_logger, decoder_output_info)