Esempio n. 1
0
def evaluate_endpoint_resegmentation(filename, seg_volume,
                                     resegmentation_radius,
                                     threshold=0.5):
  """Evaluates endpoint resegmentation.

  Args:
    filename: path to the file containing resegmentation results
    seg_volume: volume object with the original segmentation
    resegmentation_radius: (z, y, x) radius of the resegmentation subvolume
    threshold: threshold at which to create objects from the predicted
        object map

  Returns:
    EndpointResegmentationResult proto

  Raises:
    InvalidBaseSegmentatonError: when no base segmentation object with the
        expected ID matches the resegmentation data
  """
  id1, _, x, y, z = parse_resegmentation_filename(filename)

  result = resegmentation_pb2.EndpointSegmentationResult()
  result.id = id1
  start = result.start
  start.x, start.y, start.z = x, y, z

  sr = result.segmentation_radius
  sr.z, sr.y, sr.x = resegmentation_radius

  with gfile.Open(filename, 'rb') as f:
    data = np.load(f)
    prob = storage.dequantize_probability(data['probs'])
    prob = np.nan_to_num(prob)  # nans indicate unvisited voxels

  sr = result.segmentation_radius
  orig_seg = seg_volume[0,
                        (z - sr.z):(z + sr.z + 1),
                        (y - sr.y):(y + sr.y + 1),
                        (x - sr.x):(x + sr.x + 1)][0, ...]
  seg1 = orig_seg == id1
  if not np.any(seg1):
    raise InvalidBaseSegmentatonError()

  new_seg = prob[0, ...] >= threshold
  result.num_voxels = int(np.sum(new_seg))

  overlaps = pywrapsegment_util.ComputeOverlapCounts(
      orig_seg.ravel(), new_seg.astype(np.uint64).ravel())
  for k, v in overlaps.items():
    old, new = k
    if not new:
      continue

    result.overlaps[old].num_overlapping = v
    result.overlaps[old].num_original = int(np.sum(orig_seg == old))

    if old == id1:
      result.source.CopyFrom(result.overlaps[old])

  return result
Esempio n. 2
0
    def load(cls, path):
        """Load data from json file."""
        if '/cns/' in path:
            with gfile.Open(path, 'r') as f:
                data = json.load(f)
        else:
            with open(path, 'r') as f:
                data = json.load(f)

        keys = sorted(data.keys())
        cams = []
        for key in keys:
            if key == 'metadata':
                continue

            else:
                name = data[key]['name']
                size = data[key]['size']
                matrix = np.array(data[key]['matrix'])
                dist = np.array(data[key]['distortions'])
                rvec = np.array(data[key]['rotation'])
                tvec = np.array(data[key]['translation'])

            cams.append(Camera(rvec, tvec, matrix, dist, name, size))

        return cls(cams)
Esempio n. 3
0
 def evaluate(self):
     """Evaluate."""
     if self.slave:
         data = self.iou_per_class
         with gfile.Open(self.path, 'wb') as file:
             pickle.dump(data, file)
         logging.info(file)
         return
     else:
         iou_per_class_means = []
         for _, v in self.iou_per_class.items():
             if v:
                 iou_per_class_means.append(np.mean(v))
         return np.mean(iou_per_class_means)
Esempio n. 4
0
def _CreateCsv(config_dict, tdb, output_root, output_file, file_type,
               batch_size):
    """Writes data from db to CSV based on file type.

  Args:
    config_dict: Config file dict.
    tdb: Instance of TrackingDB.
    output_root: Path to directory of output CSV.
    output_file: Name of output CSV file.
    file_type: Type of NetCDF file.
    batch_size: Number of rows written to CSV at once.
  """
    output_file = os.path.join(output_root, output_file)
    processor = reader.NetCdfMetadataReader(config_dict, file_type)
    with gfile.Open(output_file, 'w') as csv_file:
        _PrintStderr('Saving CSV file: %s', output_file)
        csv_file.write(processor.GetCsvHeader() + '\n')
        tdb.ExportCsvData(csv_file, file_type, batch_size)
Esempio n. 5
0
 def evaluate(self):
     """Evaluate."""
     if self.slave:
         data = {
             'collisions': self.collisions,
             'intersections': self.intersections,
             'ious': self.ious
         }
         with gfile.Open(self.path, 'wb') as file:
             pickle.dump(data, file)
         logging.info(file)
         return
     else:
         # self.collisions = []
         # for k, v in self.iou_per_class.items():
         #   if len(v) > 0:
         #     iou_per_class_means.append(np.mean(v))
         return np.sum(self.collisions)
