示例#1
0
def draw_outputs(output, labels, mode):
    output[:, 4:7] = output[:, 4:7] / output[:, 4:7].norm(dim=1, keepdim=True)
    output = pt_util.to_numpy(output)
    labels = {key: pt_util.to_numpy(val) for key, val in labels.items()}
    if USE_SEMANTIC:
        labels["semantic"][:, 0, 1] = 0
        labels["semantic"][:, 0, 0] = 40
    for bb in range(output.shape[0]):
        output_on = output[bb]
        labels_on = {key: val[bb] for key, val in labels.items()}
        output_semantic = None
        if USE_SEMANTIC:
            output_semantic = np.argmax(output_on[7:], axis=0)
            output_semantic[0, 0] = 40
            output_semantic[0, 1] = 0
        images = [
            labels_on["rgb"].transpose(1, 2, 0),
            255 - np.clip((labels_on["depth"] + 0.5).squeeze() * 255, 0, 255),
            (np.clip(labels_on["surface_normals"] + 1, 0, 2) * 127).astype(np.uint8).transpose(1, 2, 0),
            labels_on["semantic"].squeeze().astype(np.uint8) if USE_SEMANTIC else None,
            np.clip((output_on[:3] + 0.5) * 255, 0, 255).astype(np.uint8).transpose(1, 2, 0),
            255 - np.clip((output_on[3] + 0.5).squeeze() * 255, 0, 255),
            (np.clip(output_on[4:7] + 1, 0, 2) * 127).astype(np.uint8).transpose(1, 2, 0),
            output_semantic.astype(np.uint8) if USE_SEMANTIC else None,
        ]
        titles = ["rgb", "depth", "normals", "semantic", "rgb_pred", "depth_pred", "normals_pred", "semantic_pred"]

        image = drawing.subplot(
            images, 2, 4, 256, 256, titles=titles, normalize=[False, False, False, True, False, False, False, True]
        )
        cv2.imshow("im_" + mode, image[:, :, ::-1])
        cv2.waitKey(0)
    def get_image_output(self, network_outputs):
        with torch.no_grad():
            image_output = {}
            predictions = torch.argmax(network_outputs["outputs"], dim=1)
            labels = network_outputs["labels"]

            batch_size = network_outputs["batch_size"]
            seq_len = network_outputs["num_frames"]

            acc = pt_util.to_numpy(predictions == labels)

            inputs = network_outputs["data"]
            inputs = to_uint8(inputs)
            im_height, im_width = inputs.shape[1:3]

            inputs = pt_util.split_dim(inputs, 0, batch_size, seq_len)

            rand_order = np.random.choice(len(inputs), min(len(inputs), seq_len), replace=False)

            scale_factor = im_width / 320.0
            images = []
            for bb in rand_order:
                correct = acc[bb]
                image_seq = inputs[bb].copy()
                pred_cls = self.ind_to_label_func(predictions[bb])
                gt_cls = self.ind_to_label_func(labels[bb])
                for ii, image in enumerate(image_seq):
                    if correct:
                        image[:10, :, :] = (0, 255, 0)
                        image[-10:, :, :] = (0, 255, 0)
                        image[:, :10, :] = (0, 255, 0)
                        image[:, -10:, :] = (0, 255, 0)
                    else:
                        image[:10, :, :] = (255, 0, 0)
                        image[-10:, :, :] = (255, 0, 0)
                        image[:, :10, :] = (255, 0, 0)
                        image[:, -10:, :] = (255, 0, 0)
                    if ii == 0:
                        image = drawing.draw_contrast_text_cv2(
                            image, "P: " + pred_cls, (10, 10 + int(30 * scale_factor))
                        )
                        if not correct:
                            image = drawing.draw_contrast_text_cv2(
                                image, "GT: " + gt_cls, (10, 10 + int(2 * 30 * scale_factor))
                            )
                    images.append(image)

            n_cols = seq_len
            n_rows = len(images) // n_cols

            subplot = drawing.subplot(images, n_rows, n_cols, im_width, im_height)
            image_output["images/classifier_outputs"] = subplot
            return image_output
def draw_nns(source_features, source_images, source_name, target_features=None, target_images=None, target_name=None):
    skip_first = False
    if target_features is None:
        target_features = source_features
        target_images = source_images
        target_name = source_name
        skip_first = True

    num_to_compare = target_features.shape[0]
    torch.manual_seed(0)
    random.seed(0)
    np.random.seed(0)
    rand_selection = np.sort(np.random.choice(source_features.shape[0], NUM_QUERIES, replace=False))

    query_features = source_features[rand_selection]

    dists = torch.mm(query_features, target_features.T)
    val, neighbors = torch.topk(dists, k=(NUM_NEIGHBORS + int(skip_first)), dim=1, sorted=True, largest=True)
    if skip_first:
        neighbors = neighbors[:, 1:]

    neighbors = target_images[pt_util.to_numpy(neighbors)]
    os.makedirs(
        os.path.join(args.checkpoint_dir, "neighbors_from_%s_to_%s" % (source_name, target_name)), exist_ok=True
    )

    # Get images
    for ii in tqdm.tqdm(range(neighbors.shape[0])):
        images = []
        image = source_images[rand_selection[ii]].copy()
        image = np.pad(image, ((10, 10), (10, 10), (0, 0)), "constant")
        images.append(image)
        for jj in range(neighbors.shape[1]):
            image = neighbors[ii, jj].copy()
            images.append(image)

        subplot = drawing.subplot(images, 1, neighbors.shape[1] + 1, args.input_width, args.input_height, border=5)
        cv2.imwrite(
            os.path.join(
                args.checkpoint_dir,
                "neighbors_from_%s_to_%s" % (source_name, target_name),
                "bsize_%06d_%03d.jpg" % (num_to_compare, ii),
            ),
            subplot[:, :, ::-1],
        )
示例#4
0
    def get_image_output(self, network_outputs) -> Dict[str, np.ndarray]:
        image_output = {}
        exemplar_images = to_uint8(network_outputs["data"])
        track_images = to_uint8(network_outputs["track_data"])
        responses = pt_util.to_numpy(network_outputs["responses"].squeeze(1))
        labels = pt_util.to_numpy(self._create_labels(responses.shape))
        batch_size, _, im_height, im_width = network_outputs[
            "track_data"].shape

        images = []
        for exemplar_image, track_image, response, label in zip(
                exemplar_images, track_images, responses, labels):
            images.extend([exemplar_image, track_image, response, label])
            if len(images) > ((4 * 2)**2):
                break

        subplot = drawing.subplot(images, 4 * 2, 4 * 2, im_width, im_height)
        image_output["images/tracks"] = subplot

        return image_output
