def evaluate_endpoint_resegmentation(filename, seg_volume, resegmentation_radius, threshold=0.5): """Evaluates endpoint resegmentation. Args: filename: path to the file containing resegmentation results seg_volume: volume object with the original segmentation resegmentation_radius: (z, y, x) radius of the resegmentation subvolume threshold: threshold at which to create objects from the predicted object map Returns: EndpointResegmentationResult proto Raises: InvalidBaseSegmentatonError: when no base segmentation object with the expected ID matches the resegmentation data """ id1, _, x, y, z = parse_resegmentation_filename(filename) result = resegmentation_pb2.EndpointSegmentationResult() result.id = id1 start = result.start start.x, start.y, start.z = x, y, z sr = result.segmentation_radius sr.z, sr.y, sr.x = resegmentation_radius with gfile.Open(filename, 'rb') as f: data = np.load(f) prob = storage.dequantize_probability(data['probs']) prob = np.nan_to_num(prob) # nans indicate unvisited voxels sr = result.segmentation_radius orig_seg = seg_volume[0, (z - sr.z):(z + sr.z + 1), (y - sr.y):(y + sr.y + 1), (x - sr.x):(x + sr.x + 1)][0, ...] seg1 = orig_seg == id1 if not np.any(seg1): raise InvalidBaseSegmentatonError() new_seg = prob[0, ...] >= threshold result.num_voxels = int(np.sum(new_seg)) overlaps = pywrapsegment_util.ComputeOverlapCounts( orig_seg.ravel(), new_seg.astype(np.uint64).ravel()) for k, v in overlaps.items(): old, new = k if not new: continue result.overlaps[old].num_overlapping = v result.overlaps[old].num_original = int(np.sum(orig_seg == old)) if old == id1: result.source.CopyFrom(result.overlaps[old]) return result
def load(cls, path): """Load data from json file.""" if '/cns/' in path: with gfile.Open(path, 'r') as f: data = json.load(f) else: with open(path, 'r') as f: data = json.load(f) keys = sorted(data.keys()) cams = [] for key in keys: if key == 'metadata': continue else: name = data[key]['name'] size = data[key]['size'] matrix = np.array(data[key]['matrix']) dist = np.array(data[key]['distortions']) rvec = np.array(data[key]['rotation']) tvec = np.array(data[key]['translation']) cams.append(Camera(rvec, tvec, matrix, dist, name, size)) return cls(cams)
def evaluate(self): """Evaluate.""" if self.slave: data = self.iou_per_class with gfile.Open(self.path, 'wb') as file: pickle.dump(data, file) logging.info(file) return else: iou_per_class_means = [] for _, v in self.iou_per_class.items(): if v: iou_per_class_means.append(np.mean(v)) return np.mean(iou_per_class_means)
def _CreateCsv(config_dict, tdb, output_root, output_file, file_type, batch_size): """Writes data from db to CSV based on file type. Args: config_dict: Config file dict. tdb: Instance of TrackingDB. output_root: Path to directory of output CSV. output_file: Name of output CSV file. file_type: Type of NetCDF file. batch_size: Number of rows written to CSV at once. """ output_file = os.path.join(output_root, output_file) processor = reader.NetCdfMetadataReader(config_dict, file_type) with gfile.Open(output_file, 'w') as csv_file: _PrintStderr('Saving CSV file: %s', output_file) csv_file.write(processor.GetCsvHeader() + '\n') tdb.ExportCsvData(csv_file, file_type, batch_size)
def evaluate(self): """Evaluate.""" if self.slave: data = { 'collisions': self.collisions, 'intersections': self.intersections, 'ious': self.ious } with gfile.Open(self.path, 'wb') as file: pickle.dump(data, file) logging.info(file) return else: # self.collisions = [] # for k, v in self.iou_per_class.items(): # if len(v) > 0: # iou_per_class_means.append(np.mean(v)) return np.sum(self.collisions)
def main(_): models_data_filename = 'models_per_split.pkl' with gfile.Open(models_data_filename, 'rb') as filename: models_data_per_class = pickle.load(filename) path_prefix = '/datasets/shapenet/raw/' local_prefix = '/occluded_primitives/meshes/' for split in ['val', 'train', 'test']: for class_name in models_data_per_class[split]: for model in models_data_per_class[split][class_name]: mesh_file = os.path.join(path_prefix, class_name, model, 'models', 'model_normalized.obj') local_dir = os.path.join(local_prefix, split, class_name, model) if not os.path.exists(local_dir): os.makedirs(local_dir) print(mesh_file) res = subprocess.call(['fileutil', 'cp', mesh_file, local_dir]) print(res)
def evaluate_pair_resegmentation(filename, seg_volume, resegmentation_radius, analysis_radius, voxel_size, threshold=0.5): """Evaluates segment pair resegmentation. Args: filename: path to the file containing resegmentation results seg_volume: VolumeStore object with the original segmentation resegmentation_radius: (z, y, x) radius of the resegmentation subvolume analysis_radius: (z, y, x) radius of the subvolume in which to perform analysis voxel_size: (z, y, x) voxel size in physical units threshold: threshold at which to create objects from the predicted object map Returns: PairResegmentationResult proto Raises: IncompleteResegmentationError: when the resegmentation data does not represent two finished segments InvalidBaseSegmentatonError: when no base segmentation object with the excepted ID matches the resegmentation data """ id1, id2, x, y, z = parse_resegmentation_filename(filename) result = resegmentation_pb2.PairResegmentationResult() result.id_a, result.id_b = id1, id2 p = result.point p.x, p.y, p.z = x, y, z sr = result.segmentation_radius sr.z, sr.y, sr.x = resegmentation_radius with gfile.Open(filename, 'rb') as f: data = np.load(f) prob = storage.dequantize_probability(data['probs']) prob = np.nan_to_num(prob) # nans indicate unvisited voxels dels = data['deletes'] moves = data['histories'] # z, y, x start_points = data['start_points'] # x, y, z if prob.shape[0] != 2: raise IncompleteResegmentationError() assert prob.ndim == 4 # Corner of the resegmentation subvolume in the global coordinate system. corner = np.array([p.x - sr.x, p.y - sr.y, p.z - sr.z]) # In case of multiple segmentation attempts, the last recorded start # point is the one we care about. origin_a = np.array(start_points[0][-1], dtype=np.int) + corner origin_b = np.array(start_points[1][-1], dtype=np.int) + corner oa = result.eval.from_a.origin oa.x, oa.y, oa.z = origin_a ob = result.eval.from_b.origin ob.x, ob.y, ob.z = origin_b # Record basic infromation about the resegmentation run. analysis_r = np.array(analysis_radius) r = result.eval.radius r.z, r.y, r.x = analysis_r seg = seg_volume[0, (z - analysis_r[0]):(z + analysis_r[0] + 1), (y - analysis_r[1]):(y + analysis_r[1] + 1), (x - analysis_r[2]):(x + analysis_r[2] + 1)][0, ...] seg1 = seg == id1 seg2 = seg == id2 result.eval.num_voxels_a = int(np.sum(seg1)) result.eval.num_voxels_b = int(np.sum(seg2)) if result.eval.num_voxels_a == 0 or result.eval.num_voxels_b == 0: raise InvalidBaseSegmentatonError() # Record information about the size of the original segments. result.eval.max_edt_a = float( ndimage.distance_transform_edt(seg1, sampling=voxel_size).max()) result.eval.max_edt_b = float( ndimage.distance_transform_edt(seg2, sampling=voxel_size).max()) # Offset of the analysis subvolume within the resegmentation subvolume. delta = np.array(resegmentation_radius) - analysis_r prob = prob[:, delta[0]:(delta[0] + 2 * analysis_r[0] + 1), delta[1]:(delta[1] + 2 * analysis_r[1] + 1), delta[2]:(delta[2] + 2 * analysis_r[2] + 1)] reseg = prob >= threshold result.eval.iou = compute_iou(reseg) # Record information about the size of the reconstructed segments. evaluate_segmentation_result( reseg[0, ...], dels[0], moves[0], delta, analysis_r, seg1, seg2, voxel_size, result.eval.from_a) evaluate_segmentation_result( reseg[1, ...], dels[1], moves[1], delta, analysis_r, seg1, seg2, voxel_size, result.eval.from_b) return result
def load_pkl(path, keys=None, **kwargs): """Load AIST++ annotations from pkl file.""" if '/cns/' in path: with gfile.Open(path, 'rb') as f: data = pickle.load(f) annos = data['pred_results'] else: with open(path, 'rb') as f: data = pickle.load(f) annos = data['pred_results'] assert annos, f'data {path} has empty annotations' out = {} # smpl related if 'smpl_loss' in data and ('smpl_loss' in keys if keys else True): # a single float out.update({'smpl_loss': data['smpl_loss']}) if 'smpl_joints' in annos[0] and ('smpl_joints' in keys if keys else True): # [nframes, 24, 3] out.update({ 'smpl_joints': np.stack([anno['smpl_joints'] for anno in annos])[:, :24, :].astype(np.float32) }) if 'smpl_pose' in annos[0] and ('smpl_poses' in keys if keys else True): # [nframes, 24, 3] out.update({ 'smpl_poses': np.stack([anno['smpl_pose'] for anno in annos]).reshape(-1, 24, 3).astype(np.float32) }) if 'smpl_shape' in annos[0] and ('smpl_shape' in keys if keys else True): # [nframes, 10] out.update({ 'smpl_shape': np.stack([anno['smpl_shape'] for anno in annos]).astype(np.float32) }) if 'scaling' in annos[0] and ('smpl_scaling' in keys if keys else True): # [nframes, 1] out.update({ 'smpl_scaling': np.stack([anno['scaling'] for anno in annos]).astype(np.float32) }) if 'transl' in annos[0] and ('smpl_trans' in keys if keys else True): # [nframes, 3] out.update({ 'smpl_trans': np.stack([anno['transl'] for anno in annos]).astype(np.float32) }) if 'verts' in annos[0] and ('smpl_verts' in keys if keys else True): # [nframes, 6890, 3] out.update({ 'smpl_verts': np.stack([anno['verts'] for anno in annos]).astype(np.float32) }) # 2D and 3D keypoints if 'keypoints2d' in annos[0] and ('smpl_verts' in keys if keys else True): # [9, nframes, 17, 3] out.update({ 'keypoints2d': np.stack([anno['keypoints2d'] for anno in annos], axis=1).astype(np.float32) }) if 'keypoints3d' in annos[0] and ('keypoints3d' in keys if keys else True): # [nframes, 17, 3] out.update({ 'keypoints3d': np.stack([anno['keypoints3d'] for anno in annos]).astype(np.float32) }) if 'keypoints3d_optim' in annos[0] and ('keypoints3d_optim' in keys if keys else True): # [nframes, 17, 3] out.update({ 'keypoints3d_optim': np.stack([anno['keypoints3d_optim'] for anno in annos]).astype(np.float32) }) # timestamps for each frame, in ms. if 'timestamp' in annos[0] and ('timestamps' in keys if keys else True): # [nframes,] out.update({ 'timestamps': np.stack([anno['timestamp'] for anno in annos]).astype(np.int32) }) # human detection score if 'det_scores' in annos[0] and ('det_scores' in keys if keys else True): # [9, nframes] out.update({ 'det_scores': np.stack([anno['det_scores'] for anno in annos], axis=1).astype(np.int32) }) return out
def obj_read_for_gl(filename, texture_size=(32, 32)): """Read vertex and part information from OBJ file.""" if texture_size: print(texture_size) with gfile.Open(filename, 'r') as f: content = f.readlines() vertices = [] texture_coords = [] vertex_normals = [] group_name = None material_name = None faces = [] faces_tex = [] faces_normals = [] face_groups = [] material_ids = [] for i in range(len(content)): line = content[i] parts = re.split(r'\s+', line) # if parts[0] == 'mtllib': # material_file = parts[1] # Vertex information ----------------------------------------------------- if parts[0] == 'v': vertices.append([float(v) for v in parts[1:4]]) if parts[0] == 'vt': texture_coords.append([float(v) for v in parts[1:4]]) if parts[0] == 'vn': vertex_normals.append([float(v) for v in parts[1:4]]) if parts[0] == 'g': group_name = parts[1] if parts[0] == 'usemtl': material_name = parts[1] # Face information ------------------------------------------------------ if parts[0] == 'f': vertex_index, tex_index, normal_index = 0, 0, 0 current_face, current_face_tex, current_face_norm = [], [], [] for j in range(1, 4): face_info = parts[j] if face_info.count('/') == 2: vertex_index, tex_index, normal_index = face_info.split( '/') if not tex_index: tex_index = 0 elif face_info.count('/') == 1: vertex_index, tex_index = face_info.split('/') elif face_info.count('/') == 0: vertex_index = face_info current_face.append(int(vertex_index) - 1) current_face_tex.append(int(tex_index) - 1) current_face_norm.append(int(normal_index) - 1) faces.append(current_face) faces_tex.append(current_face_tex) faces_normals.append(current_face_norm) face_groups.append(group_name) material_ids.append(material_name) vertices = np.array(vertices) texture_coords = np.array(texture_coords) vertex_normals = np.array(vertex_normals) has_tex_coord, has_normals = True, True if texture_coords.shape[0] == 0: has_tex_coord = False if vertex_normals.shape[0] == 0: has_normals = False faces = np.array(faces) faces_tex = np.array(faces_tex) faces_normals = np.array(faces_normals) n_faces = faces.shape[0] vertex_positions = np.zeros((n_faces, 3, 3), dtype=np.float32) tex_coords = np.zeros((n_faces, 3, 2), dtype=np.float32) normals = np.zeros((n_faces, 3, 3), dtype=np.float32) for i in range(n_faces): for j in range(3): vertex_positions[i, j, :] = vertices[faces[i, j], :] if has_tex_coord: tex_coords[i, j, :] = texture_coords[faces_tex[i, j], :2] if has_normals: normals[i, j, :] = vertex_normals[faces_normals[i, j], :] # Material info -------------------------------------------------------------- return vertex_positions, \ tex_coords, \ normals, \ material_ids, \ vertices, \ faces
def save_for_blender(detections, sample, log_dir, dict_clusters, shape_pointclouds, class_id_to_name=CLASSES): """Save for blender.""" # VisualDebugging uses the OpenCV coordinate representation # while the dataset uses OpenGL (left-hand) so make sure to convert y and z. batch_id = 0 prefix = '/cns/lu-d/home/giotto3d/datasets/shapenet/raw/' sufix = 'models/model_normalized.obj' blender_dict = {} blender_dict['image'] = \ tf.io.decode_image(sample['image_data'][batch_id]).numpy() blender_dict['world_to_cam'] = sample['rt'].numpy() num_predicted_shapes = int(detections['sizes_3d'].shape[0]) blender_dict['num_predicted_shapes'] = num_predicted_shapes blender_dict['predicted_rotations_3d'] = \ tf.reshape(detections['rotations_3d'], [-1, 3, 3]).numpy() blender_dict['predicted_rotations_y'] = [ tf_utils.euler_from_rotation_matrix( tf.reshape(detections['rotations_3d'][i], [3, 3]), 1).numpy() for i in range(num_predicted_shapes) ] blender_dict['predicted_translations_3d'] = \ detections['translations_3d'].numpy() blender_dict['predicted_sizes_3d'] = detections['sizes_3d'].numpy() predicted_shapes_path = [] for i in range(num_predicted_shapes): shape = detections['shapes'][i].numpy() _, class_str, model_str = dict_clusters[shape] filename = os.path.join(prefix, class_str, model_str, sufix) predicted_shapes_path.append(filename) blender_dict['predicted_shapes_path'] = predicted_shapes_path blender_dict['predicted_class'] = [ class_id_to_name[int(detections['detection_classes'][i].numpy())] for i in range(num_predicted_shapes) ] blender_dict['predicted_pointcloud'] = [ shape_pointclouds[int(detections['shapes'][i].numpy())] for i in range(num_predicted_shapes) ] num_groundtruth_shapes = int(sample['sizes_3d'][batch_id].shape[0]) blender_dict['num_groundtruth_shapes'] = num_groundtruth_shapes blender_dict['groundtruth_rotations_3d'] = \ tf.reshape(sample['rotations_3d'][batch_id], [-1, 3, 3]).numpy() blender_dict['groundtruth_rotations_y'] = [ tf_utils.euler_from_rotation_matrix( tf.reshape(sample['rotations_3d'][batch_id][i], [3, 3]), 1).numpy() for i in range(sample['num_boxes'][batch_id].numpy()) ] blender_dict['groundtruth_translations_3d'] = \ sample['translations_3d'][batch_id].numpy() blender_dict['groundtruth_sizes_3d'] = sample['sizes_3d'][batch_id].numpy() groundtruth_shapes_path = [] for i in range(num_groundtruth_shapes): class_str = str(sample['classes'][batch_id, i].numpy()).zfill(8) model_str = str(sample['mesh_names'][batch_id, i].numpy())[2:-1] filename = os.path.join(prefix, class_str, model_str, sufix) groundtruth_shapes_path.append(filename) blender_dict['groundtruth_shapes_path'] = groundtruth_shapes_path blender_dict['groundtruth_classes'] = \ sample['groundtruth_valid_classes'].numpy() path = log_dir + '.pkl' with gfile.Open(path, 'wb') as file: pickle.dump(blender_dict, file)
def add_detections(self, sample, detections): """Add detections to evaluation. Args: sample: the ground truth information detections: the predicted detections Returns: dict of intermediate results. """ result_dict = { 'iou_mean': -1, 'iou_min': -1, 'collisions': 0, 'collision_intersection': 0, 'collision_iou': 0 } num_boxes = sample['num_boxes'].numpy() for _, metric in self.metrics.items(): if isinstance(metric, ShapeAccuracyMetric): labels = sample['shapes'] weights = tf.math.sign(labels + 1) # -1 is mapped to zero, else 1 metric.update(labels, detections['shapes_logits'], weights) elif isinstance(metric, BoxIoUMetric): scene_id = str(sample['scene_filename'].numpy(), 'utf-8') # Get ground truth boxes labeled_boxes = tf.gather(sample['groundtruth_boxes'], axis=1, indices=[1, 0, 3, 2]) * 256.0 if metric.threed: rotations_y = tf.concat([ tf_utils.euler_from_rotation_matrix( tf.reshape(detections['rotations_3d'][i], [3, 3]), 1) for i in range(num_boxes) ], axis=0) rotations_y = tf.reshape(rotations_y, [-1, 1]) labeled_boxes = tf.concat([ sample['translations_3d'], sample['sizes_3d'], rotations_y ], axis=1) # Get predicted boxes predicted_boxes = detections['detection_boxes'] if metric.threed: rotations_y = tf.concat([ tf_utils.euler_from_rotation_matrix( tf.reshape(detections['rotations_3d'][i], [3, 3]), 1) for i in range(num_boxes) ], axis=0) rotations_y = tf.reshape(rotations_y, [-1, 1]) predicted_boxes = tf.concat([ detections['translations_3d'], detections['sizes_3d'], rotations_y ], axis=1) labeled_classes = tf.cast(sample['groundtruth_valid_classes'], tf.int64) predicted_classes = tf.cast(detections['detection_classes'], tf.int64) confidences = detections['detection_scores'] metric.update(scene_id, labeled_boxes, labeled_classes, predicted_boxes, predicted_classes, confidences) elif isinstance(metric, IoUMetric): classes = sample['classes'] mesh_names = sample['mesh_names'] labeled_sdfs = [] for i in range(num_boxes): class_id = str(classes[i].numpy()).zfill(8) model_name = str(mesh_names[i].numpy(), 'utf-8') path_prefix = os.path.join(self.shapenet_dir, class_id, model_name) file_sdf = os.path.join(path_prefix, 'model_normalized_sdf.npy') with gfile.Open(file_sdf, 'rb') as f: labeled_sdfs.append( tf.expand_dims(np.load(f).astype(np.float32), 0)) labeled_sdfs = tf.concat(labeled_sdfs, axis=0) labeled_classes = tf.cast(sample['groundtruth_valid_classes'], tf.int64) labeled_permutation = np.argsort(labeled_classes) labeled_sdfs = labeled_sdfs.numpy()[labeled_permutation] labeled_classes = labeled_classes.numpy()[labeled_permutation] labeled_rotations_3d = sample['rotations_3d'].numpy() labeled_rotations_3d = labeled_rotations_3d[ labeled_permutation] labeled_translations_3d = sample['translations_3d'].numpy() labeled_translations_3d = labeled_translations_3d[ labeled_permutation] labeled_sizes_3d = sample['sizes_3d'].numpy( )[labeled_permutation] labeled_poses = (labeled_rotations_3d, labeled_translations_3d, labeled_sizes_3d) # Predictions predicted_classes = tf.cast(detections['detection_classes'], tf.int64) predicted_permutation = np.argsort(predicted_classes) predicted_classes = predicted_classes.numpy( )[predicted_permutation] predicted_sdfs = \ detections['predicted_sdfs'].numpy()[predicted_permutation] predicted_rotations_3d = \ detections['rotations_3d'].numpy()[predicted_permutation] predicted_translations_3d = \ detections['translations_3d'].numpy()[predicted_permutation] predicted_sizes_3d = \ detections['sizes_3d'].numpy()[predicted_permutation] predicted_poses = (predicted_rotations_3d, predicted_translations_3d, predicted_sizes_3d) full_oracle = False if full_oracle: predicted_sdfs = detections['groundtruth_sdfs'].numpy() predicted_sdfs = predicted_sdfs[labeled_permutation] predicted_classes = labeled_classes predicted_poses = labeled_poses print('----------------------------') print(predicted_sdfs.shape) print(predicted_classes.shape) print(predicted_poses[0].shape) print(predicted_poses[1].shape) print(predicted_poses[2].shape) pose_oracle = False if pose_oracle: predicted_sdfs = detections['predicted_sdfs'].numpy() predicted_sdfs = predicted_sdfs[predicted_permutation] predicted_poses = (labeled_rotations_3d, labeled_translations_3d, labeled_sizes_3d) class_oracle = True if class_oracle: predicted_classes *= 0 labeled_classes *= 0 iou_mean, iou_min = metric.update( labeled_sdfs, labeled_classes, labeled_poses, predicted_sdfs, predicted_classes, predicted_poses, sample['dot']) result_dict['iou_mean'] = iou_mean result_dict['iou_min'] = iou_min elif isinstance(metric, CollisionMetric): labeled_sdfs = detections['groundtruth_sdfs'] labeled_classes = tf.cast(sample['groundtruth_valid_classes'], tf.int64) labeled_poses = (sample['rotations_3d'], sample['translations_3d'], sample['sizes_3d']) predicted_classes = tf.cast(detections['detection_classes'], tf.int64) predicted_sdfs = detections['predicted_sdfs'] predicted_poses = (detections['rotations_3d'], detections['translations_3d'], detections['sizes_3d']) full_oracle = False if full_oracle: predicted_sdfs = detections['groundtruth_sdfs'].numpy() predicted_classes = labeled_classes predicted_poses = labeled_poses num_collisions, intersection, iou = metric.update( labeled_sdfs, labeled_classes, labeled_poses, predicted_sdfs, predicted_classes, predicted_poses) result_dict['collisions'] = num_collisions result_dict['collision_intersection'] = intersection result_dict['collision_iou'] = iou return result_dict
def run_experiment(study_hparams=None, trial_handle=None, tuner=None): FLAGS = deepcopy(tf.app.flags.FLAGS) if FLAGS.use_vizier: for key, val in study_hparams.values().items(): setattr(FLAGS, key, val) tf.reset_default_graph() np.random.seed(FLAGS.random_seed) tf.set_random_seed(FLAGS.random_seed) # Initialize env env_kwargs = { 'goal_x': FLAGS.goal_x, 'min_goal_x': FLAGS.min_goal_x, 'max_goal_x': FLAGS.max_goal_x, 'x_threshold': FLAGS.x_threshold, 'max_reward_for_dist': FLAGS.max_reward_for_dist, 'reward_per_time_step': FLAGS.reward_per_time_step, 'fixed_initial_state': FLAGS.fixed_initial_state, 'reweight_rewards': FLAGS.reweight_rewards } env = cartpole.make_env(env_kwargs) eval_env = cartpole.make_env(env_kwargs) if not FLAGS.fixed_env: env.env.randomize() if trial_handle: tensorboard_path = os.path.join(FLAGS.output_dir, trial_handle) else: tensorboard_path = FLAGS.output_dir tf.gfile.MakeDirs(tensorboard_path) kwargs = dict(observation_shape=[None] + list(env.observation_space.shape), action_dim=1) default_hps = MetaQ.get_default_config().values() for key in flags_def: if key in default_hps: kwargs[key] = getattr(FLAGS, key) hps = tf.HParams(**kwargs) meta_q = MetaQ(hps, fully_connected_net(FLAGS.nn_arch, FLAGS.activation)) meta_q.build_graph() init_op = tf.global_variables_initializer() logger = TensorBoardLogger(tensorboard_path) with tf.Session() as sess: sess.run(init_op) meta_q.init_session(sess) inner_loop_buffer = MultiTaskReplayBuffer(len(env.env.goal_positions), 200000, FLAGS.random_seed) outer_loop_buffer = MultiTaskReplayBuffer(len(env.env.goal_positions), 200000, FLAGS.random_seed) pre_update_rewards = [] post_update_rewards = [] post_update_greedy_rewards = [] post_update_q_func = None for outer_step in range(FLAGS.outer_loop_steps): print('State is ', env.env.state) if outer_step % FLAGS.on_policy_steps == 0: if FLAGS.fixed_env: goal_positions = [env.env.goal_x] else: goal_positions = env.env.goal_positions # NOTE: Approximately ~30 to 60 states per trajectory inner_loop_buffer = collect_off_policy_data( env, goal_positions, meta_q, post_update_q_func, inner_loop_buffer, FLAGS.inner_loop_n_trajs, FLAGS.inner_loop_data_collection, FLAGS.inner_loop_greedy_epsilon, FLAGS.inner_loop_bolzmann_temp) outer_loop_buffer = collect_off_policy_data( env, goal_positions, meta_q, post_update_q_func, outer_loop_buffer, FLAGS.outer_loop_n_trajs, FLAGS.outer_loop_data_collection, FLAGS.outer_loop_greedy_epsilon, FLAGS.outer_loop_bolzmann_temp) post_update_greedy_rewards = [] finetuned_policy = None for task_id in range(FLAGS.n_meta_tasks): # print('Task: {}'.format(task_id)) if not FLAGS.fixed_env: env.env.randomize() (inner_observations, inner_actions, inner_rewards, inner_next_observations, inner_dones) = inner_loop_buffer.sample( env.env.task_id, FLAGS.inner_loop_n_states) # Evaluating true rewards post_update_q_func = meta_q.get_post_update_q_function( inner_observations, inner_actions, inner_rewards, inner_next_observations, inner_dones) policy = QPolicy(post_update_q_func, epsilon=0.0) if outer_step % FLAGS.report_steps == 0 or outer_step >= ( FLAGS.outer_loop_steps - 1): _, _, greedy_rewards, _, _ = cartpole_utils.collect_data( env, n_trajs=FLAGS.outer_loop_greedy_eval_n_trajs, policy=policy) post_update_greedy_rewards.append( np.sum(greedy_rewards) / FLAGS.outer_loop_greedy_eval_n_trajs) finetuned_policy = policy (outer_observations, outer_actions, outer_rewards, outer_next_observations, outer_dones) = outer_loop_buffer.sample( env.env.task_id, FLAGS.outer_loop_n_states) meta_q.accumulate_gradient( inner_observations, inner_actions, inner_rewards, inner_next_observations, inner_dones, outer_observations, outer_actions, outer_rewards, outer_next_observations, outer_dones, ) pre_update_loss, post_update_loss = meta_q.run_train_step() if not FLAGS.outer_loop_online_target and outer_step % FLAGS.target_update_freq == 0: print("updating target network") meta_q.update_target_network() log_data = dict( pre_update_loss=pre_update_loss, post_update_loss=post_update_loss, goal_x=env.env.goal_x, ) #TODO(hkannan): uncomment this later!!! if outer_step % FLAGS.report_steps == 0 or outer_step >= ( FLAGS.outer_loop_steps - 1): # reward_across_20_tasks = evaluate( # policy, eval_env, meta_q, # inner_loop_n_trajs=FLAGS.inner_loop_n_trajs, # outer_loop_n_trajs=FLAGS.outer_loop_n_trajs, n=21, # weight_rewards=FLAGS.weight_rewards) # log_data['reward_mean'] = np.mean(reward_across_20_tasks) # log_data['reward_variance'] = np.var(reward_across_20_tasks) log_data['post_update_greedy_reward'] = np.mean( post_update_greedy_rewards) log_data['post_update_greedy_reward_variance'] = np.var( post_update_greedy_rewards) print('Outer step: {}, '.format(outer_step), log_data) logger.log_dict(outer_step, log_data) # if outer_step % FLAGS.video_report_steps == 0 or outer_step >= (FLAGS.outer_loop_steps - 1): # video_data = { # 'env_kwargs': env_kwargs, # 'inner_loop_data_collection': FLAGS.inner_loop_data_collection, # 'inner_loop_greedy_epsilon': FLAGS.inner_loop_greedy_epsilon, # 'inner_loop_bolzmann_temp': FLAGS.inner_loop_bolzmann_temp, # 'inner_loop_n_trajs': FLAGS.inner_loop_n_trajs, # 'meta_q_kwargs': kwargs, # 'weights': meta_q.get_current_weights(), # 'tensorboard_path': tensorboard_path, # 'filename': 'random_task' # } # reward_across_20_tasks = evaluate( # policy, eval_env, meta_q, # inner_loop_n_trajs=FLAGS.inner_loop_n_trajs, # outer_loop_n_trajs=FLAGS.outer_loop_n_trajs, n=21, # weight_rewards=FLAGS.weight_rewards, video_data=video_data) # log_data['reward_mean'] = np.mean(reward_across_20_tasks) # log_data['reward_variance'] = np.var(reward_across_20_tasks) # logger.log_dict(outer_step, log_data) if outer_step >= (FLAGS.outer_loop_steps - 1): greedy_reward_path = os.path.join(tensorboard_path, 'reward') with gfile.Open(greedy_reward_path, mode='wb') as f: f.write(pickle.dumps( log_data['post_update_greedy_reward'])) if FLAGS.use_vizier: for v in log_data.values(): if not np.isfinite(v): tuner.report_done( infeasible=True, infeasible_reason='Nan or inf encountered') return if outer_step % FLAGS.report_steps == 0 or outer_step >= ( FLAGS.outer_loop_steps - 1): if FLAGS.vizier_objective == 'greedy_reward': objective_value = log_data['post_update_greedy_reward'] elif FLAGS.vizier_objective == 'loss': objective_value = post_update_loss elif FLAGS.vizier_objective == 'reward': objective_value = log_data['reward_mean'] else: raise ValueError('Unsupported vizier objective!') tuner.report_measure(objective_value=objective_value, global_step=outer_step, metrics=log_data) if FLAGS.use_vizier: tuner.report_done()