Esempio n. 6
0
def main(_):
    models_data_filename = 'models_per_split.pkl'

    with gfile.Open(models_data_filename, 'rb') as filename:
        models_data_per_class = pickle.load(filename)

    path_prefix = '/datasets/shapenet/raw/'
    local_prefix = '/occluded_primitives/meshes/'
    for split in ['val', 'train', 'test']:
        for class_name in models_data_per_class[split]:
            for model in models_data_per_class[split][class_name]:
                mesh_file = os.path.join(path_prefix, class_name, model,
                                         'models', 'model_normalized.obj')
                local_dir = os.path.join(local_prefix, split, class_name,
                                         model)
                if not os.path.exists(local_dir):
                    os.makedirs(local_dir)
                print(mesh_file)
                res = subprocess.call(['fileutil', 'cp', mesh_file, local_dir])
                print(res)
Esempio n. 7
0
def evaluate_pair_resegmentation(filename, seg_volume,
                                 resegmentation_radius,
                                 analysis_radius,
                                 voxel_size,
                                 threshold=0.5):
  """Evaluates segment pair resegmentation.

  Args:
    filename: path to the file containing resegmentation results
    seg_volume: VolumeStore object with the original segmentation
    resegmentation_radius: (z, y, x) radius of the resegmentation subvolume
    analysis_radius: (z, y, x) radius of the subvolume in which to perform
        analysis
    voxel_size: (z, y, x) voxel size in physical units
    threshold: threshold at which to create objects from the predicted
        object map

  Returns:
    PairResegmentationResult proto

  Raises:
    IncompleteResegmentationError: when the resegmentation data does not
        represent two finished segments
    InvalidBaseSegmentatonError: when no base segmentation object with the
        excepted ID matches the resegmentation data
  """
  id1, id2, x, y, z = parse_resegmentation_filename(filename)

  result = resegmentation_pb2.PairResegmentationResult()
  result.id_a, result.id_b = id1, id2
  p = result.point
  p.x, p.y, p.z = x, y, z

  sr = result.segmentation_radius
  sr.z, sr.y, sr.x = resegmentation_radius

  with gfile.Open(filename, 'rb') as f:
    data = np.load(f)
    prob = storage.dequantize_probability(data['probs'])
    prob = np.nan_to_num(prob)  # nans indicate unvisited voxels
    dels = data['deletes']
    moves = data['histories']  # z, y, x
    start_points = data['start_points']  # x, y, z

  if prob.shape[0] != 2:
    raise IncompleteResegmentationError()

  assert prob.ndim == 4

  # Corner of the resegmentation subvolume in the global coordinate system.
  corner = np.array([p.x - sr.x, p.y - sr.y, p.z - sr.z])

  # In case of multiple segmentation attempts, the last recorded start
  # point is the one we care about.
  origin_a = np.array(start_points[0][-1], dtype=np.int) + corner
  origin_b = np.array(start_points[1][-1], dtype=np.int) + corner
  oa = result.eval.from_a.origin
  oa.x, oa.y, oa.z = origin_a
  ob = result.eval.from_b.origin
  ob.x, ob.y, ob.z = origin_b

  # Record basic infromation about the resegmentation run.
  analysis_r = np.array(analysis_radius)
  r = result.eval.radius
  r.z, r.y, r.x = analysis_r

  seg = seg_volume[0,
                   (z - analysis_r[0]):(z + analysis_r[0] + 1),
                   (y - analysis_r[1]):(y + analysis_r[1] + 1),
                   (x - analysis_r[2]):(x + analysis_r[2] + 1)][0, ...]
  seg1 = seg == id1
  seg2 = seg == id2
  result.eval.num_voxels_a = int(np.sum(seg1))
  result.eval.num_voxels_b = int(np.sum(seg2))

  if result.eval.num_voxels_a == 0 or result.eval.num_voxels_b == 0:
    raise InvalidBaseSegmentatonError()

  # Record information about the size of the original segments.
  result.eval.max_edt_a = float(
      ndimage.distance_transform_edt(seg1, sampling=voxel_size).max())
  result.eval.max_edt_b = float(
      ndimage.distance_transform_edt(seg2, sampling=voxel_size).max())

  # Offset of the analysis subvolume within the resegmentation subvolume.
  delta = np.array(resegmentation_radius) - analysis_r
  prob = prob[:,
              delta[0]:(delta[0] + 2 * analysis_r[0] + 1),
              delta[1]:(delta[1] + 2 * analysis_r[1] + 1),
              delta[2]:(delta[2] + 2 * analysis_r[2] + 1)]
  reseg = prob >= threshold
  result.eval.iou = compute_iou(reseg)

  # Record information about the size of the reconstructed segments.
  evaluate_segmentation_result(
      reseg[0, ...], dels[0], moves[0], delta, analysis_r, seg1, seg2,
      voxel_size, result.eval.from_a)
  evaluate_segmentation_result(
      reseg[1, ...], dels[1], moves[1], delta, analysis_r, seg1, seg2,
      voxel_size, result.eval.from_b)

  return result