示例#5
0
    def get_image_output(self, network_outputs) -> Dict[str, np.ndarray]:
        with torch.no_grad():
            image_output = {}

            # matching image
            batch_size, _, im_height, im_width = network_outputs["data"].shape

            inputs = network_outputs["data"]
            queue_inputs = network_outputs["queue_data"]
            inputs = to_uint8(inputs, padding=10)
            queue_inputs = to_uint8(queue_inputs, padding=10)
            num_frames = 1 if self.num_frames is None else self.num_frames
            inputs = pt_util.split_dim(inputs, 0, -1, num_frames)
            queue_inputs = pt_util.split_dim(queue_inputs, 0, -1, num_frames)
            images = []
            color = (255, 128, 0)
            for bb in range(
                    min(len(inputs), max(2 * num_frames,
                                         int(32 / num_frames)))):
                for ss in range(num_frames):
                    image = inputs[bb, ss]
                    images.append(image)
                for ss in range(num_frames):
                    image = queue_inputs[bb, ss].copy()
                    image[:10, :, :] = color
                    image[-10:, :, :] = color
                    image[:, :10, :] = color
                    image[:, -10:, :] = color
                    images.append(image)

            n_cols = max(2 * num_frames, 8)
            n_rows = len(images) // n_cols
            subplot = drawing.subplot(images, n_rows, n_cols, im_width,
                                      im_height)
            image_output["images/inputs"] = subplot

            if "vince_similarities" in network_outputs:
                # Nearest neighbor image
                inputs = network_outputs["data"]
                queue_inputs = network_outputs["queue_data"]

                inputs = to_uint8(inputs, padding=10)
                queue_inputs = to_uint8(queue_inputs, padding=10)

                vince_similarities = network_outputs["vince_similarities"]
                logits = vince_similarities / self.args.vince_temperature
                vince_softmax = F.softmax(logits, dim=1)

                queue_images = network_outputs["queue_images"]

                n_neighbors = 9
                topk_val, topk_ind = torch.topk(vince_softmax,
                                                n_neighbors,
                                                dim=1,
                                                largest=True,
                                                sorted=True)
                topk_ind = pt_util.to_numpy(topk_ind)
                topk_val = pt_util.to_numpy(topk_val)

                label = network_outputs["vince_similarities_mask"]

                images = []
                rand_order = np.random.choice(batch_size,
                                              min(batch_size, n_neighbors + 1),
                                              replace=False)
                for bb in rand_order:
                    query_image = inputs[bb].copy()
                    color = (90, 46, 158)
                    if network_outputs["batch_type"] == "images":
                        # Different colors for imagenet vs videos.
                        color = (24, 178, 24)
                    query_image[:10, :, :] = color
                    query_image[-10:, :, :] = color
                    query_image[:, :10, :] = color
                    query_image[:, -10:, :] = color
                    images.append(query_image)
                    found_neighbor = False
                    for nn, neighbor in enumerate(topk_ind[bb]):
                        color = (128, 128, 128)
                        score = topk_val[bb, nn]

                        if self.args.inter_batch_comparison:
                            if neighbor < batch_size:
                                image = queue_inputs[neighbor].copy()
                                data_source = network_outputs["data_source"]
                            else:
                                # Offset by batch_size for the inter-batch negatives
                                offset = batch_size
                                image = to_uint8(queue_images[neighbor -
                                                              offset],
                                                 padding=10)
                                data_source = network_outputs[
                                    "queue_data_sources"][neighbor - offset]
                        else:
                            if neighbor == 0:
                                image = queue_inputs[bb].copy()
                                data_source = network_outputs["data_source"]
                            else:
                                # Offset by 1 for the positive examples
                                image = to_uint8(queue_images[neighbor - 1],
                                                 padding=10)
                                data_source = network_outputs[
                                    "queue_data_sources"][neighbor - 1]

                        if label[bb, neighbor]:
                            if self.args.inter_batch_comparison and neighbor < batch_size:
                                found_neighbor = True
                                color = (255, 128, 0)
                            elif neighbor == 0:
                                found_neighbor = True
                                color = (255, 128, 0)
                            elif data_source == "self":
                                color = (144, 72, 0)
                            else:
                                color = (0, 0, 203)
                        elif data_source == "self":
                            color = (255, 0, 193)
                        if not found_neighbor and nn == n_neighbors - 1:
                            # Last one in row, couldn't match proper, put in just to show what it looks like.
                            image = queue_inputs[bb].copy()
                            color = (255, 0, 0)

                        if color == (128, 128, 128):
                            color = (90, 46, 158)
                            if data_source == "IN":
                                # Different colors for imagenet vs videos.
                                color = (24, 178, 24)
                        image[:10, :, :] = color
                        image[-10:, :, :] = color
                        image[:, :10, :] = color
                        image[:, -10:, :] = color
                        images.append(image)

                n_rows = n_neighbors + 1
                n_cols = n_neighbors + 1
                subplot = drawing.subplot(images, n_rows, n_cols, im_width,
                                          im_height)
                image_output["images/outputs"] = subplot

            if network_outputs["data_source"] == "IN":
                # imagenet image
                predictions = torch.argmax(
                    network_outputs["imagenet_decoder_0"], dim=1)
                labels = network_outputs["imagenet_labels"]
                acc = pt_util.to_numpy(predictions == labels)
                batch_size = acc.shape[0]

                inputs = network_outputs["data"][:batch_size]
                inputs = to_uint8(inputs, padding=10)

                images = []
                rand_order = np.random.choice(len(inputs),
                                              min(len(inputs), 25),
                                              replace=False)
                scale_factor = im_width / 320.0

                for bb in rand_order:
                    correct = acc[bb]
                    image = inputs[bb].copy()
                    pred_cls = util_functions.imagenet_label_to_class(
                        predictions[bb])
                    gt_cls = util_functions.imagenet_label_to_class(labels[bb])
                    if correct:
                        cls_str = pred_cls
                    else:
                        cls_str = "Pred: %s Actual %s" % (pred_cls, gt_cls)

                    if correct:
                        image[:10, :, :] = (0, 255, 0)
                        image[-10:, :, :] = (0, 255, 0)
                        image[:, :10, :] = (0, 255, 0)
                        image[:, -10:, :] = (0, 255, 0)
                    else:
                        image[:10, :, :] = (255, 0, 0)
                        image[-10:, :, :] = (255, 0, 0)
                        image[:, :10, :] = (255, 0, 0)
                        image[:, -10:, :] = (255, 0, 0)
                    image = drawing.draw_contrast_text_cv2(
                        image, "P: " + pred_cls,
                        (10, 10 + int(30 * scale_factor)))
                    if not correct:
                        image = drawing.draw_contrast_text_cv2(
                            image, "GT: " + gt_cls,
                            (10, 10 + int(2 * 30 * scale_factor)))
                    images.append(image)

                n_cols = int(np.sqrt(len(images)))
                n_rows = len(images) // n_cols

                subplot = drawing.subplot(images, n_rows, n_cols, im_width,
                                          im_height)
                image_output["images/imagenet_outputs"] = subplot

            if "attention_masks" in network_outputs:
                # Attention image
                inputs = network_outputs["data"]
                inputs = to_uint8(inputs, padding=10)

                queue_inputs = network_outputs["queue_data"]
                queue_inputs = to_uint8(queue_inputs, padding=10)

                attention_masks = network_outputs["attention_masks"]
                attention_masks = pt_util.to_numpy(
                    F.interpolate(attention_masks, (im_height, im_width),
                                  mode="bilinear",
                                  align_corners=False).permute(0, 2, 3, 1))
                attention_masks = np.pad(attention_masks,
                                         ((0, 0), (10, 10), (10, 10), (0, 0)),
                                         "constant")

                queue_attention_masks = network_outputs[
                    "queue_attention_masks"]
                queue_attention_masks = pt_util.to_numpy(
                    F.interpolate(queue_attention_masks, (im_height, im_width),
                                  mode="bilinear",
                                  align_corners=False).permute(0, 2, 3, 1))
                queue_attention_masks = np.pad(queue_attention_masks,
                                               ((0, 0), (10, 10), (10, 10),
                                                (0, 0)), "constant")

                rand_order = np.random.choice(len(inputs),
                                              min(len(inputs), 25),
                                              replace=False)

                subplots = []
                attention_color = np.array([255, 0, 0], dtype=np.float32)
                for bb in rand_order:
                    images = []
                    for img_src, mask_src in ((inputs, attention_masks),
                                              (queue_inputs,
                                               queue_attention_masks)):
                        image = img_src[bb].copy()
                        attention_mask = mask_src[bb].copy()
                        attention_mask -= attention_mask.min()
                        attention_mask /= attention_mask.max() + 1e-8
                        output = (attention_mask * attention_color
                                  ) + (1 - attention_mask) * image
                        output = output.astype(np.uint8)
                        images.append(image)
                        images.append(output)
                    subplot = drawing.subplot(images, 2, 2, im_width,
                                              im_height)
                    subplots.append(subplot)

                n_cols = int(np.sqrt(len(subplots)))
                n_rows = len(subplots) // n_cols

                subplot = drawing.subplot(subplots,
                                          n_rows,
                                          n_cols,
                                          im_width * 2,
                                          im_height * 2,
                                          border=5)
                image_output["images/attention"] = subplot

        return image_output
