def run_scene(self, scene_i, save_dir): self.scene_info = {} args = self.args action_relative_or_absolute = 'absolute' logging.info("Start sim: {}".format(scene_i)) self.reset() self.step() # Generate anchor handle anchor_args = self.get_anchor_object_args(min_obj_size=0.02, max_obj_size=0.10) # Move object above air anchor_in_air = args.anchor_in_air if anchor_in_air: # Move the object up so that it is in the air. anchor_args['pos'][2] += npu(0.28, 0.3) anchor_handle = self.generate_object_with_args(anchor_args) self.handles_dict['anchor_obj'] = anchor_handle # Make sure that the the octree position is at the base of the octree_pos = [anchor_args['pos'][0], anchor_args['pos'][1], anchor_args['pos'][2] + anchor_args['obj_size'][2]/2.0] self.update_octree_position(octree_pos) # Move it slightly above to account for the fact that there can be # inaccuracies in real world voxels. if npu(0, 1) < 0.3: # Move the object up so that it is in the air. anchor_args['pos'][2] += npu(0.01, 0.02) for _ in range(4): self.step() anchor_pos = self.get_object_position(anchor_handle) logging.info("Anchor pos: {}".format(_array_to_str(anchor_pos))) # Move the robot to nearby (above) the anchor object. if anchor_in_air: # Place the other object down. If the anchor is in air then there # is no point in placing the other object above the anchor. Hence, # we always keep it below orig_joint_position = [ anchor_pos[0], anchor_pos[1], max(anchor_pos[2]-0.2, 0), 0, 0, 0] else: orig_joint_position = [anchor_pos[0], anchor_pos[1], anchor_pos[2]+0.2, 0, 0, 0] self.set_all_joint_positions(orig_joint_position) # Since the joint positions now have to be below the object move it below. if anchor_in_air: for _ in range(10): self.step() # Generate the other object other_pos = [0.0, 0.0, 0.0] other_obj_args = self.get_other_object_args(other_pos, add_obj_z=False, min_size=0.02, max_size=0.10) other_handle = self.generate_object_with_args(other_obj_args) for _ in range(5): self.step() init_joint_offsets = self.get_joint_offsets() # Move other obj to the right place anchor_abs_bb = self.get_absolute_bb(anchor_handle) anchor_lb, anchor_ub = anchor_abs_bb[0], anchor_abs_bb[1] new_other_obj_pos, sample_info = self.get_other_obj_location_outside_anchor_cuboid( anchor_args, anchor_pos, anchor_lb, anchor_ub, anchor_in_air, other_obj_args['obj_size']) if new_other_obj_pos is None: logging.info("Cannot place other obj near anchor.") return False self.scene_info['new_other_obj_pos'] = new_other_obj_pos self.scene_info['sample_info'] = sample_info # What sort of actions to take? Take all actions or only actions that # result in orientation change action_type = sample_info['action_type'] actions = [[-1, 0, 0], [1, 0, 0], [0, -1, 0], [0, 1, 0], [0, 0, -1], [0, 0, 1]] diagonal_actions = [ [-1, 1, 0], [-1, 0, 1], [-1, -1, 0], [-1, 0, -1], [1, 1, 0], [1, 0, 1], [1, -1, 0], [1, 0, -1], [0, 1, 1], [0, 1, -1], [0, -1, 1], [0, -1, 1] ] diagonal_actions_3d = [ [-1, -1, -1], [-1, -1, 1], [-1, 1, -1], [-1, 1, 1], [1, -1, -1], [1, -1, 1], [1, 1, -1], [1, 1, 1], ] actions += diagonal_actions actions += diagonal_actions_3d if action_relative_or_absolute == 'absolute': action_dist = 0.08 actions = [[action_dist*x[0], action_dist*x[1], action_dist*x[2]] for x in actions] if action_type == 'max_torque': axes_to_action_idx = {0: [0, 1], 1: [2, 3], 2: [4, 5]} valid_actions = copy.deepcopy(actions) for axes_i, axes_out in enumerate(sample_info['outside_xyz']): if not axes_out: assert len(axes_to_action_idx[axes_i]) == 2 valid_actions[axes_to_action_idx[axes_i][0]] = None valid_actions[axes_to_action_idx[axes_i][1]] = None else: # Why True? See below. take_positive_actions_only = True if take_positive_actions_only: # Make the negative action None valid_actions[axes_to_action_idx[axes_i][0]] = None else: # NOTE: Since we always take action equal to moving to the # center the following piece of logic is not required anymore, # because we take only positive actions region = sample_info['sample_edge_info']['sample_region_xyz'][axes_i] if region == 'low': valid_actions[axes_to_action_idx[axes_i][0]] = None assert new_other_obj_pos[axes_i] elif region == 'high': valid_actions[axes_to_action_idx[axes_i][1]] = None elif region == 'middle': raise ValueError("Invalid region: {}".format(region)) else: raise ValueError("Invalid region: {}".format(region)) logging.info("Filtered actions from {} to {}".format( len(actions), len([va for va in valid_actions if va is not None]))) actions = valid_actions has_created_scene = True for a_idx, a in enumerate(actions): if a is None: continue self.scene_info[a_idx] = { 'action': a, 'action_type': action_type, 'action_relative_or_absolute': action_relative_or_absolute } logging.info(bcolors.c_red( " ==== Scene {} action: {} ({}/{}) start ====".format( scene_i, a, a_idx, len(actions)))) if args.reset_every_action == 1 and not has_created_scene: anchor_handle = self.generate_object_with_args(anchor_args) self.handles_dict['anchor_obj'] = anchor_handle for _ in range(5): self.step() self.set_all_joint_positions(orig_joint_position) for _ in range(5): self.step() other_handle = self.generate_object_with_args(other_obj_args) self.handles_dict['other_obj'] = other_handle for _ in range(5): self.step() elif args.reset_every_action == 0 and has_created_scene: # First go to the previous position so that it's easy to go # back to the original position. # if new_other_obj_pos is not None: # self.set_prismatic_joints(new_other_obj_pos) # for _ in range(10): # self.step() # First go to rest position self.set_all_joint_positions(orig_joint_position) for _ in range(10): self.step() joint_pos = self.get_joint_position() logging.info("Current joint pos: {}".format(_array_to_str(joint_pos))) new_other_obj_waypoints, joint_offsets = self.get_waypoints_for_other_obj_pos( new_other_obj_pos, joint_pos, sample_info) logging.info(bcolors.c_yellow( "Should move other obj\n" " \t to: {}\n" " \t First move to: {}\n" " \t Now move to: {}\n" " \t joint offset: {}\n".format( _array_to_str(new_other_obj_pos), _array_to_str(new_other_obj_waypoints[0]), _array_to_str(new_other_obj_waypoints[1]), _array_to_str(joint_offsets)))) # Move to the first waypoint logging.info(bcolors.c_green("Move joints to: {}".format( _array_to_str(new_other_obj_waypoints[0])))) self.set_prismatic_joints(new_other_obj_waypoints[0]) for _ in range(15): self.step() _, temp_obj_data = self.get_all_objects_info() logging.info(bcolors.c_cyan( " \t Curr other obj location: {}\n" " \t Curr joint pos: {}\n" " \t Curr joint target pos: {}\n".format( _array_to_str(temp_obj_data['other_pos']), _array_to_str(temp_obj_data['joint_pos']), _array_to_str(temp_obj_data['joint_target_pos']), ))) logging.info(bcolors.c_green("Move joints to: {}".format( _array_to_str(new_other_obj_waypoints[1])))) self.set_prismatic_joints(new_other_obj_waypoints[1]) for _ in range(15): self.step() new_joint_offsets = self.get_joint_offsets() diff_offsets = [new_joint_offsets[i] - joint_offsets[i] for i in range(3)] _, first_obj_data_dict = self.get_all_objects_info() obj_at_desired_location = are_position_similar( first_obj_data_dict['other_pos'], new_other_obj_pos) if not obj_at_desired_location: logging.error(bcolors.c_cyan( "OBJECTS NOT at desired location!!\n" " \t curr: {}\n" " \t desired: {}\n" " \t joint_T curr: {}\n" " \t joint_T des: {}\n".format( _array_to_str(first_obj_data_dict['other_pos']), _array_to_str(new_other_obj_pos), _array_to_str(first_obj_data_dict['joint_target_pos']), _array_to_str(new_other_obj_waypoints[1]), ))) return False # Increase the pid values before taking the action. ''' old_pid_values = self.get_joint_pid_values() self.set_joint_pid_values(0.1, 0.0, 0.0) new_pid_values = self.get_joint_pid_values() logging.debug("Old pid values: {}\n" "new pid values: {}\n".format( _array_to_str(old_pid_values), _array_to_str(new_pid_values) )) ''' for _ in range(10): self.step() _, second_obj_data_dict = self.get_all_objects_info() obj_in_place = self.are_pos_and_orientation_similar( first_obj_data_dict, second_obj_data_dict) if not obj_in_place: logging.error(bcolors.c_cyan( "OBJECTS STILL MOVING!! Will sample again.")) # return False # if not self.are_objects_close(second_obj_data_dict): # logging.error("Objects are not close. Will sample again.") # return False self.scene_info[a_idx]['before_dist'] = self.get_object_distance()[-1] # save vision and octree info before taking action before_contact_info = self.get_contacts_info() # before_ft_data = self.read_force_torque_data() # voxels_before_dict = self.run_save_voxels_before( # anchor_handle, other_handle) voxels_before_dict = {} self.toggle_recording_contact_data() self.toggle_record_ft_data() self.step() # Now perform the action # after_action_joint_pos = [joint_pos[0], joint_pos[1], anchor_pos[2], 0, 0, 0] # The following actions are taken if we take "absolute actions" if action_relative_or_absolute == 'absolute': if action_type == 'max_torque': diff = [] for ai in range(3): if anchor_args['pos'][ai]-new_other_obj_waypoints[-1][ai] >= 0: diff.append(a[ai]) else: diff.append(-a[ai]) after_action_joint_pos = [new_other_obj_waypoints[-1][i]+diff[i] for i in range(3)] else: after_action_joint_pos = [new_other_obj_waypoints[-1][i] + a[i] for i in range(3)] elif action_relative_or_absolute == 'relative': # Should we move fixed distance or adaptive distnaces? diff = [a[i]*(anchor_args['pos'][i]-new_other_obj_waypoints[-1][i]) for i in range(3)] after_action_joint_pos = [new_other_obj_waypoints[-1][i]+diff[i] for i in range(3)] else: raise ValueError(f"Invalid action {action_relative_or_absolute}") logging.info(f"After action pos: {after_action_joint_pos}") self.set_all_joint_positions(after_action_joint_pos + [0, 0, 0]) for _ in range(25): self.step() # Stop the scene? # self.stop() _, third_obj_data_dict = self.get_all_objects_info() obj_in_place = self.are_pos_and_orientation_similar( first_obj_data_dict, second_obj_data_dict) self.debug_initial_final_position( first_obj_data_dict, second_obj_data_dict) self.scene_info[a_idx]['before'] = second_obj_data_dict self.scene_info[a_idx]['after'] = third_obj_data_dict if not obj_in_place: logging.info("Objects changed position AFTER action.") self.scene_info[a_idx]['after_dist'] = self.get_object_distance()[-1] self.toggle_recording_contact_data() self.toggle_record_ft_data() _, recorded_contact_info = self.save_recorded_contact_data() recorded_ft_data = self.save_record_ft_data() ft_sensor_mean = np.mean(np.array(recorded_ft_data).reshape(-1, 6), axis=0) self.scene_info[a_idx]['ft_sensor_mean'] = ft_sensor_mean.tolist() self.step() logging.info(bcolors.c_cyan("Mean force-torque val: {}".format( _array_to_str(ft_sensor_mean)))) # Now save vision and octree info after taking action. after_contact_info = self.get_contacts_info() voxels_after_dict = self.run_save_voxels_after( anchor_handle, other_handle) for _ in range(10): self.step() self.save_scene_data( save_dir, a_idx, self.scene_info, voxels_before_dict, voxels_after_dict, before_contact_info, after_contact_info, recorded_contact_info, recorded_ft_data, ) self.debug_initial_final_position(second_obj_data_dict, third_obj_data_dict) before_contact_len = 0 if before_contact_info is None else before_contact_info.shape[0] after_contact_len = 0 if after_contact_info is None else after_contact_info.shape[0] logging.info("Contact len: before: {}, after: {}".format( before_contact_len, after_contact_len)) logging.info(" ==== Scene {} action: {} Done ====".format( scene_i, a)) if args.reset_every_action == 1: self.reset() has_created_scene = False return True
def run_scene(self, scene_i, save_dir): self.scene_info = {} args = self.args logging.info("Start sim: {}".format(scene_i)) self.reset() self.step() # Generate anchor handle anchor_pos = [0.0, 0.0, 0.0] anchor_args = self.get_anchor_object_args() anchor_handle = self.generate_object_with_args(anchor_args) self.handles_dict['anchor_obj'] = anchor_handle for _ in range(4): self.step() anchor_pos = self.get_object_position(anchor_handle) logging.info("Anchor pos: {}".format(_array_to_str(anchor_pos))) # Move the robot to nearby (above) the anchor object. orig_joint_position = self.get_joint_position() orig_joint_position = [ anchor_pos[0], anchor_pos[1], anchor_pos[2] + 0.2, 0, 0, 0 ] self.set_all_joint_positions(orig_joint_position) # Generate the other object other_pos = [0.0, 0.0, 0.0] other_obj_args = self.get_other_object_args(other_pos, add_obj_z=False) other_handle = self.generate_object_with_args(other_obj_args) for _ in range(5): self.step() init_joint_offsets = self.get_joint_offsets() # Move other obj to the right place anchor_abs_bb = self.get_absolute_bb(anchor_handle) anchor_lb, anchor_ub = anchor_abs_bb[0], anchor_abs_bb[1] new_other_obj_pos, sample_info = self.get_other_obj_location_outside_anchor_cuboid( anchor_args, anchor_pos, anchor_lb, anchor_ub, other_obj_args['obj_size']) if new_other_obj_pos is None: logging.info("Cannot place other obj near anchor.") return self.scene_info['new_other_obj_pos'] = new_other_obj_pos self.scene_info['sample_info'] = sample_info actions = [[-.1, 0, 0], [.1, 0, 0], [0, -.1, 0], [0, .1, 0], [0, 0, -.1], [0, 0, .1]] filter_actions_to_max_torque = True # import ipdb; ipdb.set_trace() if filter_actions_to_max_torque and sample_info['sample_on_edge']: axes_to_action_idx = {0: [0, 1], 1: [2, 3], 2: [4, 5]} valid_actions = copy.deepcopy(actions) for axes_i in range(len(sample_info['outside_xyz'])): if not axes_i: assert len(axes_to_action_idx[axes_i]) == 2 valid_actions[axes_to_action_idx[axes_i][0]] = None valid_actions[axes_to_action_idx[axes_i][1]] = None else: region = sample_info['sample_inside_around_edges_info'][ 'region'] if region == 'low': valid_actions[axes_to_action_idx[axes_i][0]] = None elif region == 'high': valid_actions[axes_to_action_idx[axes_i][1]] = None elif region == 'middle': pass else: raise ValueError("Invalid region: {}".format(region)) logging.info("Filtered actions from {} to {}".format( len(actions), len(valid_actions))) actions = valid_actions else: raise ValueError("WTF") # actions = [[-.1, 0, 0], [.1, 0, 0]] has_created_scene = True for a_idx, a in enumerate(actions): if a is None: continue self.scene_info[a_idx] = {'action': a} logging.info( bcolors.c_red( " ==== Scene {} action: {} ({}/{}) start ====".format( scene_i, a, a_idx, len(actions)))) if args.reset_every_action == 1 and not has_created_scene: anchor_handle = self.generate_object_with_args(anchor_args) self.handles_dict['anchor_obj'] = anchor_handle for _ in range(5): self.step() self.set_all_joint_positions(orig_joint_position) for _ in range(5): self.step() other_handle = self.generate_object_with_args(other_obj_args) self.handles_dict['other_obj'] = other_handle for _ in range(5): self.step() else: # First go to rest position self.set_all_joint_positions(orig_joint_position) for _ in range(50): self.step() joint_pos = self.get_joint_position() logging.info("Current joint pos: {}".format( _array_to_str(joint_pos))) new_other_obj_waypoints, joint_offsets = self.get_waypoints_for_other_obj_pos( new_other_obj_pos, joint_pos, sample_info) logging.info( bcolors.c_yellow("Should move other obj\n" " \t to: {}\n" " \t First move to: {}\n" " \t Now move to: {}\n" " \t joint offset: {}\n".format( _array_to_str(new_other_obj_pos), _array_to_str(new_other_obj_waypoints[0]), _array_to_str(new_other_obj_waypoints[1]), _array_to_str(joint_offsets)))) # Move to the first waypoint logging.info( bcolors.c_green("Move joints to: {}".format( _array_to_str(new_other_obj_waypoints[0])))) self.set_prismatic_joints(new_other_obj_waypoints[0]) for _ in range(25): self.step() _, temp_obj_data = self.get_all_objects_info() logging.info( bcolors.c_cyan( " \t Curr other obj location: {}\n" " \t Curr joint pos: {}\n" " \t Curr joint target pos: {}\n".format( _array_to_str(temp_obj_data['other_pos']), _array_to_str(temp_obj_data['joint_pos']), _array_to_str(temp_obj_data['joint_target_pos']), ))) logging.info( bcolors.c_green("Move joints to: {}".format( _array_to_str(new_other_obj_waypoints[1])))) self.set_prismatic_joints(new_other_obj_waypoints[1]) for _ in range(25): self.step() _, first_obj_data_dict = self.get_all_objects_info() obj_at_desired_location = are_position_similar( first_obj_data_dict['other_pos'], new_other_obj_pos) if not obj_at_desired_location: logging.error( bcolors.c_cyan( "OBJECTS NOT at desired location!!\n" " \t curr: {}\n" " \t desired: {}\n" " \t joint_T curr: {}\n" " \t joint_T des: {}\n".format( _array_to_str(first_obj_data_dict['other_pos']), _array_to_str(new_other_obj_pos), _array_to_str( first_obj_data_dict['joint_target_pos']), _array_to_str(new_other_obj_waypoints[1]), ))) # import ipdb; ipdb.set_trace() return False # Increase the pid values before taking the action. ''' old_pid_values = self.get_joint_pid_values() self.set_joint_pid_values(0.1, 0.0, 0.0) new_pid_values = self.get_joint_pid_values() logging.debug("Old pid values: {}\n" "new pid values: {}\n".format( _array_to_str(old_pid_values), _array_to_str(new_pid_values) )) ''' for _ in range(10): self.step() _, second_obj_data_dict = self.get_all_objects_info() obj_in_place = self.are_pos_and_orientation_similar( first_obj_data_dict, second_obj_data_dict) if not obj_in_place: logging.error( bcolors.c_cyan( "OBJECTS STILL MOVING!! Will sample again.")) return False # if not self.are_objects_close(second_obj_data_dict): # logging.error("Objects are not close. Will sample again.") # return False self.scene_info[a_idx]['before_dist'] = self.get_object_distance( )[-1] # save vision and octree info before taking action before_contact_info = self.get_contacts_info() # before_ft_data = self.read_force_torque_data() voxels_before_dict = self.run_save_voxels_before( anchor_handle, other_handle) self.toggle_recording_contact_data() self.toggle_record_ft_data() self.step() # Now perform the action # after_action_joint_pos = [joint_pos[0], joint_pos[1], anchor_pos[2], 0, 0, 0] after_action_joint_pos = [ new_other_obj_waypoints[-1][i] + a[i] for i in range(3) ] self.set_all_joint_positions(after_action_joint_pos + [0, 0, 0]) for _ in range(25): self.step() # Stop the scene? # self.stop() _, third_obj_data_dict = self.get_all_objects_info() obj_in_place = self.are_pos_and_orientation_similar( first_obj_data_dict, second_obj_data_dict) self.debug_initial_final_position(first_obj_data_dict, second_obj_data_dict) self.scene_info[a_idx]['before'] = second_obj_data_dict self.scene_info[a_idx]['after'] = third_obj_data_dict if not obj_in_place: logging.info("Objects changed position AFTER action.") self.scene_info[a_idx]['after_dist'] = self.get_object_distance( )[-1] self.toggle_recording_contact_data() self.toggle_record_ft_data() _, recorded_contact_info = self.save_recorded_contact_data() recorded_ft_data = self.save_record_ft_data() ft_sensor_mean = np.mean(np.array(recorded_ft_data).reshape(-1, 6), axis=0) self.scene_info[a_idx]['ft_sensor_mean'] = ft_sensor_mean.tolist() self.step() logging.info("Mean force-torque val: {}".format( _array_to_str(ft_sensor_mean))) # Now save vision and octree info after taking action. after_contact_info = self.get_contacts_info() voxels_after_dict = self.run_save_voxels_after( anchor_handle, other_handle) for _ in range(10): self.step() self.save_scene_data( save_dir, a_idx, self.scene_info, voxels_before_dict, voxels_after_dict, before_contact_info, after_contact_info, recorded_contact_info, recorded_ft_data, ) self.debug_initial_final_position(second_obj_data_dict, third_obj_data_dict) before_contact_len = 0 if before_contact_info is None else before_contact_info.shape[ 0] after_contact_len = 0 if after_contact_info is None else after_contact_info.shape[ 0] logging.info("Contact len: before: {}, after: {}".format( before_contact_len, after_contact_len)) logging.info(" ==== Scene {} action: {} Done ====".format( scene_i, a)) if args.reset_every_action == 1: self.reset() has_created_scene = False return True
def train(self, train=True, viz_images=False, save_embedding=True, use_emb_data=False, log_prefix=''): print("Begin training") args = self.config.args log_freq_iters = args.log_freq_iters if train else 10 dataloader = self.dataloader device = self.config.get_device() if use_emb_data: train_data_size = dataloader.get_h5_data_size(train) else: train_data_size = dataloader.get_data_size(train) train_data_idx_list = list(range(0, train_data_size)) # Reset log counter train_step_count, test_step_count = 0, 0 self.set_model_device(device) result_dict = { 'data_info': { 'path': [], 'info': [], }, 'emb': { 'train_img_emb': [], 'test_img_emb': [], 'train_gt': [], 'train_pred': [], 'test_gt': [], 'tese_pred': [], }, 'output': { 'gt': [], 'pred': [], 'test_f1_score': [], 'test_wt_f1_score': [], 'test_conf': [], }, 'conf': { 'train': [], 'test': [], } } num_epochs = args.num_epochs if train else 1 for e in range(num_epochs): if train: iter_order = np.random.permutation(train_data_idx_list) else: iter_order = np.arange(train_data_size) batch_size = args.batch_size if train else 32 num_batches = train_data_size // batch_size data_idx = 0 n_classes = args.classif_num_classes result_dict['conf']['train'].append( np.zeros((n_classes, n_classes), dtype=np.int32)) for k in ['gt', 'pred']: result_dict['output'][k] = [] for k in ['train_img_emb', 'train_gt', 'train_pred']: result_dict['emb'][k] = [] for batch_idx in range(num_batches): # Get raw data from the dataloader. batch_data = [] # for b in range(batch_size): batch_get_start_time = time.time() while len(batch_data) < batch_size and data_idx < len( iter_order): actual_data_idx = iter_order[data_idx] if use_emb_data: data = dataloader.get_h5_train_data_at_idx( actual_data_idx, train=train) else: data = dataloader.get_train_data_at_idx( actual_data_idx, train=train) batch_data.append(data) data_idx = data_idx + 1 batch_get_end_time = time.time() # print("Data time: {:.4f}".format( # batch_get_end_time - batch_get_start_time)) # Process raw batch data proc_data_start_time = time.time() x_dict = self.process_raw_batch_data(batch_data) # Now collate the batch data together x_tensor_dict = self.collate_batch_data_to_tensors(x_dict) proc_data_end_time = time.time() run_batch_start_time = time.time() model_fn = self.run_emb_model_on_batch \ if use_emb_data else self.run_model_on_batch batch_result_dict = model_fn(x_tensor_dict, batch_size, train=train, save_preds=True) run_batch_end_time = time.time() # print("Batch get: {:4f} \t proc data: {:.4f} \t run: {:.4f}".format( # batch_get_end_time - batch_get_start_time, # proc_data_end_time - proc_data_start_time, # run_batch_end_time - run_batch_start_time # )) if args.loss_type == 'classif': result_dict['conf']['train'][-1] += batch_result_dict[ 'conf'] result_dict['output']['gt'].append( batch_result_dict['gt_label']) result_dict['output']['pred'].append( batch_result_dict['pred_label']) for b in range(len(batch_data)): result_dict['emb']['train_img_emb'].append( to_numpy(batch_result_dict['img_emb'][b])) result_dict['emb']['train_gt'].append( batch_result_dict['gt_label'][b]) result_dict['emb']['train_pred'].append( batch_result_dict['pred_label'][b]) self.print_train_update_to_console(e, num_epochs, batch_idx, num_batches, train_step_count, batch_result_dict) plot_images = viz_images and train \ and train_step_count % log_freq_iters == 0 plot_loss = train \ and train_step_count % args.print_freq_iters == 0 if train: self.plot_train_update_to_tensorboard( x_dict, x_tensor_dict, batch_result_dict, train_step_count, plot_loss=plot_loss, plot_images=plot_images, ) if train: if train_step_count % log_freq_iters == 0: self.log_model_to_tensorboard(train_step_count) # Save trained models if train_step_count % args.save_freq_iters == 0: self.save_checkpoint(train_step_count) # Run current model on val/test data. if train_step_count % args.test_freq_iters == 0: # Remove old stuff from memory x_dict = None x_tensor_dict = None batch_result_dict = None torch.cuda.empty_cache() for k in ['test_img_emb', 'test_gt', 'test_pred']: result_dict['emb'][k] = [] test_batch_size = args.batch_size if use_emb_data: test_data_size = self.dataloader.get_h5_data_size( train=False) else: test_data_size = self.dataloader.get_data_size( train=False) num_batch_test = test_data_size // test_batch_size if test_data_size % test_batch_size != 0: num_batch_test += 1 # Do NOT sort the test data. test_iter_order = np.arange(test_data_size) test_data_idx, total_test_loss = 0, 0 all_gt_label_list, all_pred_label_list = [], [] self.set_model_to_eval() result_dict['conf']['test'].append( np.zeros((n_classes, n_classes), dtype=np.int32)) print(bcolors.c_yellow("==== Test begin ==== ")) for test_e in range(num_batch_test): batch_data = [] while (len(batch_data) < test_batch_size and test_data_idx < len(test_iter_order)): if use_emb_data: data = dataloader.get_h5_train_data_at_idx( test_iter_order[test_data_idx], train=False) else: data = dataloader.get_train_data_at_idx( test_iter_order[test_data_idx], train=False) batch_data.append(data) test_data_idx = test_data_idx + 1 # Process raw batch data x_dict = self.process_raw_batch_data(batch_data) # Now collate the batch data together x_tensor_dict = self.collate_batch_data_to_tensors( x_dict) with torch.no_grad(): model_fn = self.run_emb_model_on_batch \ if use_emb_data else self.run_model_on_batch batch_result_dict = model_fn(x_tensor_dict, test_batch_size, train=False, save_preds=True) total_test_loss += batch_result_dict[ 'total_loss'] all_gt_label_list.append( batch_result_dict['gt_label']) all_pred_label_list.append( batch_result_dict['pred_label']) for b in range(len(batch_data)): result_dict['emb']['test_img_emb'].append( to_numpy( batch_result_dict['img_emb'][b])) result_dict['conf']['test'][ -1] += batch_result_dict['conf'] self.print_train_update_to_console( e, num_epochs, test_e, num_batch_test, train_step_count, batch_result_dict, train=False) plot_images = test_e == 0 plot_loss = True self.plot_train_update_to_tensorboard( x_dict, x_tensor_dict, batch_result_dict, test_step_count, plot_loss=plot_loss, plot_images=plot_images, log_prefix='/test/') test_step_count += 1 # Calculate metrics gt_label = np.concatenate(all_gt_label_list) pred_label = np.concatenate(all_pred_label_list) normal_f1 = f1_score(gt_label, pred_label) wt_f1 = f1_score(gt_label, pred_label, average='weighted') self.logger.summary_writer.add_scalar( '/metrics/test/normal_f1', normal_f1, test_step_count) self.logger.summary_writer.add_scalar( '/metrics/test/wt_f1', wt_f1, test_step_count) result_dict['output']['test_f1_score'].append( normal_f1) result_dict['output']['test_wt_f1_score'].append(wt_f1) result_dict['output']['test_conf'].append( result_dict['conf']['test'][-1]) result_dict['emb']['test_gt'] = np.copy(gt_label) result_dict['emb']['test_pred'] = np.copy(pred_label) # Plot the total loss on the entire dataset. Hopefull, # this would decrease over time. self.logger.summary_writer.add_scalar( '/test/all_batch_loss/loss_avg', total_test_loss / max(num_batch_test, 1), test_step_count) self.logger.summary_writer.add_scalar( '/test/all_batch_loss/loss', total_test_loss, test_step_count) print( bcolors.c_yellow( "Test: \t F1: {:.4f}\n" " \t Wt-F1: {:.4f}\n" " \t conf:\n{}".format( normal_f1, wt_f1, np.array_str( result_dict['conf']['test'][-1], precision=0)))) print( bcolors.c_yellow(' ==== Test Epoch conf end ====')) x_dict = None x_tensor_dict = None batch_result_dict = None torch.cuda.empty_cache() if train: self.set_model_to_train() train_step_count += 1 torch.cuda.empty_cache() self.did_end_train_epoch() for k in ['gt', 'pred']: result_dict['output'][k] = np.concatenate( result_dict['output'][k]).astype(np.int32) normal_f1 = f1_score(result_dict['output']['gt'], result_dict['output']['pred']) wt_f1 = f1_score(result_dict['output']['gt'], result_dict['output']['pred'], average='weighted') self.logger.summary_writer.add_scalar('/metrics/train/normal_f1', normal_f1, train_step_count) self.logger.summary_writer.add_scalar('/metrics/train/wt_f1', wt_f1, train_step_count) if args.loss_type == 'classif': print( bcolors.c_red("Train: \t F1: {:.4f}\n" " \t Wt-F1: {:.4f}\n" " \t conf:\n{}".format( normal_f1, wt_f1, np.array_str( result_dict['conf']['train'][-1], precision=0)))) # Find min wt f1 if len(result_dict['output']['test_wt_f1_score']) > 0: max_f1_idx = np.argmax( result_dict['output']['test_wt_f1_score']) print( bcolors.c_cyan( "Max test wt f1:\n" " \t F1: {:.4f}\n" " \t Wt-F1: {:.4f}\n" " \t conf:\n{}".format( result_dict['output']['test_f1_score'] [max_f1_idx], result_dict['output'] ['test_wt_f1_score'][max_f1_idx], np.array_str( result_dict['conf']['test'][max_f1_idx], precision=0)))) save_emb_data_to_h5(args.result_dir, result_dict) print(' ==== Epoch done ====') for k in ['train_gt', 'train_pred']: result_dict['emb'][k] = np.array(result_dict['emb'][k]) return result_dict