Esempio n. 8
0
def load_pkl(path, keys=None, **kwargs):
    """Load AIST++ annotations from pkl file."""
    if '/cns/' in path:
        with gfile.Open(path, 'rb') as f:
            data = pickle.load(f)
            annos = data['pred_results']
    else:
        with open(path, 'rb') as f:
            data = pickle.load(f)
            annos = data['pred_results']
    assert annos, f'data {path} has empty annotations'
    out = {}

    # smpl related
    if 'smpl_loss' in data and ('smpl_loss' in keys if keys else True):
        # a single float
        out.update({'smpl_loss': data['smpl_loss']})

    if 'smpl_joints' in annos[0] and ('smpl_joints' in keys if keys else True):
        # [nframes, 24, 3]
        out.update({
            'smpl_joints':
            np.stack([anno['smpl_joints']
                      for anno in annos])[:, :24, :].astype(np.float32)
        })
    if 'smpl_pose' in annos[0] and ('smpl_poses' in keys if keys else True):
        # [nframes, 24, 3]
        out.update({
            'smpl_poses':
            np.stack([anno['smpl_pose']
                      for anno in annos]).reshape(-1, 24, 3).astype(np.float32)
        })
    if 'smpl_shape' in annos[0] and ('smpl_shape' in keys if keys else True):
        # [nframes, 10]
        out.update({
            'smpl_shape':
            np.stack([anno['smpl_shape'] for anno in annos]).astype(np.float32)
        })
    if 'scaling' in annos[0] and ('smpl_scaling' in keys if keys else True):
        # [nframes, 1]
        out.update({
            'smpl_scaling':
            np.stack([anno['scaling'] for anno in annos]).astype(np.float32)
        })
    if 'transl' in annos[0] and ('smpl_trans' in keys if keys else True):
        # [nframes, 3]
        out.update({
            'smpl_trans':
            np.stack([anno['transl'] for anno in annos]).astype(np.float32)
        })
    if 'verts' in annos[0] and ('smpl_verts' in keys if keys else True):
        # [nframes, 6890, 3]
        out.update({
            'smpl_verts':
            np.stack([anno['verts'] for anno in annos]).astype(np.float32)
        })

    # 2D and 3D keypoints
    if 'keypoints2d' in annos[0] and ('smpl_verts' in keys if keys else True):
        # [9, nframes, 17, 3]
        out.update({
            'keypoints2d':
            np.stack([anno['keypoints2d'] for anno in annos],
                     axis=1).astype(np.float32)
        })
    if 'keypoints3d' in annos[0] and ('keypoints3d' in keys if keys else True):
        # [nframes, 17, 3]
        out.update({
            'keypoints3d':
            np.stack([anno['keypoints3d']
                      for anno in annos]).astype(np.float32)
        })
    if 'keypoints3d_optim' in annos[0] and ('keypoints3d_optim' in keys
                                            if keys else True):
        # [nframes, 17, 3]
        out.update({
            'keypoints3d_optim':
            np.stack([anno['keypoints3d_optim']
                      for anno in annos]).astype(np.float32)
        })

    # timestamps for each frame, in ms.
    if 'timestamp' in annos[0] and ('timestamps' in keys if keys else True):
        # [nframes,]
        out.update({
            'timestamps':
            np.stack([anno['timestamp'] for anno in annos]).astype(np.int32)
        })

    # human detection score
    if 'det_scores' in annos[0] and ('det_scores' in keys if keys else True):
        # [9, nframes]
        out.update({
            'det_scores':
            np.stack([anno['det_scores'] for anno in annos],
                     axis=1).astype(np.int32)
        })

    return out