示例#6
0
    def train_model(self):
        episode_rewards = deque(maxlen=10)
        current_episode_rewards = np.zeros(self.shell_args.num_processes)
        episode_lengths = deque(maxlen=10)
        current_episode_lengths = np.zeros(self.shell_args.num_processes)
        current_rewards = np.zeros(self.shell_args.num_processes)

        total_num_steps = self.start_iter
        fps_timer = [time.time(), total_num_steps]
        timers = np.zeros(3)
        egomotion_loss = 0

        video_frames = []
        num_episodes = 0
        # self.evaluate_model()

        obs = self.envs.reset()
        if self.compute_surface_normals:
            obs["surface_normals"] = pt_util.depth_to_surface_normals(
                obs["depth"].to(self.device))
        obs["prev_action_one_hot"] = obs[
            "prev_action_one_hot"][:, ACTION_SPACE].to(torch.float32)
        if self.shell_args.algo == "supervised":
            obs["best_next_action"] = pt_util.from_numpy(
                obs["best_next_action"][:, ACTION_SPACE])
        self.rollouts.copy_obs(obs, 0)
        distances = pt_util.to_numpy_array(obs["goal_geodesic_distance"])
        self.train_stats["start_geodesic_distance"][:] = distances
        previous_visual_features = None
        egomotion_pred = None
        prev_action = None
        prev_action_probs = None
        num_updates = (int(self.shell_args.num_env_steps) //
                       self.shell_args.num_forward_rollout_steps
                       ) // self.shell_args.num_processes

        try:
            for iter_count in range(num_updates):
                if self.shell_args.tensorboard:
                    if iter_count % 500 == 0:
                        print("Logging conv summaries")
                        self.logger.network_conv_summary(
                            self.agent, total_num_steps)
                    elif iter_count % 100 == 0:
                        print("Logging variable summaries")
                        self.logger.network_variable_summary(
                            self.agent, total_num_steps)

                if self.shell_args.use_linear_lr_decay:
                    # decrease learning rate linearly
                    update_linear_schedule(self.optimizer.optimizer,
                                           iter_count, num_updates,
                                           self.shell_args.lr)

                if self.shell_args.algo == "ppo" and self.shell_args.use_linear_clip_decay:
                    self.optimizer.clip_param = self.shell_args.clip_param * (
                        1 - iter_count / float(num_updates))

                if hasattr(self.agent.base, "enable_decoder"):
                    if self.shell_args.record_video:
                        self.agent.base.enable_decoder()
                    else:
                        self.agent.base.disable_decoder()

                for step in range(self.shell_args.num_forward_rollout_steps):
                    with torch.no_grad():
                        start_t = time.time()
                        value, action, action_log_prob, recurrent_hidden_states = self.agent.act(
                            {
                                "images":
                                self.rollouts.obs[step],
                                "target_vector":
                                self.rollouts.additional_observations_dict[
                                    "pointgoal"][step],
                                "prev_action_one_hot":
                                self.rollouts.additional_observations_dict[
                                    "prev_action_one_hot"][step],
                            },
                            self.rollouts.recurrent_hidden_states[step],
                            self.rollouts.masks[step],
                        )
                        action_cpu = pt_util.to_numpy_array(action.squeeze(1))
                        translated_action_space = ACTION_SPACE[action_cpu]
                        if not self.shell_args.end_to_end:
                            self.rollouts.additional_observations_dict[
                                "visual_encoder_features"][
                                    self.rollouts.step].copy_(
                                        self.agent.base.visual_encoder_features
                                    )

                        if self.shell_args.use_motion_loss:
                            if self.shell_args.record_video:
                                if previous_visual_features is not None:
                                    egomotion_pred = self.agent.base.predict_egomotion(
                                        self.agent.base.visual_features,
                                        previous_visual_features)
                            previous_visual_features = self.agent.base.visual_features.detach(
                            )

                        timers[1] += time.time() - start_t

                        if self.shell_args.record_video:
                            # Copy so we don't mess with obs itself
                            draw_obs = OrderedDict()
                            for key, val in obs.items():
                                draw_obs[key] = pt_util.to_numpy_array(
                                    val).copy()
                            best_next_action = draw_obs.pop(
                                "best_next_action", None)

                            if prev_action is not None:
                                draw_obs[
                                    "action_taken"] = pt_util.to_numpy_array(
                                        self.agent.last_dist.probs).copy()
                                draw_obs["action_taken"][:] = 0
                                draw_obs["action_taken"][
                                    np.arange(self.shell_args.num_processes),
                                    prev_action] = 1
                                draw_obs[
                                    "action_taken_name"] = SIM_ACTION_TO_NAME[
                                        ACTION_SPACE_TO_SIM_ACTION[
                                            ACTION_SPACE[
                                                prev_action.squeeze()]]]
                                draw_obs[
                                    "action_prob"] = pt_util.to_numpy_array(
                                        prev_action_probs).copy()
                            else:
                                draw_obs["action_taken"] = None
                                draw_obs[
                                    "action_taken_name"] = SIM_ACTION_TO_NAME[
                                        SimulatorActions.STOP]
                                draw_obs["action_prob"] = None
                            prev_action = action_cpu
                            prev_action_probs = self.agent.last_dist.probs.detach(
                            )
                            if (hasattr(self.agent.base, "decoder_outputs")
                                    and self.agent.base.decoder_outputs
                                    is not None):
                                min_channel = 0
                                for key, num_channels in self.agent.base.decoder_output_info:
                                    outputs = self.agent.base.decoder_outputs[:,
                                                                              min_channel:
                                                                              min_channel
                                                                              +
                                                                              num_channels,
                                                                              ...]
                                    draw_obs["output_" +
                                             key] = pt_util.to_numpy_array(
                                                 outputs).copy()
                                    min_channel += num_channels
                            draw_obs["rewards"] = current_rewards.copy()
                            draw_obs["step"] = current_episode_lengths.copy()
                            draw_obs["method"] = self.shell_args.method_name
                            if best_next_action is not None:
                                draw_obs["best_next_action"] = best_next_action
                            if self.shell_args.use_motion_loss:
                                if egomotion_pred is not None:
                                    draw_obs[
                                        "egomotion_pred"] = pt_util.to_numpy_array(
                                            F.softmax(egomotion_pred,
                                                      dim=1)).copy()
                                else:
                                    draw_obs["egomotion_pred"] = None
                            images, titles, normalize = draw_outputs.obs_to_images(
                                draw_obs)
                            if self.shell_args.algo == "supervised":
                                im_inds = [0, 2, 3, 1, 9, 6, 7, 8, 5, 4]
                            else:
                                im_inds = [0, 2, 3, 1, 6, 7, 8, 5]
                            height, width = images[0].shape[:2]
                            subplot_image = drawing.subplot(
                                images,
                                2,
                                5,
                                titles=titles,
                                normalize=normalize,
                                order=im_inds,
                                output_width=max(width, 320),
                                output_height=max(height, 320),
                            )
                            video_frames.append(subplot_image)

                        # save dists from previous step or else on reset they will be overwritten
                        distances = pt_util.to_numpy_array(
                            obs["goal_geodesic_distance"])

                        start_t = time.time()
                        obs, rewards, dones, infos = self.envs.step(
                            translated_action_space)
                        timers[0] += time.time() - start_t
                        obs["reward"] = rewards
                        if self.shell_args.algo == "supervised":
                            obs["best_next_action"] = pt_util.from_numpy(
                                obs["best_next_action"][:, ACTION_SPACE]).to(
                                    torch.float32)
                        obs["prev_action_one_hot"] = obs[
                            "prev_action_one_hot"][:, ACTION_SPACE].to(
                                torch.float32)
                        rewards *= REWARD_SCALAR
                        rewards = np.clip(rewards, -10, 10)

                        if self.shell_args.record_video and not dones[0]:
                            obs["top_down_map"] = infos[0]["top_down_map"]

                        if self.compute_surface_normals:
                            obs["surface_normals"] = pt_util.depth_to_surface_normals(
                                obs["depth"].to(self.device))

                        current_rewards = pt_util.to_numpy_array(rewards)
                        current_episode_rewards += pt_util.to_numpy_array(
                            rewards).squeeze()
                        current_episode_lengths += 1
                        for ii, done_e in enumerate(dones):
                            if done_e:
                                num_episodes += 1
                                if self.shell_args.record_video:
                                    final_rgb = draw_obs["rgb"].transpose(
                                        0, 2, 3, 1).squeeze(0)
                                    if self.shell_args.task == "pointnav":
                                        if infos[ii]["spl"] > 0:
                                            draw_obs[
                                                "action_taken_name"] = "Stop. Success"
                                            draw_obs["reward"] = [
                                                self.configs[0].TASK.
                                                SUCCESS_REWARD
                                            ]
                                            final_rgb[:] = final_rgb * np.float32(
                                                0.5) + np.tile(
                                                    np.array([0, 128, 0],
                                                             dtype=np.uint8),
                                                    (final_rgb.shape[0],
                                                     final_rgb.shape[1], 1),
                                                )
                                        else:
                                            draw_obs[
                                                "action_taken_name"] = "Timeout. Failed"
                                            final_rgb[:] = final_rgb * np.float32(
                                                0.5) + np.tile(
                                                    np.array([128, 0, 0],
                                                             dtype=np.uint8),
                                                    (final_rgb.shape[0],
                                                     final_rgb.shape[1], 1),
                                                )
                                    elif self.shell_args.task == "exploration" or self.shell_args.task == "flee":
                                        draw_obs[
                                            "action_taken_name"] = "End of episode."
                                    final_rgb = final_rgb[np.newaxis,
                                                          ...].transpose(
                                                              0, 3, 1, 2)
                                    draw_obs["rgb"] = final_rgb

                                    images, titles, normalize = draw_outputs.obs_to_images(
                                        draw_obs)
                                    im_inds = [0, 2, 3, 1, 6, 7, 8, 5]
                                    height, width = images[0].shape[:2]
                                    subplot_image = drawing.subplot(
                                        images,
                                        2,
                                        5,
                                        titles=titles,
                                        normalize=normalize,
                                        order=im_inds,
                                        output_width=max(width, 320),
                                        output_height=max(height, 320),
                                    )
                                    video_frames.extend(
                                        [subplot_image] *
                                        (self.configs[0].ENVIRONMENT.
                                         MAX_EPISODE_STEPS + 30 -
                                         len(video_frames)))

                                    if "top_down_map" in infos[0]:
                                        video_dir = os.path.join(
                                            self.shell_args.log_prefix,
                                            "videos")
                                        if not os.path.exists(video_dir):
                                            os.makedirs(video_dir)
                                        im_path = os.path.join(
                                            self.shell_args.log_prefix,
                                            "videos", "total_steps_%d.png" %
                                            total_num_steps)
                                        from habitat.utils.visualizations import maps
                                        import imageio

                                        top_down_map = maps.colorize_topdown_map(
                                            infos[0]["top_down_map"]["map"])
                                        imageio.imsave(im_path, top_down_map)

                                    images_to_video(
                                        video_frames,
                                        os.path.join(
                                            self.shell_args.log_prefix,
                                            "videos"),
                                        "total_steps_%d" % total_num_steps,
                                    )
                                    video_frames = []

                                if self.shell_args.task == "pointnav":
                                    print(
                                        "FINISHED EPISODE %d Length %d Reward %.3f SPL %.4f"
                                        % (
                                            num_episodes,
                                            current_episode_lengths[ii],
                                            current_episode_rewards[ii],
                                            infos[ii]["spl"],
                                        ))
                                    self.train_stats["spl"][ii] = infos[ii][
                                        "spl"]
                                    self.train_stats["success"][
                                        ii] = self.train_stats["spl"][ii] > 0
                                    self.train_stats["end_geodesic_distance"][
                                        ii] = (distances[ii] - self.configs[0].
                                               SIMULATOR.FORWARD_STEP_SIZE)
                                    self.train_stats[
                                        "delta_geodesic_distance"][ii] = (
                                            self.train_stats[
                                                "start_geodesic_distance"][ii]
                                            - self.train_stats[
                                                "end_geodesic_distance"][ii])
                                    self.train_stats["num_steps"][
                                        ii] = current_episode_lengths[ii]
                                elif self.shell_args.task == "exploration":
                                    print(
                                        "FINISHED EPISODE %d Reward %.3f States Visited %d"
                                        % (num_episodes,
                                           current_episode_rewards[ii],
                                           infos[ii]["visited_states"]))
                                    self.train_stats["visited_states"][
                                        ii] = infos[ii]["visited_states"]
                                elif self.shell_args.task == "flee":
                                    print(
                                        "FINISHED EPISODE %d Reward %.3f Distance from start %.4f"
                                        % (num_episodes,
                                           current_episode_rewards[ii],
                                           infos[ii]["distance_from_start"]))
                                    self.train_stats["distance_from_start"][
                                        ii] = infos[ii]["distance_from_start"]

                                self.train_stats["num_episodes"][ii] += 1
                                self.train_stats["reward"][
                                    ii] = current_episode_rewards[ii]

                                if self.shell_args.tensorboard:
                                    log_dict = {
                                        "single_episode/reward":
                                        self.train_stats["reward"][ii]
                                    }
                                    if self.shell_args.task == "pointnav":
                                        log_dict.update({
                                            "single_episode/num_steps":
                                            self.train_stats["num_steps"][ii],
                                            "single_episode/spl":
                                            self.train_stats["spl"][ii],
                                            "single_episode/success":
                                            self.train_stats["success"][ii],
                                            "single_episode/start_geodesic_distance":
                                            self.train_stats[
                                                "start_geodesic_distance"][ii],
                                            "single_episode/end_geodesic_distance":
                                            self.train_stats[
                                                "end_geodesic_distance"][ii],
                                            "single_episode/delta_geodesic_distance":
                                            self.train_stats[
                                                "delta_geodesic_distance"][ii],
                                        })
                                    elif self.shell_args.task == "exploration":
                                        log_dict[
                                            "single_episode/visited_states"] = self.train_stats[
                                                "visited_states"][ii]
                                    elif self.shell_args.task == "flee":
                                        log_dict[
                                            "single_episode/distance_from_start"] = self.train_stats[
                                                "distance_from_start"][ii]
                                    self.logger.dict_log(
                                        log_dict,
                                        step=(total_num_steps +
                                              self.shell_args.num_processes *
                                              step + ii))

                                episode_rewards.append(
                                    current_episode_rewards[ii])
                                current_episode_rewards[ii] = 0
                                episode_lengths.append(
                                    current_episode_lengths[ii])
                                current_episode_lengths[ii] = 0
                                self.train_stats["start_geodesic_distance"][
                                    ii] = obs["goal_geodesic_distance"][ii]

                        # If done then clean the history of observations.
                        masks = torch.FloatTensor([[0.0] if done_ else [1.0]
                                                   for done_ in dones])
                        bad_masks = torch.FloatTensor(
                            [[0.0]
                             if "bad_transition" in info.keys() else [1.0]
                             for info in infos])

                        self.rollouts.insert(obs, recurrent_hidden_states,
                                             action, action_log_prob, value,
                                             rewards, masks, bad_masks)

                with torch.no_grad():
                    start_t = time.time()
                    next_value = self.agent.get_value(
                        {
                            "images":
                            self.rollouts.obs[-1],
                            "target_vector":
                            self.rollouts.
                            additional_observations_dict["pointgoal"][-1],
                            "prev_action_one_hot":
                            self.rollouts.additional_observations_dict[
                                "prev_action_one_hot"][-1],
                        },
                        self.rollouts.recurrent_hidden_states[-1],
                        self.rollouts.masks[-1],
                    ).detach()
                    timers[1] += time.time() - start_t

                self.rollouts.compute_returns(next_value,
                                              self.shell_args.use_gae,
                                              self.shell_args.gamma,
                                              self.shell_args.tau)

                if not self.shell_args.no_weight_update:
                    start_t = time.time()
                    if self.shell_args.algo == "supervised":
                        (
                            total_loss,
                            action_loss,
                            visual_loss_total,
                            visual_loss_dict,
                            egomotion_loss,
                            forward_model_loss,
                        ) = self.optimizer.update(self.rollouts,
                                                  self.shell_args)
                    else:
                        (
                            total_loss,
                            value_loss,
                            action_loss,
                            dist_entropy,
                            visual_loss_total,
                            visual_loss_dict,
                            egomotion_loss,
                            forward_model_loss,
                        ) = self.optimizer.update(self.rollouts,
                                                  self.shell_args)

                    timers[2] += time.time() - start_t

                self.rollouts.after_update()

                # save for every interval-th episode or for the last epoch
                if iter_count % self.shell_args.save_interval == 0 or iter_count == num_updates - 1:
                    self.save_checkpoint(5, total_num_steps)

                total_num_steps += self.shell_args.num_processes * self.shell_args.num_forward_rollout_steps

                if not self.shell_args.no_weight_update and iter_count % self.shell_args.log_interval == 0:
                    log_dict = {}
                    if len(episode_rewards) > 1:
                        end = time.time()
                        nsteps = total_num_steps - fps_timer[1]
                        fps = int((total_num_steps - fps_timer[1]) /
                                  (end - fps_timer[0]))
                        timers /= nsteps
                        env_spf = timers[0]
                        forward_spf = timers[1]
                        backward_spf = timers[2]
                        print((
                            "{} Updates {}, num timesteps {}, FPS {}, Env FPS "
                            "{}, \n Last {} training episodes: mean/median reward "
                            "{:.3f}/{:.3f}, min/max reward {:.3f}/{:.3f}\n"
                        ).format(
                            datetime.datetime.now(),
                            iter_count,
                            total_num_steps,
                            fps,
                            int(1.0 / env_spf),
                            len(episode_rewards),
                            np.mean(episode_rewards),
                            np.median(episode_rewards),
                            np.min(episode_rewards),
                            np.max(episode_rewards),
                        ))

                        if self.shell_args.tensorboard:
                            log_dict.update({
                                "stats/full_spf":
                                1.0 / (fps + 1e-10),
                                "stats/env_spf":
                                env_spf,
                                "stats/forward_spf":
                                forward_spf,
                                "stats/backward_spf":
                                backward_spf,
                                "stats/full_fps":
                                fps,
                                "stats/env_fps":
                                1.0 / (env_spf + 1e-10),
                                "stats/forward_fps":
                                1.0 / (forward_spf + 1e-10),
                                "stats/backward_fps":
                                1.0 / (backward_spf + 1e-10),
                                "episode/mean_rewards":
                                np.mean(episode_rewards),
                                "episode/median_rewards":
                                np.median(episode_rewards),
                                "episode/min_rewards":
                                np.min(episode_rewards),
                                "episode/max_rewards":
                                np.max(episode_rewards),
                                "episode/mean_lengths":
                                np.mean(episode_lengths),
                                "episode/median_lengths":
                                np.median(episode_lengths),
                                "episode/min_lengths":
                                np.min(episode_lengths),
                                "episode/max_lengths":
                                np.max(episode_lengths),
                            })
                        fps_timer[0] = time.time()
                        fps_timer[1] = total_num_steps
                        timers[:] = 0
                    if self.shell_args.tensorboard:
                        log_dict.update({
                            "loss/action":
                            action_loss,
                            "loss/0_total":
                            total_loss,
                            "loss/visual/0_total":
                            visual_loss_total,
                            "loss/exploration/egomotion":
                            egomotion_loss,
                            "loss/exploration/forward_model":
                            forward_model_loss,
                        })
                        if self.shell_args.algo != "supervised":
                            log_dict.update({
                                "loss/entropy": dist_entropy,
                                "loss/value": value_loss
                            })
                        for key, val in visual_loss_dict.items():
                            log_dict["loss/visual/" + key] = val
                        self.logger.dict_log(log_dict, step=total_num_steps)

                if self.shell_args.eval_interval is not None and total_num_steps % self.shell_args.eval_interval < (
                        self.shell_args.num_processes *
                        self.shell_args.num_forward_rollout_steps):
                    self.save_checkpoint(-1, total_num_steps)
                    self.set_log_iter(total_num_steps)
                    self.evaluate_model()
                    # reset the env datasets
                    self.envs.unwrapped.call(
                        ["switch_dataset"] * self.shell_args.num_processes,
                        [("train", )] * self.shell_args.num_processes)
                    obs = self.envs.reset()
                    if self.compute_surface_normals:
                        obs["surface_normals"] = pt_util.depth_to_surface_normals(
                            obs["depth"].to(self.device))
                    obs["prev_action_one_hot"] = obs[
                        "prev_action_one_hot"][:,
                                               ACTION_SPACE].to(torch.float32)
                    if self.shell_args.algo == "supervised":
                        obs["best_next_action"] = pt_util.from_numpy(
                            obs["best_next_action"][:, ACTION_SPACE])
                    self.rollouts.copy_obs(obs, 0)
                    distances = pt_util.to_numpy_array(
                        obs["goal_geodesic_distance"])
                    self.train_stats["start_geodesic_distance"][:] = distances
                    previous_visual_features = None
                    egomotion_pred = None
                    prev_action = None
                    prev_action_probs = None
        except:
            # Catch all exceptions so a final save can be performed
            import traceback

            traceback.print_exc()
        finally:
            self.save_checkpoint(-1, total_num_steps)