Esempio n. 9
0
def obj_read_for_gl(filename, texture_size=(32, 32)):
    """Read vertex and part information from OBJ file."""

    if texture_size:
        print(texture_size)
    with gfile.Open(filename, 'r') as f:
        content = f.readlines()

        vertices = []
        texture_coords = []
        vertex_normals = []

        group_name = None
        material_name = None

        faces = []
        faces_tex = []
        faces_normals = []
        face_groups = []
        material_ids = []

        for i in range(len(content)):
            line = content[i]
            parts = re.split(r'\s+', line)

            # if parts[0] == 'mtllib':
            #   material_file = parts[1]

            # Vertex information -----------------------------------------------------
            if parts[0] == 'v':
                vertices.append([float(v) for v in parts[1:4]])
            if parts[0] == 'vt':
                texture_coords.append([float(v) for v in parts[1:4]])
            if parts[0] == 'vn':
                vertex_normals.append([float(v) for v in parts[1:4]])

            if parts[0] == 'g':
                group_name = parts[1]
            if parts[0] == 'usemtl':
                material_name = parts[1]

            # Face information ------------------------------------------------------
            if parts[0] == 'f':
                vertex_index, tex_index, normal_index = 0, 0, 0
                current_face, current_face_tex, current_face_norm = [], [], []
                for j in range(1, 4):
                    face_info = parts[j]
                    if face_info.count('/') == 2:
                        vertex_index, tex_index, normal_index = face_info.split(
                            '/')
                        if not tex_index:
                            tex_index = 0
                    elif face_info.count('/') == 1:
                        vertex_index, tex_index = face_info.split('/')
                    elif face_info.count('/') == 0:
                        vertex_index = face_info
                    current_face.append(int(vertex_index) - 1)
                    current_face_tex.append(int(tex_index) - 1)
                    current_face_norm.append(int(normal_index) - 1)
                faces.append(current_face)
                faces_tex.append(current_face_tex)
                faces_normals.append(current_face_norm)
                face_groups.append(group_name)
                material_ids.append(material_name)

        vertices = np.array(vertices)
        texture_coords = np.array(texture_coords)
        vertex_normals = np.array(vertex_normals)
        has_tex_coord, has_normals = True, True
        if texture_coords.shape[0] == 0:
            has_tex_coord = False
        if vertex_normals.shape[0] == 0:
            has_normals = False

        faces = np.array(faces)
        faces_tex = np.array(faces_tex)
        faces_normals = np.array(faces_normals)

        n_faces = faces.shape[0]
        vertex_positions = np.zeros((n_faces, 3, 3), dtype=np.float32)
        tex_coords = np.zeros((n_faces, 3, 2), dtype=np.float32)
        normals = np.zeros((n_faces, 3, 3), dtype=np.float32)
        for i in range(n_faces):
            for j in range(3):
                vertex_positions[i, j, :] = vertices[faces[i, j], :]
                if has_tex_coord:
                    tex_coords[i, j, :] = texture_coords[faces_tex[i, j], :2]
                if has_normals:
                    normals[i, j, :] = vertex_normals[faces_normals[i, j], :]

    # Material info --------------------------------------------------------------
    return vertex_positions, \
           tex_coords, \
           normals, \
           material_ids, \
           vertices, \
           faces