示例#7
0
    def run_val(self):
        with torch.no_grad():
            self.model.eval()
            time_meters = dict(
                total_time=RollingAverageMeter(self.args.log_frequency),
                data_cache_time=RollingAverageMeter(self.args.log_frequency),
                forward_time=RollingAverageMeter(self.args.log_frequency),
                metrics_time=RollingAverageMeter(self.args.log_frequency),
            )
            loss_meters = {
                key: RollingAverageMeter(self.args.log_frequency)
                for key in self.model.loss(None).keys()
            }
            if len(loss_meters) > 1:
                loss_meters["total_loss"] = RollingAverageMeter(
                    self.args.log_frequency)
            metric_meters = {
                metric: RollingAverageMeter(self.args.log_frequency)
                for metric in self.model.get_metrics(None).keys()
            }

            epoch_loss_meters = {
                "epoch_" + key: AverageMeter()
                for key in loss_meters.keys()
            }
            epoch_metric_meters = {
                "epoch_" + key: AverageMeter()
                for key in metric_meters.keys()
            }

            updated_epoch_loss_meters = set()
            updated_epoch_metric_meters = set()

            step_on = self.iteration

            for val_name, val_loader, data_processor in zip(
                    self.val_data_names, self.val_loaders, self.val_batch_fns):
                print("Running val for", val_name)
                total_t_start = time.time()
                test_t_start = time.time()
                for ii, image_batch in enumerate(tqdm.tqdm(val_loader)):
                    if test_t_start - time.time() > 5 * 60:
                        # Break after 5 minutes.
                        break
                    image_batch = data_processor(image_batch)
                    image_batch["batch_types"] = [image_batch["batch_type"]]
                    del image_batch["batch_type"]
                    image_batch["batch_sizes"] = [image_batch["batch_size"]]
                    del image_batch["batch_size"]
                    image_batch = {
                        key: (val.to(self.model.device, non_blocking=True)
                              if isinstance(val, torch.Tensor) else val)
                        for key, val in image_batch.items()
                    }

                    batch_size = image_batch["data"].shape[0]

                    t_end = time.time()
                    time_meters["data_cache_time"].update(t_end -
                                                          total_t_start)
                    t_start = time.time()

                    image_batch.update(self.queue_model(image_batch)[0])
                    image_batch.update(
                        self.model.get_embeddings(image_batch)[0])
                    image_batch.update(self.vince_queue.dequeue())

                    image_batch.update(self.model(image_batch))
                    output = image_batch
                    loss_dict = self.model.loss(output)

                    t_end = time.time()
                    time_meters["forward_time"].update(t_end - t_start)
                    t_start = time.time()

                    metrics = self.model.get_metrics(output)
                    if ii % self.args.image_log_frequency == 0:
                        image_output = self.model.get_image_output(output)

                    updated_loss_meters = set()
                    total_loss = 0
                    for key, val in loss_dict.items():
                        weighted_loss = val[0] * val[1]
                        total_loss = total_loss + weighted_loss
                        loss_meters[key].update(weighted_loss)
                        epoch_loss_meters["epoch_" + key].update(
                            weighted_loss, batch_size)
                        updated_loss_meters.add(key)
                        updated_epoch_loss_meters.add("epoch_" + key)
                    if "total_loss" in loss_meters:
                        loss_meters["total_loss"].update(total_loss)
                        epoch_loss_meters["epoch_total_loss"].update(
                            total_loss, batch_size)
                        updated_loss_meters.add("total_loss")
                        updated_epoch_loss_meters.add("epoch_total_loss")
                    loss = total_loss

                    try:
                        assert torch.isfinite(loss)
                    except:
                        # output = self.model.forward(image_batch)
                        print("Nan loss", loss_dict)

                    updated_metric_meters = set()
                    for key, val in metrics.items():
                        metric_meters[key].update(val)
                        updated_metric_meters.add(key)
                        epoch_metric_meters["epoch_" + key].update(
                            val, batch_size)
                        updated_epoch_metric_meters.add("epoch_" + key)

                    t_end = time.time()
                    time_meters["metrics_time"].update(t_end - t_start)

                    if ii % self.args.image_log_frequency == 0:
                        if self.val_logger is not None:
                            for key, val in image_output.items():
                                if isinstance(val, list):
                                    for vv, item in enumerate(val):
                                        self.val_logger.image_summary(
                                            self.full_name + "_" +
                                            key[len("images/"):], item,
                                            step_on + vv, False)
                                else:
                                    self.val_logger.image_summary(
                                        self.full_name + "_" +
                                        key[len("images/"):], val, step_on,
                                        False)

                    if ii % self.args.log_frequency == 0:
                        log_dict = {
                            "times/%s/%s" % (self.full_name, key): val.val
                            for key, val in time_meters.items()
                        }
                        log_dict.update({
                            "losses/%s/%s" % (self.full_name, key):
                            loss_meters[key].val
                            for key in updated_loss_meters
                        })
                        log_dict.update({
                            "metrics/%s/%s" % (self.full_name, key):
                            metric_meters[key].val
                            for key in updated_metric_meters
                        })
                        if self.val_logger is not None:
                            self.val_logger.dict_log(log_dict, step_on)

                    step_on += self.args.batch_size
                    total_t_end = time.time()
                    time_meters["total_time"].update(total_t_end -
                                                     total_t_start)
                    total_t_start = time.time()

            ##### CIFAR #####
            epoch_metric_meters["epoch_knn_cifar"] = AverageMeter()

            all_features = []
            imagenet_mean = pt_util.from_numpy(constants.IMAGENET_MEAN).to(
                self.model.device).view(1, -1, 1, 1)
            imagenet_std = pt_util.from_numpy(constants.IMAGENET_STD).to(
                self.model.device).view(1, -1, 1, 1)

            print("Running CIFAR")
            for start_ind in tqdm.tqdm(
                    range(0, len(self.cifar_dataset), self.args.batch_size)):
                data = self.cifar_dataset.data[
                    start_ind:min(len(self.cifar_dataset), start_ind +
                                  self.args.batch_size)]
                data = data.to(device=self.model.device, dtype=torch.float32)
                data = data - imagenet_mean
                data.div_(imagenet_std)
                features = self.model.get_embeddings({"data":
                                                      data})["embeddings"]
                all_features.append(pt_util.to_numpy(features))
            all_images = np.transpose(
                pt_util.to_numpy(self.cifar_dataset.data), (0, 2, 3, 1))
            labels = pt_util.to_numpy(self.cifar_dataset.labels)
            all_features = np.concatenate(all_features, axis=0)
            if len(all_features.shape) == 4:
                # all_features = pt_util.remove_dim(all_features, dim=(2, 3))
                all_features = np.mean(all_features, axis=(2, 3))

            if self.val_logger is not None:
                kdt = KDTree(all_features, leaf_size=40, metric="euclidean")
                neighbors = kdt.query(all_features, k=11)[1]
                # remove self match
                neighbors = neighbors[:, 1:]
                preds_all = labels[neighbors]
                preds = scipy.stats.mode(preds_all, axis=1)[0].squeeze(1)
                acc = np.mean(preds == labels)
                epoch_metric_meters["epoch_knn_cifar"].update(acc)
                updated_epoch_metric_meters.add("epoch_knn_cifar")

                nn_inds = kdt.query(all_features[0:100:10], k=10)[1]
                image = drawing.subplot(all_images[nn_inds.reshape(-1)],
                                        10,
                                        10,
                                        self.args.input_width,
                                        self.args.input_height,
                                        border=10)

                self.val_logger.image_summary(self.full_name + "_kNN/cifar",
                                              image,
                                              step_on,
                                              increment_counter=False,
                                              max_size=1000)

        log_dict = {
            "epoch/losses/%s/%s" % (self.full_name, key):
            epoch_loss_meters[key].avg
            for key in updated_epoch_loss_meters
        }
        log_dict.update({
            "epoch/metrics/%s/%s" % (self.full_name, key):
            epoch_metric_meters[key].avg
            for key in updated_epoch_metric_meters
        })
        if self.val_logger is not None:
            self.val_logger.dict_log(log_dict, step_on)