Esempio n. 10
0
def save_for_blender(detections,
                     sample,
                     log_dir,
                     dict_clusters,
                     shape_pointclouds,
                     class_id_to_name=CLASSES):
    """Save for blender."""
    # VisualDebugging uses the OpenCV coordinate representation
    # while the dataset uses OpenGL (left-hand) so make sure to convert y and z.

    batch_id = 0
    prefix = '/cns/lu-d/home/giotto3d/datasets/shapenet/raw/'
    sufix = 'models/model_normalized.obj'

    blender_dict = {}
    blender_dict['image'] = \
        tf.io.decode_image(sample['image_data'][batch_id]).numpy()
    blender_dict['world_to_cam'] = sample['rt'].numpy()
    num_predicted_shapes = int(detections['sizes_3d'].shape[0])
    blender_dict['num_predicted_shapes'] = num_predicted_shapes
    blender_dict['predicted_rotations_3d'] = \
        tf.reshape(detections['rotations_3d'], [-1, 3, 3]).numpy()
    blender_dict['predicted_rotations_y'] = [
        tf_utils.euler_from_rotation_matrix(
            tf.reshape(detections['rotations_3d'][i], [3, 3]), 1).numpy()
        for i in range(num_predicted_shapes)
    ]
    blender_dict['predicted_translations_3d'] = \
        detections['translations_3d'].numpy()
    blender_dict['predicted_sizes_3d'] = detections['sizes_3d'].numpy()
    predicted_shapes_path = []
    for i in range(num_predicted_shapes):
        shape = detections['shapes'][i].numpy()
        _, class_str, model_str = dict_clusters[shape]
        filename = os.path.join(prefix, class_str, model_str, sufix)
        predicted_shapes_path.append(filename)
    blender_dict['predicted_shapes_path'] = predicted_shapes_path
    blender_dict['predicted_class'] = [
        class_id_to_name[int(detections['detection_classes'][i].numpy())]
        for i in range(num_predicted_shapes)
    ]

    blender_dict['predicted_pointcloud'] = [
        shape_pointclouds[int(detections['shapes'][i].numpy())]
        for i in range(num_predicted_shapes)
    ]

    num_groundtruth_shapes = int(sample['sizes_3d'][batch_id].shape[0])
    blender_dict['num_groundtruth_shapes'] = num_groundtruth_shapes
    blender_dict['groundtruth_rotations_3d'] = \
        tf.reshape(sample['rotations_3d'][batch_id], [-1, 3, 3]).numpy()
    blender_dict['groundtruth_rotations_y'] = [
        tf_utils.euler_from_rotation_matrix(
            tf.reshape(sample['rotations_3d'][batch_id][i], [3, 3]),
            1).numpy() for i in range(sample['num_boxes'][batch_id].numpy())
    ]
    blender_dict['groundtruth_translations_3d'] = \
        sample['translations_3d'][batch_id].numpy()
    blender_dict['groundtruth_sizes_3d'] = sample['sizes_3d'][batch_id].numpy()
    groundtruth_shapes_path = []
    for i in range(num_groundtruth_shapes):
        class_str = str(sample['classes'][batch_id, i].numpy()).zfill(8)
        model_str = str(sample['mesh_names'][batch_id, i].numpy())[2:-1]
        filename = os.path.join(prefix, class_str, model_str, sufix)
        groundtruth_shapes_path.append(filename)
    blender_dict['groundtruth_shapes_path'] = groundtruth_shapes_path
    blender_dict['groundtruth_classes'] = \
        sample['groundtruth_valid_classes'].numpy()

    path = log_dir + '.pkl'
    with gfile.Open(path, 'wb') as file:
        pickle.dump(blender_dict, file)
Esempio n. 11
0
    def add_detections(self, sample, detections):
        """Add detections to evaluation.

    Args:
      sample: the ground truth information
      detections: the predicted detections

    Returns:
      dict of intermediate results.

    """
        result_dict = {
            'iou_mean': -1,
            'iou_min': -1,
            'collisions': 0,
            'collision_intersection': 0,
            'collision_iou': 0
        }
        num_boxes = sample['num_boxes'].numpy()

        for _, metric in self.metrics.items():
            if isinstance(metric, ShapeAccuracyMetric):
                labels = sample['shapes']
                weights = tf.math.sign(labels +
                                       1)  # -1 is mapped to zero, else 1
                metric.update(labels, detections['shapes_logits'], weights)
            elif isinstance(metric, BoxIoUMetric):
                scene_id = str(sample['scene_filename'].numpy(), 'utf-8')

                # Get ground truth boxes
                labeled_boxes = tf.gather(sample['groundtruth_boxes'],
                                          axis=1,
                                          indices=[1, 0, 3, 2]) * 256.0
                if metric.threed:
                    rotations_y = tf.concat([
                        tf_utils.euler_from_rotation_matrix(
                            tf.reshape(detections['rotations_3d'][i], [3, 3]),
                            1) for i in range(num_boxes)
                    ],
                                            axis=0)
                    rotations_y = tf.reshape(rotations_y, [-1, 1])
                    labeled_boxes = tf.concat([
                        sample['translations_3d'], sample['sizes_3d'],
                        rotations_y
                    ],
                                              axis=1)

                # Get predicted boxes
                predicted_boxes = detections['detection_boxes']
                if metric.threed:
                    rotations_y = tf.concat([
                        tf_utils.euler_from_rotation_matrix(
                            tf.reshape(detections['rotations_3d'][i], [3, 3]),
                            1) for i in range(num_boxes)
                    ],
                                            axis=0)
                    rotations_y = tf.reshape(rotations_y, [-1, 1])
                    predicted_boxes = tf.concat([
                        detections['translations_3d'], detections['sizes_3d'],
                        rotations_y
                    ],
                                                axis=1)

                labeled_classes = tf.cast(sample['groundtruth_valid_classes'],
                                          tf.int64)
                predicted_classes = tf.cast(detections['detection_classes'],
                                            tf.int64)
                confidences = detections['detection_scores']
                metric.update(scene_id, labeled_boxes, labeled_classes,
                              predicted_boxes, predicted_classes, confidences)
            elif isinstance(metric, IoUMetric):
                classes = sample['classes']
                mesh_names = sample['mesh_names']
                labeled_sdfs = []
                for i in range(num_boxes):
                    class_id = str(classes[i].numpy()).zfill(8)
                    model_name = str(mesh_names[i].numpy(), 'utf-8')
                    path_prefix = os.path.join(self.shapenet_dir, class_id,
                                               model_name)
                    file_sdf = os.path.join(path_prefix,
                                            'model_normalized_sdf.npy')
                    with gfile.Open(file_sdf, 'rb') as f:
                        labeled_sdfs.append(
                            tf.expand_dims(np.load(f).astype(np.float32), 0))
                labeled_sdfs = tf.concat(labeled_sdfs, axis=0)

                labeled_classes = tf.cast(sample['groundtruth_valid_classes'],
                                          tf.int64)
                labeled_permutation = np.argsort(labeled_classes)

                labeled_sdfs = labeled_sdfs.numpy()[labeled_permutation]
                labeled_classes = labeled_classes.numpy()[labeled_permutation]
                labeled_rotations_3d = sample['rotations_3d'].numpy()
                labeled_rotations_3d = labeled_rotations_3d[
                    labeled_permutation]
                labeled_translations_3d = sample['translations_3d'].numpy()
                labeled_translations_3d = labeled_translations_3d[
                    labeled_permutation]
                labeled_sizes_3d = sample['sizes_3d'].numpy(
                )[labeled_permutation]
                labeled_poses = (labeled_rotations_3d, labeled_translations_3d,
                                 labeled_sizes_3d)

                # Predictions
                predicted_classes = tf.cast(detections['detection_classes'],
                                            tf.int64)
                predicted_permutation = np.argsort(predicted_classes)
                predicted_classes = predicted_classes.numpy(
                )[predicted_permutation]

                predicted_sdfs = \
                  detections['predicted_sdfs'].numpy()[predicted_permutation]
                predicted_rotations_3d = \
                  detections['rotations_3d'].numpy()[predicted_permutation]
                predicted_translations_3d = \
                  detections['translations_3d'].numpy()[predicted_permutation]
                predicted_sizes_3d = \
                  detections['sizes_3d'].numpy()[predicted_permutation]
                predicted_poses = (predicted_rotations_3d,
                                   predicted_translations_3d,
                                   predicted_sizes_3d)

                full_oracle = False
                if full_oracle:
                    predicted_sdfs = detections['groundtruth_sdfs'].numpy()
                    predicted_sdfs = predicted_sdfs[labeled_permutation]
                    predicted_classes = labeled_classes
                    predicted_poses = labeled_poses

                print('----------------------------')
                print(predicted_sdfs.shape)
                print(predicted_classes.shape)
                print(predicted_poses[0].shape)
                print(predicted_poses[1].shape)
                print(predicted_poses[2].shape)

                pose_oracle = False
                if pose_oracle:
                    predicted_sdfs = detections['predicted_sdfs'].numpy()
                    predicted_sdfs = predicted_sdfs[predicted_permutation]
                    predicted_poses = (labeled_rotations_3d,
                                       labeled_translations_3d,
                                       labeled_sizes_3d)

                class_oracle = True
                if class_oracle:
                    predicted_classes *= 0
                    labeled_classes *= 0

                iou_mean, iou_min = metric.update(
                    labeled_sdfs, labeled_classes, labeled_poses,
                    predicted_sdfs, predicted_classes, predicted_poses,
                    sample['dot'])
                result_dict['iou_mean'] = iou_mean
                result_dict['iou_min'] = iou_min
            elif isinstance(metric, CollisionMetric):

                labeled_sdfs = detections['groundtruth_sdfs']
                labeled_classes = tf.cast(sample['groundtruth_valid_classes'],
                                          tf.int64)
                labeled_poses = (sample['rotations_3d'],
                                 sample['translations_3d'], sample['sizes_3d'])

                predicted_classes = tf.cast(detections['detection_classes'],
                                            tf.int64)
                predicted_sdfs = detections['predicted_sdfs']
                predicted_poses = (detections['rotations_3d'],
                                   detections['translations_3d'],
                                   detections['sizes_3d'])

                full_oracle = False
                if full_oracle:
                    predicted_sdfs = detections['groundtruth_sdfs'].numpy()
                    predicted_classes = labeled_classes
                    predicted_poses = labeled_poses

                num_collisions, intersection, iou = metric.update(
                    labeled_sdfs, labeled_classes, labeled_poses,
                    predicted_sdfs, predicted_classes, predicted_poses)
                result_dict['collisions'] = num_collisions
                result_dict['collision_intersection'] = intersection
                result_dict['collision_iou'] = iou

        return result_dict