示例#8
0
    def update(self, img):
        # set to evaluation mode
        self.net.eval()

        # search images
        x = [
            ops.crop_and_resize(img,
                                self.center,
                                self.x_sz * f,
                                out_size=self.cfg["instance_sz"],
                                border_value=self.avg_color)
            for f in self.scale_factors
        ]
        if self.visualize:
            search_images = drawing.subplot(x, 1, 3, x[0].shape[1],
                                            x[0].shape[0], 5)
            cv2.imshow("search_image", search_images[:, :, ::-1])
            cv2.waitKey(1)
        x = np.stack(x, axis=0)
        x = self.image_to_torch(x)

        # responses
        x = self.net.extract_features(x)
        responses = self.net.head(self.kernel, x)
        # responses = torch.sigmoid(responses)
        # responses = responses.squeeze(1).cpu().numpy()

        # upsample responses and penalize scale changes
        responses = F.interpolate(responses,
                                  size=(self.upscale_sz, self.upscale_sz),
                                  mode="bicubic",
                                  align_corners=False)
        responses = pt_util.to_numpy(responses.squeeze(1))
        if self.visualize:
            response_image = drawing.subplot(-responses, 1, 3,
                                             responses.shape[2],
                                             responses.shape[1], 5)
            cv2.imshow("response image", response_image)

        responses[:self.cfg["scale_num"] // 2] *= self.cfg["scale_penalty"]
        responses[self.cfg["scale_num"] // 2 + 1:] *= self.cfg["scale_penalty"]

        # peak scale
        scale_id = np.argmax(np.amax(responses, axis=(1, 2)))

        # peak location
        response = responses[scale_id]
        response -= response.min()
        response /= response.sum() + 1e-16
        response = (1 - self.cfg["window_influence"]) * response + self.cfg[
            "window_influence"] * self.hann_window
        loc = np.unravel_index(response.argmax(), response.shape)
        if self.visualize:
            loc_result = response == response.max()
            loc_result = loc_result.astype(np.uint8) * 255
            loc_result = cv2.dilate(loc_result,
                                    np.ones((9, 9), dtype=loc_result.dtype),
                                    iterations=1)
            loc_result = np.tile(loc_result[..., np.newaxis], (1, 1, 1, 3))
            cv2.imshow(
                "response max",
                drawing.subplot([-response, loc_result], 1, 2,
                                loc_result.shape[2], loc_result.shape[1], 5),
            )

        # locate target center
        disp_in_response = np.array(loc) - (self.upscale_sz - 1) / 2
        disp_in_instance = disp_in_response * self.cfg[
            "total_stride"] * 1.0 / self.cfg["response_up"]
        disp_in_image = disp_in_instance * self.x_sz * self.scale_factors[
            scale_id] / self.cfg["instance_sz"]
        # disp_in_image = disp_in_response * self.x_sz * self.scale_factors[scale_id] / self.upscale_sz
        self.center += disp_in_image
        if self.visualize:
            print(
                "loc",
                loc,
                "original center change",
                disp_in_response,
                "center change",
                disp_in_image,
                "new center",
                self.center,
            )

        # update target size
        scale = (1 - self.cfg["scale_lr"]
                 ) * 1.0 + self.cfg["scale_lr"] * self.scale_factors[scale_id]
        self.target_sz *= scale
        self.z_sz *= scale
        self.x_sz *= scale

        # return 1-indexed and left-top based bounding box
        box = np.array([
            self.center[1] + 1 - (self.target_sz[1] - 1) / 2,
            self.center[0] + 1 - (self.target_sz[0] - 1) / 2,
            self.target_sz[1],
            self.target_sz[0],
        ])
        if self.visualize:
            cv2.waitKey(0)

        return box
示例#9
0
    def evaluate_model(self):
        self.envs.unwrapped.call(["switch_dataset"] *
                                 self.shell_args.num_processes,
                                 [("val", )] * self.shell_args.num_processes)

        if not os.path.exists(self.eval_dir):
            os.makedirs(self.eval_dir)
        try:
            eval_net_file_name = sorted(
                glob.glob(
                    os.path.join(self.shell_args.log_prefix,
                                 self.shell_args.checkpoint_dirname, "*") +
                    "/*.pt"),
                key=os.path.getmtime,
            )[-1]
            eval_net_file_name = (
                self.shell_args.log_prefix.replace(os.sep, "_") + "_" +
                "_".join(eval_net_file_name.split(os.sep)[-2:])[:-3])
        except IndexError:
            print("Warning, no weights found")
            eval_net_file_name = "random_weights"
        eval_output_file = open(
            os.path.join(self.eval_dir, eval_net_file_name + ".csv"), "w")
        print("Writing results to", eval_output_file.name)

        # Save the evaled net for posterity
        if self.shell_args.save_checkpoints:
            save_model = self.agent
            pt_util.save(
                save_model,
                os.path.join(self.shell_args.log_prefix,
                             self.shell_args.checkpoint_dirname,
                             "eval_weights"),
                num_to_keep=-1,
                iteration=self.log_iter,
            )
            print("Wrote model to file for safe keeping")

        obs = self.envs.reset()
        if self.compute_surface_normals:
            obs["surface_normals"] = pt_util.depth_to_surface_normals(
                obs["depth"].to(self.device))
        obs["prev_action_one_hot"] = obs[
            "prev_action_one_hot"][:, ACTION_SPACE].to(torch.float32)
        recurrent_hidden_states = torch.zeros(
            self.shell_args.num_processes,
            self.agent.recurrent_hidden_state_size,
            dtype=torch.float32,
            device=self.device,
        )
        masks = torch.ones(self.shell_args.num_processes,
                           1,
                           dtype=torch.float32,
                           device=self.device)

        episode_rewards = deque(maxlen=10)
        current_episode_rewards = np.zeros(self.shell_args.num_processes)
        episode_lengths = deque(maxlen=10)
        current_episode_lengths = np.zeros(self.shell_args.num_processes)

        total_num_steps = self.log_iter
        fps_timer = [time.time(), total_num_steps]
        timers = np.zeros(3)

        num_episodes = 0

        print("Config\n", self.configs[0])

        # Initialize every time eval is run rather than just at the start
        dataset_sizes = np.array(
            [len(dataset.episodes) for dataset in self.eval_datasets])

        eval_stats = dict(
            episode_ids=[None for _ in range(self.shell_args.num_processes)],
            num_episodes=np.zeros(self.shell_args.num_processes,
                                  dtype=np.int32),
            num_steps=np.zeros(self.shell_args.num_processes, dtype=np.int32),
            reward=np.zeros(self.shell_args.num_processes, dtype=np.float32),
            spl=np.zeros(self.shell_args.num_processes, dtype=np.float32),
            visited_states=np.zeros(self.shell_args.num_processes,
                                    dtype=np.int32),
            success=np.zeros(self.shell_args.num_processes, dtype=np.int32),
            end_geodesic_distance=np.zeros(self.shell_args.num_processes,
                                           dtype=np.float32),
            start_geodesic_distance=np.zeros(self.shell_args.num_processes,
                                             dtype=np.float32),
            delta_geodesic_distance=np.zeros(self.shell_args.num_processes,
                                             dtype=np.float32),
            distance_from_start=np.zeros(self.shell_args.num_processes,
                                         dtype=np.float32),
        )
        eval_stats_means = dict(
            num_episodes=0,
            num_steps=0,
            reward=0,
            spl=0,
            visited_states=0,
            success=0,
            end_geodesic_distance=0,
            start_geodesic_distance=0,
            delta_geodesic_distance=0,
            distance_from_start=0,
        )
        eval_output_file.write("name,%s,iter,%d\n\n" %
                               (eval_net_file_name, self.log_iter))
        if self.shell_args.task == "pointnav":
            eval_output_file.write((
                "episode_id,num_steps,reward,spl,success,start_geodesic_distance,"
                "end_geodesic_distance,delta_geodesic_distance\n"))
        elif self.shell_args.task == "exploration":
            eval_output_file.write("episode_id,reward,visited_states\n")
        elif self.shell_args.task == "flee":
            eval_output_file.write("episode_id,reward,distance_from_start\n")
        distances = pt_util.to_numpy(obs["goal_geodesic_distance"])
        eval_stats["start_geodesic_distance"][:] = distances
        progress_bar = tqdm.tqdm(total=self.num_eval_episodes_total)
        all_done = False
        iter_count = 0
        video_frames = []
        previous_visual_features = None
        egomotion_pred = None
        prev_action = None
        prev_action_probs = None
        if hasattr(self.agent.base, "enable_decoder"):
            if self.shell_args.record_video:
                self.agent.base.enable_decoder()
            else:
                self.agent.base.disable_decoder()
        while not all_done:
            with torch.no_grad():
                start_t = time.time()
                value, action, action_log_prob, recurrent_hidden_states = self.agent.act(
                    {
                        "images":
                        obs["rgb"].to(self.device),
                        "target_vector":
                        obs["pointgoal"].to(self.device),
                        "prev_action_one_hot":
                        obs["prev_action_one_hot"].to(self.device),
                    },
                    recurrent_hidden_states,
                    masks,
                )
                action_cpu = pt_util.to_numpy(action.squeeze(1))
                translated_action_space = ACTION_SPACE[action_cpu]

                timers[1] += time.time() - start_t

                if self.shell_args.record_video:
                    if self.shell_args.use_motion_loss:
                        if previous_visual_features is not None:
                            egomotion_pred = self.agent.base.predict_egomotion(
                                self.agent.base.visual_features,
                                previous_visual_features)
                        previous_visual_features = self.agent.base.visual_features.detach(
                        )

                    # Copy so we don't mess with obs itself
                    draw_obs = OrderedDict()
                    for key, val in obs.items():
                        draw_obs[key] = pt_util.to_numpy(val).copy()
                    best_next_action = draw_obs.pop("best_next_action", None)

                    if prev_action is not None:
                        draw_obs["action_taken"] = pt_util.to_numpy(
                            self.agent.last_dist.probs).copy()
                        draw_obs["action_taken"][:] = 0
                        draw_obs["action_taken"][
                            np.arange(self.shell_args.num_processes),
                            prev_action] = 1
                        draw_obs["action_taken_name"] = SIM_ACTION_TO_NAME[
                            draw_obs['prev_action'].item()]
                        draw_obs["action_prob"] = pt_util.to_numpy(
                            prev_action_probs).copy()
                    else:
                        draw_obs["action_taken"] = None
                        draw_obs["action_taken_name"] = SIM_ACTION_TO_NAME[
                            SimulatorActions.STOP]
                        draw_obs["action_prob"] = None
                    prev_action = action_cpu
                    prev_action_probs = self.agent.last_dist.probs.detach()
                    if hasattr(
                            self.agent.base, "decoder_outputs"
                    ) and self.agent.base.decoder_outputs is not None:
                        min_channel = 0
                        for key, num_channels in self.agent.base.decoder_output_info:
                            outputs = self.agent.base.decoder_outputs[:,
                                                                      min_channel:
                                                                      min_channel
                                                                      +
                                                                      num_channels,
                                                                      ...]
                            draw_obs["output_" +
                                     key] = pt_util.to_numpy(outputs).copy()
                            min_channel += num_channels
                    draw_obs["rewards"] = eval_stats["reward"]
                    draw_obs["step"] = current_episode_lengths.copy()
                    draw_obs["method"] = self.shell_args.method_name
                    if best_next_action is not None:
                        draw_obs["best_next_action"] = best_next_action
                    if self.shell_args.use_motion_loss:
                        if egomotion_pred is not None:
                            draw_obs["egomotion_pred"] = pt_util.to_numpy(
                                F.softmax(egomotion_pred, dim=1)).copy()
                        else:
                            draw_obs["egomotion_pred"] = None
                    images, titles, normalize = draw_outputs.obs_to_images(
                        draw_obs)
                    im_inds = [0, 2, 3, 1, 6, 7, 8, 5]
                    height, width = images[0].shape[:2]
                    subplot_image = drawing.subplot(
                        images,
                        2,
                        4,
                        titles=titles,
                        normalize=normalize,
                        output_width=max(width, 320),
                        output_height=max(height, 320),
                        order=im_inds,
                        fancy_text=True,
                    )
                    video_frames.append(subplot_image)

                # save dists from previous step or else on reset they will be overwritten
                distances = pt_util.to_numpy(obs["goal_geodesic_distance"])

                start_t = time.time()
                obs, rewards, dones, infos = self.envs.step(
                    translated_action_space)
                timers[0] += time.time() - start_t
                obs["prev_action_one_hot"] = obs[
                    "prev_action_one_hot"][:, ACTION_SPACE].to(torch.float32)
                rewards *= REWARD_SCALAR
                rewards = np.clip(rewards, -10, 10)

                if self.shell_args.record_video and not dones[0]:
                    obs["top_down_map"] = infos[0]["top_down_map"]

                if self.compute_surface_normals:
                    obs["surface_normals"] = pt_util.depth_to_surface_normals(
                        obs["depth"].to(self.device))

                current_episode_rewards += pt_util.to_numpy(rewards).squeeze()
                current_episode_lengths += 1
                to_pause = []
                for ii, done_e in enumerate(dones):
                    if done_e:
                        num_episodes += 1

                        if self.shell_args.record_video:
                            if "top_down_map" in infos[ii]:
                                video_dir = os.path.join(
                                    self.shell_args.log_prefix, "videos")
                                if not os.path.exists(video_dir):
                                    os.makedirs(video_dir)
                                im_path = os.path.join(
                                    self.shell_args.log_prefix, "videos",
                                    "total_steps_%d.png" % total_num_steps)
                                top_down_map = maps.colorize_topdown_map(
                                    infos[ii]["top_down_map"]["map"])
                                imageio.imsave(im_path, top_down_map)

                            images_to_video(
                                video_frames,
                                os.path.join(self.shell_args.log_prefix,
                                             "videos"),
                                "total_steps_%d" % total_num_steps,
                            )
                            video_frames = []

                        eval_stats["episode_ids"][ii] = infos[ii]["episode_id"]

                        if self.shell_args.task == "pointnav":
                            print(
                                "FINISHED EPISODE %d Length %d Reward %.3f SPL %.4f"
                                % (
                                    num_episodes,
                                    current_episode_lengths[ii],
                                    current_episode_rewards[ii],
                                    infos[ii]["spl"],
                                ))
                            eval_stats["spl"][ii] = infos[ii]["spl"]
                            eval_stats["success"][
                                ii] = eval_stats["spl"][ii] > 0
                            eval_stats["num_steps"][
                                ii] = current_episode_lengths[ii]
                            eval_stats["end_geodesic_distance"][ii] = (
                                infos[ii]["final_distance"] if
                                eval_stats["success"][ii] else distances[ii])
                            eval_stats["delta_geodesic_distance"][ii] = (
                                eval_stats["start_geodesic_distance"][ii] -
                                eval_stats["end_geodesic_distance"][ii])
                        elif self.shell_args.task == "exploration":
                            print(
                                "FINISHED EPISODE %d Reward %.3f States Visited %d"
                                % (num_episodes, current_episode_rewards[ii],
                                   infos[ii]["visited_states"]))
                            eval_stats["visited_states"][ii] = infos[ii][
                                "visited_states"]
                        elif self.shell_args.task == "flee":
                            print(
                                "FINISHED EPISODE %d Reward %.3f Distance from start %.4f"
                                % (num_episodes, current_episode_rewards[ii],
                                   infos[ii]["distance_from_start"]))
                            eval_stats["distance_from_start"][ii] = infos[ii][
                                "distance_from_start"]

                        eval_stats["num_episodes"][ii] += 1
                        eval_stats["reward"][ii] = current_episode_rewards[ii]

                        if eval_stats["num_episodes"][ii] <= dataset_sizes[ii]:
                            progress_bar.update(1)
                            eval_stats_means["num_episodes"] += 1
                            eval_stats_means["reward"] += eval_stats["reward"][
                                ii]
                            if self.shell_args.task == "pointnav":
                                eval_output_file.write(
                                    "%s,%d,%f,%f,%d,%f,%f,%f\n" % (
                                        eval_stats["episode_ids"][ii],
                                        eval_stats["num_steps"][ii],
                                        eval_stats["reward"][ii],
                                        eval_stats["spl"][ii],
                                        eval_stats["success"][ii],
                                        eval_stats["start_geodesic_distance"]
                                        [ii],
                                        eval_stats["end_geodesic_distance"]
                                        [ii],
                                        eval_stats["delta_geodesic_distance"]
                                        [ii],
                                    ))
                                eval_stats_means["num_steps"] += eval_stats[
                                    "num_steps"][ii]
                                eval_stats_means["spl"] += eval_stats["spl"][
                                    ii]
                                eval_stats_means["success"] += eval_stats[
                                    "success"][ii]
                                eval_stats_means[
                                    "start_geodesic_distance"] += eval_stats[
                                        "start_geodesic_distance"][ii]
                                eval_stats_means[
                                    "end_geodesic_distance"] += eval_stats[
                                        "end_geodesic_distance"][ii]
                                eval_stats_means[
                                    "delta_geodesic_distance"] += eval_stats[
                                        "delta_geodesic_distance"][ii]
                            elif self.shell_args.task == "exploration":
                                eval_output_file.write("%s,%f,%d\n" % (
                                    eval_stats["episode_ids"][ii],
                                    eval_stats["reward"][ii],
                                    eval_stats["visited_states"][ii],
                                ))
                                eval_stats_means[
                                    "visited_states"] += eval_stats[
                                        "visited_states"][ii]
                            elif self.shell_args.task == "flee":
                                eval_output_file.write("%s,%f,%f\n" % (
                                    eval_stats["episode_ids"][ii],
                                    eval_stats["reward"][ii],
                                    eval_stats["distance_from_start"][ii],
                                ))
                                eval_stats_means[
                                    "distance_from_start"] += eval_stats[
                                        "distance_from_start"][ii]
                            eval_output_file.flush()
                            if eval_stats["num_episodes"][ii] == dataset_sizes[
                                    ii]:
                                to_pause.append(ii)

                        episode_rewards.append(current_episode_rewards[ii])
                        current_episode_rewards[ii] = 0
                        episode_lengths.append(current_episode_lengths[ii])
                        current_episode_lengths[ii] = 0
                        eval_stats["start_geodesic_distance"][ii] = obs[
                            "goal_geodesic_distance"][ii]

                # If done then clean the history of observations.
                masks = torch.FloatTensor([[0.0] if done_ else [1.0]
                                           for done_ in dones]).to(self.device)

                # Reverse in order to maintain order in case of multiple.
                to_pause.reverse()
                for ii in to_pause:
                    # Pause the environments that are done from the vectorenv.
                    print("Pausing env", ii)
                    self.envs.unwrapped.pause_at(ii)
                    current_episode_rewards = np.concatenate(
                        (current_episode_rewards[:ii],
                         current_episode_rewards[ii + 1:]))
                    current_episode_lengths = np.concatenate(
                        (current_episode_lengths[:ii],
                         current_episode_lengths[ii + 1:]))
                    for key in eval_stats:
                        eval_stats[key] = np.concatenate(
                            (eval_stats[key][:ii], eval_stats[key][ii + 1:]))
                    dataset_sizes = np.concatenate(
                        (dataset_sizes[:ii], dataset_sizes[ii + 1:]))

                    for key in obs:
                        if type(obs[key]) == torch.Tensor:
                            obs[key] = torch.cat(
                                (obs[key][:ii], obs[key][ii + 1:]), dim=0)
                        else:
                            obs[key] = np.concatenate(
                                (obs[key][:ii], obs[key][ii + 1:]), axis=0)

                    recurrent_hidden_states = torch.cat(
                        (recurrent_hidden_states[:ii],
                         recurrent_hidden_states[ii + 1:]),
                        dim=0)
                    masks = torch.cat((masks[:ii], masks[ii + 1:]), dim=0)

                if len(dataset_sizes) == 0:
                    progress_bar.close()
                    all_done = True

            total_num_steps += self.shell_args.num_processes

            if iter_count % (self.shell_args.log_interval * 100) == 0:
                log_dict = {}
                if len(episode_rewards) > 1:
                    end = time.time()
                    nsteps = total_num_steps - fps_timer[1]
                    fps = int((total_num_steps - fps_timer[1]) /
                              (end - fps_timer[0]))
                    timers /= nsteps
                    env_spf = timers[0]
                    forward_spf = timers[1]
                    print((
                        "{} Updates {}, num timesteps {}, FPS {}, Env FPS {}, "
                        "\n Last {} training episodes: mean/median reward {:.3f}/{:.3f}, "
                        "min/max reward {:.3f}/{:.3f}\n").format(
                            datetime.datetime.now(),
                            iter_count,
                            total_num_steps,
                            fps,
                            int(1.0 / env_spf),
                            len(episode_rewards),
                            np.mean(episode_rewards),
                            np.median(episode_rewards),
                            np.min(episode_rewards),
                            np.max(episode_rewards),
                        ))

                    if self.shell_args.tensorboard:
                        log_dict.update({
                            "stats/full_spf":
                            1.0 / (fps + 1e-10),
                            "stats/env_spf":
                            env_spf,
                            "stats/forward_spf":
                            forward_spf,
                            "stats/full_fps":
                            fps,
                            "stats/env_fps":
                            1.0 / (env_spf + 1e-10),
                            "stats/forward_fps":
                            1.0 / (forward_spf + 1e-10),
                            "episode/mean_rewards":
                            np.mean(episode_rewards),
                            "episode/median_rewards":
                            np.median(episode_rewards),
                            "episode/min_rewards":
                            np.min(episode_rewards),
                            "episode/max_rewards":
                            np.max(episode_rewards),
                            "episode/mean_lengths":
                            np.mean(episode_lengths),
                            "episode/median_lengths":
                            np.median(episode_lengths),
                            "episode/min_lengths":
                            np.min(episode_lengths),
                            "episode/max_lengths":
                            np.max(episode_lengths),
                        })
                        self.eval_logger.dict_log(log_dict, step=self.log_iter)
                    fps_timer[0] = time.time()
                    fps_timer[1] = total_num_steps
                    timers[:] = 0
            iter_count += 1
        print("Finished testing")
        print("Wrote results to", eval_output_file.name)

        eval_stats_means = {
            key: val / eval_stats_means["num_episodes"]
            for key, val in eval_stats_means.items()
        }
        if self.shell_args.tensorboard:
            log_dict = {"single_episode/reward": eval_stats_means["reward"]}
            if self.shell_args.task == "pointnav":
                log_dict.update({
                    "single_episode/num_steps":
                    eval_stats_means["num_steps"],
                    "single_episode/spl":
                    eval_stats_means["spl"],
                    "single_episode/success":
                    eval_stats_means["success"],
                    "single_episode/start_geodesic_distance":
                    eval_stats_means["start_geodesic_distance"],
                    "single_episode/end_geodesic_distance":
                    eval_stats_means["end_geodesic_distance"],
                    "single_episode/delta_geodesic_distance":
                    eval_stats_means["delta_geodesic_distance"],
                })
            elif self.shell_args.task == "exploration":
                log_dict["single_episode/visited_states"] = eval_stats_means[
                    "visited_states"]
            elif self.shell_args.task == "flee":
                log_dict[
                    "single_episode/distance_from_start"] = eval_stats_means[
                        "distance_from_start"]
            self.eval_logger.dict_log(log_dict, step=self.log_iter)
        self.envs.unwrapped.resume_all()
示例#10
0
    def get_image_output(self, network_outputs):
        with torch.no_grad():
            image_output = {}
            predictions = torch.argmax(network_outputs["classifier_output_0"],
                                       dim=1)
            labels = network_outputs["classifier_labels"]
            acc = pt_util.to_numpy(predictions == labels)
            batch_size = acc.shape[0]

            if "attention_masks" in network_outputs:
                inputs = network_outputs["data"]
                im_height, im_width = inputs.shape[2:]
                inputs = to_uint8(inputs)

                attention_masks = network_outputs["attention_masks"]
                attention_masks = pt_util.to_numpy(
                    F.interpolate(attention_masks, (im_height, im_width),
                                  mode="bilinear",
                                  align_corners=False).permute(0, 2, 3, 1))
                attention_masks = np.pad(attention_masks,
                                         ((0, 0), (10, 10), (10, 10), (0, 0)),
                                         "constant")

                rand_order = np.random.choice(len(inputs),
                                              min(len(inputs), 50),
                                              replace=False)

                images = []
                attention_color = np.array([255, 0, 0], dtype=np.float32)
                for bb in rand_order:
                    img_src = inputs
                    mask_src = attention_masks
                    image = img_src[bb].copy()
                    attention_mask = mask_src[bb].copy()
                    attention_mask -= attention_mask.min()
                    attention_mask /= attention_mask.max() + 1e-8
                    output = (attention_mask *
                              attention_color) + (1 - attention_mask) * image
                    output = output.astype(np.uint8)
                    images.append(image)
                    images.append(output)

                n_cols = int(np.sqrt(len(images)))
                # if n_cols % 2 != 0:
                # n_cols += n_cols % 2
                n_rows = len(images) // n_cols

                subplot = drawing.subplot(images,
                                          n_rows,
                                          n_cols,
                                          im_width * 2,
                                          im_height * 2,
                                          border=5)
                image_output["images/attention"] = subplot

            images = []
            inputs = network_outputs["data"][:batch_size]
            inputs = to_uint8(inputs)
            im_height, im_width = inputs.shape[1:3]
            rand_order = np.random.choice(len(inputs),
                                          min(len(inputs), 25),
                                          replace=False)
            scale_factor = im_width / 320.0
            for bb in rand_order:
                correct = acc[bb]
                image = inputs[bb].copy()
                pred_cls = self.ind_to_label_func(predictions[bb])
                gt_cls = self.ind_to_label_func(labels[bb])
                if correct:
                    image[:10, :, :] = (0, 255, 0)
                    image[-10:, :, :] = (0, 255, 0)
                    image[:, :10, :] = (0, 255, 0)
                    image[:, -10:, :] = (0, 255, 0)
                else:
                    image[:10, :, :] = (255, 0, 0)
                    image[-10:, :, :] = (255, 0, 0)
                    image[:, :10, :] = (255, 0, 0)
                    image[:, -10:, :] = (255, 0, 0)
                image = drawing.draw_contrast_text_cv2(
                    image, "P: " + pred_cls, (10, 10 + int(30 * scale_factor)))
                if not correct:
                    image = drawing.draw_contrast_text_cv2(
                        image, "GT: " + gt_cls,
                        (10, 10 + int(2 * 30 * scale_factor)))
                images.append(image)

            n_cols = int(np.sqrt(len(images)))
            n_rows = len(images) // n_cols

            subplot = drawing.subplot(images, n_rows, n_cols, im_width,
                                      im_height)
            image_output["images/classifier_outputs"] = subplot
            return image_output
示例#11
0
    collate_fn=R2V2Dataset.collate_fn,
)

all_images = []

for data in tqdm.tqdm(data_loader, total=NUM_IMAGES_PER_ROW ** 2 // args.num_frames):
    images = to_uint8(data["data"])
    for image in images:
        all_images.append(image)
    if len(all_images) >= NUM_IMAGES_PER_ROW ** 2:
        break

del data_loader

mosaic = drawing.subplot(
    all_images, NUM_IMAGES_PER_ROW, NUM_IMAGES_PER_ROW, args.input_width, args.input_height, border=5
)
cv2.imwrite("mosaic_%s.jpg" % args.title, mosaic[:, :, ::-1])
print("done with mosaic")

print("starting TSNE")
dataset = R2V2Dataset(args, "val", transform=StandardVideoTransform(args.input_size, "val"), num_images_to_return=1)

data_loader = DataLoader(
    dataset,
    batch_size=args.batch_size,
    shuffle=True,
    num_workers=args.num_workers,
    pin_memory=True,
    collate_fn=R2V2Dataset.collate_fn,
)
示例#12
0
def get_shots(
    frames: Union[np.ndarray, List[np.ndarray]], return_inds=False
) -> Union[
    List[np.ndarray],
    Tuple[List[np.ndarray], List[int]],
    List[List[np.ndarray]],
    Tuple[List[List[np.ndarray]], List[int]],
]:
    if len(frames) < 2:
        return [frames]
    last_image = frames[0]
    all_edges = None
    all_edges_inverted = None
    if isinstance(frames, np.ndarray):
        all_edges, all_edges_inverted = get_edges(frames)
    elif isinstance(frames, torch.Tensor):
        all_edges, all_edges_inverted = get_edges_pt(frames)
    if all_edges is not None:
        ecrs = batch_ECR(all_edges[:-1], all_edges_inverted[:-1], all_edges[1:], all_edges_inverted[1:], crop=False)
        threshold = np.percentile(ecrs[1:], 5)
        # Unset all the really short shots, probably false alarms.
        shots = np.where(ecrs < threshold)[0]
        shot_lengths = shots[1:] - shots[:-1]
        ecrs[shots[1:][shot_lengths < MIN_SHOT_LENGTH]] = threshold - 1e-5
        if DEBUG:
            print("threshold", threshold)
    else:
        threshold = 0.6

    if all_edges is None:
        prev_edges, prev_edges_inverted = get_edges(last_image)
    else:
        prev_edges = all_edges[0]
        prev_edges_inverted = all_edges_inverted[0]
    shot_borders = [0]
    start_t = time.time()

    for ff in range(1, len(frames) - 1):
        curr_frame = frames[ff]
        if all_edges is None:
            new_edges, new_edges_inverted = get_edges(curr_frame)
            ecr = ECR(prev_edges, prev_edges_inverted, new_edges, new_edges_inverted, crop=False)
        else:
            new_edges = all_edges[ff]
            new_edges_inverted = all_edges_inverted[ff]
            ecr = ecrs[ff]

        if DEBUG:
            if isinstance(frames, torch.Tensor):
                last_image_draw = pt_util.to_numpy(last_image).transpose(1, 2, 0)
                curr_frame_draw = pt_util.to_numpy(curr_frame).transpose(1, 2, 0)
            else:
                last_image_draw = last_image
                curr_frame_draw = curr_frame
            intersection = ~prev_edges_inverted & ~new_edges_inverted
            union = prev_edges | new_edges
            print("intersection", intersection.sum())
            print("union", union.sum())
            print("iou", intersection.sum() * 1.0 / union.sum())
            print("ecr", ecr)
            images = [
                last_image_draw,
                curr_frame_draw,
                prev_edges,
                new_edges,
                intersection,
                union,
                prev_edges & new_edges_inverted,
                new_edges & prev_edges_inverted,
            ]

            titles = ["last image", "curr image", "last edges", "curr edges", "intersection", "union"]
            print("ECR", ecr)
            image = drawing.subplot(images, 3, 2, last_image_draw.shape[1], last_image_draw.shape[0], titles=titles)
            cv2.imshow("image", image[:, :, ::-1])
            cv2.waitKey(0)

        if ecr < threshold:
            # if shot_length > 30:
            # Call this a change
            last_image = curr_frame
            if DEBUG:
                cv2.waitKey(0)
                pdb.set_trace()
            shot_borders.append(ff)
        # shot_length = 0
        prev_edges = new_edges
        prev_edges_inverted = new_edges_inverted

    shot_borders.append(len(frames))
    shot_borders = np.array(shot_borders)
    shots = []
    for ii in range(len(shot_borders) - 1):
        shots.append(frames[shot_borders[ii] : shot_borders[ii + 1]])
    if return_inds:
        return shots, shot_borders
    else:
        return shots