Esempio n. 12
0
def run_experiment(study_hparams=None, trial_handle=None, tuner=None):

    FLAGS = deepcopy(tf.app.flags.FLAGS)

    if FLAGS.use_vizier:
        for key, val in study_hparams.values().items():
            setattr(FLAGS, key, val)

    tf.reset_default_graph()
    np.random.seed(FLAGS.random_seed)
    tf.set_random_seed(FLAGS.random_seed)

    # Initialize env

    env_kwargs = {
        'goal_x': FLAGS.goal_x,
        'min_goal_x': FLAGS.min_goal_x,
        'max_goal_x': FLAGS.max_goal_x,
        'x_threshold': FLAGS.x_threshold,
        'max_reward_for_dist': FLAGS.max_reward_for_dist,
        'reward_per_time_step': FLAGS.reward_per_time_step,
        'fixed_initial_state': FLAGS.fixed_initial_state,
        'reweight_rewards': FLAGS.reweight_rewards
    }
    env = cartpole.make_env(env_kwargs)
    eval_env = cartpole.make_env(env_kwargs)

    if not FLAGS.fixed_env:
        env.env.randomize()

    if trial_handle:
        tensorboard_path = os.path.join(FLAGS.output_dir, trial_handle)
    else:
        tensorboard_path = FLAGS.output_dir
    tf.gfile.MakeDirs(tensorboard_path)

    kwargs = dict(observation_shape=[None] + list(env.observation_space.shape),
                  action_dim=1)
    default_hps = MetaQ.get_default_config().values()

    for key in flags_def:
        if key in default_hps:
            kwargs[key] = getattr(FLAGS, key)

    hps = tf.HParams(**kwargs)

    meta_q = MetaQ(hps, fully_connected_net(FLAGS.nn_arch, FLAGS.activation))
    meta_q.build_graph()

    init_op = tf.global_variables_initializer()

    logger = TensorBoardLogger(tensorboard_path)

    with tf.Session() as sess:
        sess.run(init_op)
        meta_q.init_session(sess)

        inner_loop_buffer = MultiTaskReplayBuffer(len(env.env.goal_positions),
                                                  200000, FLAGS.random_seed)
        outer_loop_buffer = MultiTaskReplayBuffer(len(env.env.goal_positions),
                                                  200000, FLAGS.random_seed)

        pre_update_rewards = []
        post_update_rewards = []
        post_update_greedy_rewards = []
        post_update_q_func = None
        for outer_step in range(FLAGS.outer_loop_steps):
            print('State is ', env.env.state)
            if outer_step % FLAGS.on_policy_steps == 0:
                if FLAGS.fixed_env:
                    goal_positions = [env.env.goal_x]
                else:
                    goal_positions = env.env.goal_positions
                # NOTE: Approximately ~30 to 60 states per trajectory
                inner_loop_buffer = collect_off_policy_data(
                    env, goal_positions, meta_q, post_update_q_func,
                    inner_loop_buffer, FLAGS.inner_loop_n_trajs,
                    FLAGS.inner_loop_data_collection,
                    FLAGS.inner_loop_greedy_epsilon,
                    FLAGS.inner_loop_bolzmann_temp)
                outer_loop_buffer = collect_off_policy_data(
                    env, goal_positions, meta_q, post_update_q_func,
                    outer_loop_buffer, FLAGS.outer_loop_n_trajs,
                    FLAGS.outer_loop_data_collection,
                    FLAGS.outer_loop_greedy_epsilon,
                    FLAGS.outer_loop_bolzmann_temp)

            post_update_greedy_rewards = []

            finetuned_policy = None
            for task_id in range(FLAGS.n_meta_tasks):
                # print('Task: {}'.format(task_id))

                if not FLAGS.fixed_env:
                    env.env.randomize()

                (inner_observations, inner_actions, inner_rewards,
                 inner_next_observations,
                 inner_dones) = inner_loop_buffer.sample(
                     env.env.task_id, FLAGS.inner_loop_n_states)
                # Evaluating true rewards
                post_update_q_func = meta_q.get_post_update_q_function(
                    inner_observations, inner_actions, inner_rewards,
                    inner_next_observations, inner_dones)

                policy = QPolicy(post_update_q_func, epsilon=0.0)

                if outer_step % FLAGS.report_steps == 0 or outer_step >= (
                        FLAGS.outer_loop_steps - 1):
                    _, _, greedy_rewards, _, _ = cartpole_utils.collect_data(
                        env,
                        n_trajs=FLAGS.outer_loop_greedy_eval_n_trajs,
                        policy=policy)
                    post_update_greedy_rewards.append(
                        np.sum(greedy_rewards) /
                        FLAGS.outer_loop_greedy_eval_n_trajs)

                finetuned_policy = policy

                (outer_observations, outer_actions, outer_rewards,
                 outer_next_observations,
                 outer_dones) = outer_loop_buffer.sample(
                     env.env.task_id, FLAGS.outer_loop_n_states)
                meta_q.accumulate_gradient(
                    inner_observations,
                    inner_actions,
                    inner_rewards,
                    inner_next_observations,
                    inner_dones,
                    outer_observations,
                    outer_actions,
                    outer_rewards,
                    outer_next_observations,
                    outer_dones,
                )

            pre_update_loss, post_update_loss = meta_q.run_train_step()

            if not FLAGS.outer_loop_online_target and outer_step % FLAGS.target_update_freq == 0:
                print("updating target network")
                meta_q.update_target_network()

            log_data = dict(
                pre_update_loss=pre_update_loss,
                post_update_loss=post_update_loss,
                goal_x=env.env.goal_x,
            )

            #TODO(hkannan): uncomment this later!!!
            if outer_step % FLAGS.report_steps == 0 or outer_step >= (
                    FLAGS.outer_loop_steps - 1):
                # reward_across_20_tasks = evaluate(
                #     policy, eval_env, meta_q,
                #     inner_loop_n_trajs=FLAGS.inner_loop_n_trajs,
                #     outer_loop_n_trajs=FLAGS.outer_loop_n_trajs, n=21,
                #     weight_rewards=FLAGS.weight_rewards)
                # log_data['reward_mean'] = np.mean(reward_across_20_tasks)
                # log_data['reward_variance'] = np.var(reward_across_20_tasks)
                log_data['post_update_greedy_reward'] = np.mean(
                    post_update_greedy_rewards)
                log_data['post_update_greedy_reward_variance'] = np.var(
                    post_update_greedy_rewards)

            print('Outer step: {}, '.format(outer_step), log_data)
            logger.log_dict(outer_step, log_data)
            # if outer_step % FLAGS.video_report_steps == 0 or outer_step >= (FLAGS.outer_loop_steps - 1):
            #   video_data = {
            #       'env_kwargs': env_kwargs,
            #       'inner_loop_data_collection': FLAGS.inner_loop_data_collection,
            #       'inner_loop_greedy_epsilon': FLAGS.inner_loop_greedy_epsilon,
            #       'inner_loop_bolzmann_temp': FLAGS.inner_loop_bolzmann_temp,
            #       'inner_loop_n_trajs': FLAGS.inner_loop_n_trajs,
            #       'meta_q_kwargs': kwargs,
            #       'weights': meta_q.get_current_weights(),
            #       'tensorboard_path': tensorboard_path,
            #       'filename': 'random_task'
            #   }
            #   reward_across_20_tasks = evaluate(
            #       policy, eval_env, meta_q,
            #       inner_loop_n_trajs=FLAGS.inner_loop_n_trajs,
            #       outer_loop_n_trajs=FLAGS.outer_loop_n_trajs, n=21,
            #       weight_rewards=FLAGS.weight_rewards, video_data=video_data)
            #   log_data['reward_mean'] = np.mean(reward_across_20_tasks)
            #   log_data['reward_variance'] = np.var(reward_across_20_tasks)
            #   logger.log_dict(outer_step, log_data)

            if outer_step >= (FLAGS.outer_loop_steps - 1):
                greedy_reward_path = os.path.join(tensorboard_path, 'reward')
                with gfile.Open(greedy_reward_path, mode='wb') as f:
                    f.write(pickle.dumps(
                        log_data['post_update_greedy_reward']))
            if FLAGS.use_vizier:
                for v in log_data.values():
                    if not np.isfinite(v):
                        tuner.report_done(
                            infeasible=True,
                            infeasible_reason='Nan or inf encountered')
                        return

                if outer_step % FLAGS.report_steps == 0 or outer_step >= (
                        FLAGS.outer_loop_steps - 1):
                    if FLAGS.vizier_objective == 'greedy_reward':
                        objective_value = log_data['post_update_greedy_reward']
                    elif FLAGS.vizier_objective == 'loss':
                        objective_value = post_update_loss
                    elif FLAGS.vizier_objective == 'reward':
                        objective_value = log_data['reward_mean']
                    else:
                        raise ValueError('Unsupported vizier objective!')
                    tuner.report_measure(objective_value=objective_value,
                                         global_step=outer_step,
                                         metrics=log_data)

    if FLAGS.use_vizier:
        tuner.report_done()