def generate_symbolic_dataset(dataset_path, basepath) -> None: derived_data = 'bimacs_derived_data' rgbd_data_ground_truth = 'bimacs_rgbd_data_ground_truth' os.makedirs(os.path.join(basepath, 'dataset_caches'), exist_ok=True) cachefile = os.path.join(basepath, 'dataset_caches', 'symbolic_dataset.cache') recs = {} ac.print2('Generating symbolic dataset.') derived_data_paths, ground_truth_paths = get_raw_dataset_paths( os.path.join(dataset_path, derived_data), os.path.join(dataset_path, rgbd_data_ground_truth)) for derived_data_path, ground_truth_path in zip(derived_data_paths, ground_truth_paths): *_, subject, task, take = derived_data_path.split('/') ac.print1('Generating objects for {} | {} | {}.'.format( subject, task, take)) if subject not in recs: recs[subject] = {} if task not in recs[subject]: recs[subject][task] = {} rec = Recording(derived_data_path, ground_truth_path) recs[subject][task][take] = rec # Write cache file. with open(cachefile, 'wb') as f: ac.print2('Writing symbolic dataset cachefile.') pickle.dump(recs, f) ac.print3('Done writing cachefile.')
def load(self): with open(self.path, 'rb') as f: try: return getattr(self, 'load_{}'.format(self.mode))(pickle.load(f)) except pickle.UnpicklingError as e: ac.print1('Error while unpickling `{}`: {}. Retrying.'.format( self.path, e)) time.sleep(0.1) try: return getattr(self, 'load_{}'.format(self.mode))(pickle.load(f)) except pickle.UnpicklingError: ac.print0('Repeatedly failed to unpickle `{}`'.format( self.path)) ac.print0(traceback.format_exc()) raise Exception('Repeatedly failed to unpickle file.')
def predict(args) -> int: assert os.path.exists( args.basepath), 'Basepath `{}` is not a valid path.'.format( args.basepath) assert os.path.exists(os.path.join(args.basepath, 'dataset_caches', args.dataset_config)), \ 'Dataset config `{}` does not exist in `{}/dataset_caches`'.format(args.dataset_config, args.basepath) assert args.evaluation_id in [1, 2, 3, 4, 5, 6], 'Invalid subject ID: {}.'.format( args.evaluation_id) namespace_path = os.path.join(args.basepath, args.namespace) fold_path = os.path.join(namespace_path, 'leave_out_{}'.format(args.evaluation_id)) assert os.path.exists( namespace_path), 'Namespace `{}` does not exist.'.format( args.namespace) assert os.path.exists( fold_path), 'Fold for subject {} was not trained yet.'.format( args.evaluation_id) ac.print1('Prepare data.') test_set = ac.dataset.load( args.basepath, args.dataset_config, evaluation_mode=args.evaluation_mode, filter_if=lambda x: x.subject != args.evaluation_id) test_set.sort(key=lambda x: (x.subject, x.task, x.take, x.side, x.frame)) # Ensure that the dataset is valid. ac.print3('Test size: {}.'.format(len(test_set))) assert all(x.subject == args.evaluation_id for x in test_set) ac.print3('All records in test set are from subject #{}.'.format( args.evaluation_id)) ac.print1('Importing model.') import actclass.model model = ac.model.ActionClassifierModel( fold_path, processing_steps_count=args.processing_steps_count, layer_count=args.layer_count, neuron_count=args.neuron_count) model.predict(test_set, args.restore) return STATUS_OK
def main(argv: List[str]): ac.print0('Running on `{}`.'.format(socket.gethostname())) env = { 'dataset_path_default': os.getenv('BIMACS_DATASET_PATH', None), 'basepath_default': os.getenv('BIMACS_BASEPATH', None) } args = parse_args(argv, env) try: code = getattr(ac.exec, args.command)(args) except KeyboardInterrupt: ac.print0('Interrupted by user.') code = STATUS_INTERRUPTED except Exception: ac.print0('Exception occured!', silent=True) ac.print0(traceback.format_exc(), silent=True) code = STATUS_UNHANDLED_EXCEPTION ac.print1('Exiting with code {}.'.format(code)) ac.wait_for_logfile() return code
def dataset(args) -> int: assert os.path.exists( args.basepath), 'Basepath `{}` is not a valid path.'.format( args.basepath) dataset_config = 'h{}'.format(args.history_size) assert not os.path.exists(os.path.join(args.basepath, 'dataset_caches', dataset_config)), \ 'Dataset configuration `{}` already exists in `{}/dataset_caches`'.format(dataset_config, args.basepath) if not ac.dataset.symbolic_datset_exists(args.basepath): ac.print1( "Symbolic dataset does not exist. Generating now from --raw-dataset-path='{}'." .format(args.raw_dataset_path)) ac.dataset.generate_symbolic_dataset(args.raw_dataset_path, args.basepath) ac.print1('Generating dataset.') ac.dataset.generate_dataset(args.basepath, dataset_config, args.history_size) return STATUS_OK
def generate_dataset(basepath: str, config: str, history_size: int) -> None: recs = load_symbolic(basepath) cachepath = os.path.join(basepath, 'dataset_caches') def write_frame(s, ts, tk, fr, graphs_to_write): # Write cache file os.makedirs(os.path.join(cachepath, config, s, ts, tk), exist_ok=True) for side in ['left', 'right']: cachefile = os.path.join(cachepath, config, s, ts, tk, 'frame_{}_{}.cache'.format(fr, side)) written = False while not written: try: with open(cachefile, 'wb') as f: pickle.dump(graphs_to_write[side], f) except IOError: ac.print0('Error writing frame. Retrying...') time.sleep(0.25) else: written = True ac.print2('Generating dataset.') for subject, task, take in crawl_dataset(): ac.print1('Generating graphs data for {} | {} | {}.'.format( subject, task, take)) recording: Recording = recs[subject][task][take] recording.check_integrity() for i in range(0, recording.frame_count, 1): sgl = recording.to_scene_graphs(i, history_size=history_size) scene_graph = flatten_scene_graphs(sgl) scene_graph.check_integrity() graphs = { 'right': scene_graph.to_data_dict(mirrored=False), 'left': scene_graph.to_data_dict(mirrored=True) } write_frame(subject, task, take, i, graphs)
def train(args) -> int: assert os.path.exists( args.basepath), 'Basepath `{}` is not a valid path.'.format( args.basepath) assert os.path.exists(os.path.join(args.basepath, 'dataset_caches', args.dataset_config)), \ 'Dataset config `{}` does not exist in `{}/dataset_caches`'.format(args.dataset_config, args.basepath) assert args.evaluation_id in [1, 2, 3, 4, 5, 6], 'Invalid subject ID: {}.'.format( args.evaluation_id) namespace_path = os.path.join(args.basepath, args.namespace) fold_path = os.path.join(namespace_path, 'leave_out_{}'.format(args.evaluation_id)) assert os.path.exists( namespace_path), 'Namespace `{}` does not exist.'.format( args.namespace) if args.restore is None: assert not os.path.exists( fold_path ), 'An evaluation was already started for subject #{}.'.format( args.evaluation_id) # Setup logging. os.makedirs(fold_path, exist_ok=True) ac.logfile_path = fold_path ac.write_logfile = True ac.print1('Persistent logging enabled.') def is_ev_sub(fs: ac.dataset.SceneGraphProxy): return fs.subject == args.evaluation_id def is_vld(fs: ac.dataset.SceneGraphProxy): return fs.take == args.validation_id ac.print1('Prepare data.') train_set = ac.dataset.load(args.basepath, args.dataset_config, evaluation_mode=args.evaluation_mode, filter_if=lambda x: is_ev_sub(x) or is_vld(x)) valid_set = ac.dataset.load( args.basepath, args.dataset_config, evaluation_mode=args.evaluation_mode, filter_if=lambda x: is_ev_sub(x) or not is_vld(x)) random.shuffle(train_set) random.shuffle(valid_set) # Ensure that the datasets are valid. ac.print3('Train/Validation sizes: {} vs. {}.'.format( len(train_set), len(valid_set))) assert len(train_set) > len(valid_set) assert all(x.subject != args.evaluation_id for x in train_set), 'Test data in train set!' assert all(x.take != args.validation_id for x in train_set), 'Validation data in train set!' ac.print3( 'All records in train set are not from subject #{} and are not take #{}' .format(args.evaluation_id, args.validation_id)) assert all(x.subject != args.evaluation_id for x in valid_set), 'Test data in validation set!' assert all(x.take == args.validation_id for x in valid_set), 'Non-validation data in validation set!' ac.print3( 'All records in valid set are not from subject #{} and are take #{}.'. format(args.evaluation_id, args.validation_id)) ac.print1('Importing model.') import actclass.model model = ac.model.ActionClassifierModel( fold_path, processing_steps_count=args.processing_steps_count, layer_count=args.layer_count, neuron_count=args.neuron_count) model.train(train_set, valid_set, restore=args.restore, max_iteration=args.max_iteration, log_interval=args.log_interval, save_interval=args.save_interval) return STATUS_OK
def parse_args(argv: List[str], env: Dict[str, str]) -> argparse.Namespace: parser = argparse.ArgumentParser(prog='ac') subparsers = parser.add_subparsers(dest='command') mkevenv_parser = subparsers.add_parser('mkevenv') train_parser = subparsers.add_parser('train') predict_parser = subparsers.add_parser('predict') dataset_parser = subparsers.add_parser('dataset') evaluate_parser = subparsers.add_parser('evaluate') # Common arguments. # All parsers need the basepath. for p in [ mkevenv_parser, train_parser, predict_parser, dataset_parser, evaluate_parser ]: p.add_argument('-b', '--basepath', type=str, required=env['basepath_default'] is None, default=env['basepath_default'], help='Basepath to the namespace folders') p.add_argument('--verbose', '-v', action='count', dest='verbosity', default=0, help='Sets the verbosity') # Namespace argument. for p in [mkevenv_parser, train_parser, predict_parser, evaluate_parser]: p.add_argument('-n', '--namespace', type=str, required=True, help='String identifier of the namespace') # Left out subject ID, restore point for model. for p in [train_parser, predict_parser]: p.add_argument( '-e', '--evaluation-id', metavar='[1,2,3,4,5,6]', type=int, choices=[1, 2, 3, 4, 5, 6], required=True, help='Numerical identifier of the left-out evaluation subject') p.add_argument('--restore', type=int, required=(p == predict_parser), help='Iteration number of the state to restore') # Make evaluation environment. mkevenv_parser.add_argument( '--dataset-config', type=str, required=True, help='String identifier of the dataset configuration') mkevenv_parser.add_argument( '--processing-steps-count', type=int, default=10, help='Number of processing steps the graph network should perform.') mkevenv_parser.add_argument( '--layer-count', type=int, default=2, help='Number of layers used for the MLPs in the graph network.') mkevenv_parser.add_argument( '--neuron-count', type=int, default=32, help= 'Number of neurons per layer used for the MLPs in the graph network.') mkevenv_parser.add_argument( '--validation-id', type=int, default=7, help='ID of the takes to use as validation set') mkevenv_parser.add_argument( '--evaluation-mode', type=str, choices=['normal', 'contact', 'centroids'], required=True, help='Evaluation mode. (normal, contacts only, bb centroids.)') # Train. train_parser.add_argument( '--max-iteration', type=int, default=3000, help= 'Maximum iteration number before interrupting. Negative values = unlimited' ) train_parser.add_argument( '--log-interval', type=int, default=120, help='Interval in seconds after which to perform a validation') train_parser.add_argument( '--save-interval', type=int, default=100, help='Number of iterations after which a model checkpoint is saved') # Dataset. dataset_parser.add_argument( '--history-size', type=int, default=10, required=False, help= 'Amount of scene graphs to be considered in the history for temporal edges' ) dataset_parser.add_argument('--raw-dataset-path', type=str, required=False, default=env['dataset_path_default'], help='Path to the raw dataset') # Parse args. args = {} parsed_args = parser.parse_args(argv) # Find environment configuration if applicable. if parsed_args.command in ['train', 'predict']: with open( os.path.join(parsed_args.basepath, parsed_args.namespace, 'env_config.json')) as f: args = json.load(f) # Add/overwrite env config with current arguments. for key, value in parsed_args.__dict__.items(): args[key] = value # Echo configuration. ac.verbosity = args['verbosity'] for key, value in args.items(): ac.print1('{k:<25s} {v:<10s}'.format(k='{}:'.format(key), v=str(value))) return argparse.Namespace(**args)
def evaluate(args) -> int: assert os.path.exists( args.basepath), 'Basepath `{}` is not a valid path.'.format( args.basepath) assert os.path.exists(os.path.join(args.basepath, args.namespace)), \ 'Could not find namespace `{}`'.format(args.namespace) namespace = os.path.join(args.basepath, args.namespace) ac.plot.latexify(fig_height=2.8) for frame_number in [1, 2, 3, 4, 5, 6]: assert os.path.exists(os.path.join(namespace, 'leave_out_{}'.format(frame_number), 'predictions')), \ 'Predictions not complete!' ac.print1('Load symbolic dataset.') data = ac.dataset.load_symbolic(args.basepath) def to_task_str(task_id: str) -> str: tasks = [ 'task_1_k_cooking', 'task_2_k_cooking_with_bowls', 'task_3_k_pouring', 'task_4_k_wiping', 'task_5_k_cereals', 'task_6_w_hard_drive', 'task_7_w_free_hard_drive', 'task_8_w_hammering', 'task_9_w_sawing' ] return tasks[int(task_id[2:3]) - 1] ground_truth: List[int] = [] correct_top1: List[int] = [] correct_top3: List[int] = [] correct_top1_count: int = 0 correct_top3_count: int = 0 subtotals = [] classes = np.array(ac.actions) os.makedirs(os.path.join(namespace, 'evaluation_results'), exist_ok=True) seg = {} ac.print1('Evaluating now...') for evaluation_id in [1, 2, 3, 4, 5, 6]: subject_ground_truth: List[int] = [] subject_correct_top1: List[int] = [] subject_correct_top3: List[int] = [] subject_correct_top1_count: int = 0 subject_correct_top3_count: int = 0 path = os.path.join(args.basepath, args.namespace, 'leave_out_{}'.format(evaluation_id), 'predictions') for file_name in os.listdir(path): with open(os.path.join(path, file_name)) as f: predictions = json.load(f) _, subject_id, task_id, take_id = file_name[:-len('.json')].split( '_') subject_id = subject_id.replace('s', 'subject_') task_id = to_task_str(task_id) take_id = take_id.replace('tk', 'take_') if subject_id not in seg: seg[subject_id] = {} if task_id not in seg[subject_id]: seg[subject_id][task_id] = {} if take_id not in seg[subject_id][task_id]: seg[subject_id][task_id][take_id] = { 'top1': { 'left': [], 'right': [] }, 'top3': { 'left': [], 'right': [] } } assert 'subject_{}'.format(evaluation_id) == subject_id assert len(predictions['right']) == len(predictions['left']) frame_count = len(predictions['right']) for side in ['left', 'right']: for frame_number in range(0, frame_count, 1): f = np.exp(predictions[side][frame_number]) normalised_predictions = f / np.sum(f) del f top3_pred_indices = np.argsort(-normalised_predictions, axis=-1)[0:3] ground_truth_action = getattr( data[subject_id][task_id][take_id], 'ground_truth_{}'.format(side))[frame_number] # Segmentation seg[subject_id][task_id][take_id]['top1'][side].append( int(top3_pred_indices[0])) seg[subject_id][task_id][take_id]['top3'][side].append( int(ground_truth_action if ground_truth_action in top3_pred_indices else top3_pred_indices[0])) if ground_truth_action is None: continue # Ground truth. ground_truth.append(ground_truth_action) subject_ground_truth.append(ground_truth_action) # Correct top 1. correct_top1.append(top3_pred_indices[0]) subject_correct_top1.append(top3_pred_indices[0]) if top3_pred_indices[0] == ground_truth_action: correct_top1_count += 1 subject_correct_top1_count += 1 # Correct top 3. if ground_truth_action in top3_pred_indices: correct_top3.append(ground_truth_action) subject_correct_top3.append(ground_truth_action) correct_top3_count += 1 subject_correct_top3_count += 1 else: correct_top3.append(top3_pred_indices[0]) subject_correct_top3.append(top3_pred_indices[0]) assert len(subject_ground_truth) == len(subject_correct_top1) == len( subject_correct_top3) ground_truth_np = np.array(subject_ground_truth, dtype=np.int64) correct_top1_np = np.array(subject_correct_top1, dtype=np.int64) correct_top3_np = np.array(subject_correct_top3, dtype=np.int64) top1_filename = os.path.join( namespace, 'evaluation_results', 'subject{}_confusion_matrix_top1.pdf'.format(evaluation_id)) top3_filename = os.path.join( namespace, 'evaluation_results', 'subject{}_confusion_matrix_top3.pdf'.format(evaluation_id)) ac.plot.confusion_matrix(ground_truth_np, correct_top1_np, classes, top1_filename, normalize=True, title='', axis_subject='action') ac.plot.confusion_matrix(ground_truth_np, correct_top3_np, classes, top3_filename, normalize=True, title='', axis_subject='action') total_count = len(subject_ground_truth) subtotal_stats = { 'correct_top1_count': subject_correct_top1_count, 'correct_top3_count': subject_correct_top3_count, 'total_count': total_count, 'percent_top1': subject_correct_top1_count / total_count, 'percent_top3': subject_correct_top3_count / total_count } subtotals.append(subtotal_stats) assert len(ground_truth) == len(correct_top1) == len(correct_top3) ground_truth_np = np.array(ground_truth, dtype=np.int64) correct_top1_np = np.array(correct_top1, dtype=np.int64) correct_top3_np = np.array(correct_top3, dtype=np.int64) top1_filename = os.path.join(namespace, 'evaluation_results', 'confusion_matrix_top1.pdf') top3_filename = os.path.join(namespace, 'evaluation_results', 'confusion_matrix_top3.pdf') ac.plot.confusion_matrix(ground_truth_np, correct_top1_np, classes, top1_filename, normalize=True, title='', axis_subject='action') ac.plot.confusion_matrix(ground_truth_np, correct_top3_np, classes, top3_filename, normalize=True, title='', axis_subject='action') total_count = len(ground_truth) stats = { 'correct_top1_count': correct_top1_count, 'correct_top3_count': correct_top3_count, 'total_count': len(ground_truth), 'percent_top1': correct_top1_count / total_count, 'percent_top3': correct_top3_count / total_count, 'report_top1': sklearn.metrics.classification_report(ground_truth_np, correct_top1_np, target_names=ac.actions, output_dict=True), 'report_top3': sklearn.metrics.classification_report(ground_truth_np, correct_top3_np, target_names=ac.actions, output_dict=True), 'subtotals': subtotals, } with open(os.path.join(namespace, 'evaluation_results', 'stats.json'), 'w') as f: json.dump(stats, f, indent=4) os.makedirs(os.path.join(namespace, 'evaluation_results', 'segmentations'), exist_ok=True) for subject_id in seg: for task_id in seg[subject_id]: for take_id in seg[subject_id][task_id]: for kind in ['top1', 'top3']: segmentation = {'left_hand': [], 'right_hand': []} seg_lr = seg[subject_id][task_id][take_id][kind] assert 'left' in seg_lr and 'right' in seg_lr assert len(seg_lr['left']) == len(seg_lr['right']) for side, action_list in seg_lr.items(): segmentation['{}_hand'.format( side)] = ac.util.to_segmentation_file(action_list) filename = '{}_{}_{}_{}.json'.format( subject_id, task_id, take_id, kind) with open( os.path.join(namespace, 'evaluation_results', 'segmentations', filename), 'w') as f: json.dump(segmentation, f) return STATUS_OK
def predict(self, test_set: List[ac.dataset.SceneGraphProxy], restore) -> None: # Create output directories. os.makedirs(os.path.join(self.out_path, 'predictions'), exist_ok=False) # Model. placeholder = [test_set[0].load()] input_ph = gn.utils_tf.placeholders_from_data_dicts(placeholder) output_ops = self.model(input_ph, self.processing_steps_count) # Init and restore session. self._init_session() self._restore_session(restore) ac.print3('Model restored from `{}`.'.format(self.saver_path)) times: List[float] = [] # To assess timings. current_take = {'right': [], 'left': []} global_frame_id: int = 0 current_frame: ac.dataset.SceneGraphProxy = test_set[global_frame_id] # Helper function to check if the current frame is the last frame in this take. def is_last_frame_in_take(): try: return test_set[global_frame_id + 1].take != test_set[global_frame_id].take except IndexError: return True ac.print2('Looping over test set now.') has_next_frame: bool = True while has_next_frame: start_time = time.time() # Get predictions and append them to list. graph = current_frame.load() # Empty graphs will cause the framework to crash. if len(graph['edges']) > 0: feed_dict = self._create_predict_feed_dict(graph, input_ph) run_params_test = { 'outputs': output_ops } values = self.session.run(run_params_test, feed_dict=feed_dict) output = gn.utils_np.graphs_tuple_to_data_dicts(values['outputs'][-1]) assert len(output) == 1 output = output[0]['globals'] current_take[current_frame.side].append(output.tolist()) times.append(time.time() - start_time) else: current_take[current_frame.side].append([0] * len(ac.actions)) times.append(time.time() - start_time) # Persist predictions if this was the last frame. if is_last_frame_in_take(): assert len(current_take['right']) == len(current_take['left']) identifier = 's{}_ts{}_tk{}'.format(current_frame.subject, current_frame.task, current_frame.take) ac.print2('Writing predictions for `{}` ({} frames).'.format(identifier, len(current_take['right']))) filename = 'predictions_{}.json'.format(identifier) with open(os.path.join(self.out_path, 'predictions', filename), 'w') as fp: json.dump(current_take, fp) current_take['right'] = [] current_take['left'] = [] # Try to load next frame or exit if it fails. global_frame_id += 1 try: current_frame = test_set[global_frame_id] except IndexError: ac.print1('Reached the end of the test set. Exiting main loop now.') has_next_frame = False ac.print1('Predicting took {} ms on average'.format(np.average(times) * 1000))
def train(self, train_set: List[ac.dataset.SceneGraphProxy], valid_set: List[ac.dataset.SceneGraphProxy], restore=None, batch_size_train: int = 512, batch_size_valid: int = 1024, learning_rate: float = 0.001, max_iteration: int = 3000, log_interval: int = 120, save_interval: int = 250) -> None: """ Trains the model given an train_set and validates it on valid_set. :param train_set: Training set. :param valid_set: Validation set. :param restore: Iteration number of the state to restore. :param batch_size_train: Train batch size. :param batch_size_valid: Validation batch size. :param learning_rate: Learning rate. :param max_iteration: Max iteration (aborting afterwards). :param log_interval: Interval when a log should be printed to stdout. :param save_interval: Interval when the model should be saved to disk. """ print('') ac.print2('Starting training.') ac.print2('configuration:') ac.print2(' - restore: {}'.format(restore)) ac.print2(' - batch_size_train: {}'.format(batch_size_train)) ac.print2(' - batch_size_valid: {}'.format(batch_size_valid)) ac.print2(' - learning_rate: {}'.format(learning_rate)) ac.print2(' - max_iteration: {}'.format(max_iteration)) ac.print2(' - log_interval: {}'.format(log_interval)) ac.print2(' - save_interval: {}'.format(save_interval)) # Create output directories. os.makedirs(os.path.join(self.out_path, 'confusion_matrices'), exist_ok=True) os.makedirs(os.path.join(self.out_path, 'losses'), exist_ok=True) os.makedirs(os.path.join(self.out_path, 'models'), exist_ok=True) # Data. # Input and target placeholders. placeholder = [train_set[0].load()] input_ph = gn.utils_tf.placeholders_from_data_dicts(placeholder) target_ph = gn.utils_tf.placeholders_from_data_dicts(placeholder) # A list of outputs, one per processing step. output_ops_tr = self.model(input_ph, self.processing_steps_count) output_ops_ge = self.model(input_ph, self.processing_steps_count) # Training loss. loss_ops_tr = self._create_loss_ops(target_ph, output_ops_tr) # Loss across processing steps. loss_op_tr = sum(loss_ops_tr) / self.processing_steps_count # Generalization loss. loss_ops_ge = self._create_loss_ops(target_ph, output_ops_ge) loss_op_ge = loss_ops_ge[-1] # Loss from final processing step. # Global step variable. global_step = tf.Variable(0, name='global_step', trainable=False) # Optimizer. optimizer = tf.train.AdamOptimizer(learning_rate) step_op = optimizer.minimize(loss_op_tr, global_step=global_step) # Lets an iterable of TF graphs be output from a session as NP graphs. input_ph, target_ph = self._make_all_runnable_in_session(input_ph, target_ph) # Initialise and restore session if desired. self._init_session() train_state = TrainState(self.saver_basepath) iteration_no: int = 0 if restore: self._restore_session(restore) train_state.restore_state(restore) iteration_no = self.session.run(global_step) ac.print2('Model restored from `{}` at iteration #{}.'.format(self.saver_path, iteration_no)) print('') ac.print1('Legend:') ac.print1('# (iteration number), T (elapsed seconds), ' 'Ltr (training loss), Lge (generalisation loss), ' 'Ctr (correct mean, train), ' 'Cge (correct mean, generalisation), ' 'C3tr (top3 guess mean, train), ' 'C3ge (top3 guess mean, generalisation)') start_time = time.time() last_log_time = start_time # Iterate as long as max_iteration is not reached. If max_iteration is set to a negative value, keep iterating # forever. while iteration_no <= max_iteration or max_iteration < 0: # Run session. iteration_no = self.session.run(global_step) feed_dict_tr = self._create_train_feed_dict(iteration_no, train_set, batch_size_train, input_ph, target_ph) run_params_tr = { 'step': step_op, 'targets': target_ph, 'loss': loss_op_tr, 'outputs': output_ops_tr, } train_values = self.session.run(run_params_tr, feed_dict=feed_dict_tr) # Save model (only if not the first iteration after starting). if iteration_no % save_interval == 0 and iteration_no not in [0, restore]: self._persist_session(iteration_no) train_state.persist_state() ac.print2('Model saved to `{}` for iteration #{}.'.format(self.saver_path, iteration_no)) # Assess timings. the_time = time.time() elapsed_since_last_log = the_time - last_log_time # Validate. if elapsed_since_last_log > log_interval: last_log_time = the_time feed_dict_ge = self._create_train_feed_dict(-1, valid_set, batch_size_valid, input_ph, target_ph) run_params_ge = { 'targets': target_ph, 'loss': loss_op_ge, 'outputs': output_ops_ge, } test_values = self.session.run(run_params_ge, feed_dict=feed_dict_ge) cor_tr, cor3_tr = self._compute_accuracy(train_values['targets'], train_values['outputs'][-1]) cor_ge, cor3_ge = self._compute_accuracy(test_values['targets'], test_values['outputs'][-1]) elapsed = time.time() - start_time loss_tr = train_values['loss'] loss_ge = test_values['loss'] log_line = '# {:05d}, T {:.1f}, Ltr {:.4f}, Lge {:.4f}, ' \ 'Ctr {:.4f}, Cge {:.4f}, C3tr {:.4f}, C3ge {:.4f}' ac.print1(log_line.format(iteration_no, elapsed, loss_tr, loss_ge, cor_tr, cor_ge, cor3_tr, cor3_ge)) # Confusion matrix. cm_gt, cm_top_1, cm_top_3 = self._compute_confusions(test_values['targets'], test_values['outputs'][-1]) train_state.log_validation(elapsed, int(iteration_no), float(loss_tr), float(loss_ge), float(cor_tr), float(cor_ge), float(cor3_tr), float(cor3_ge), list(cm_gt.tolist()), list(cm_top_1.tolist()), list(cm_top_3.tolist())) true, pred1, pred3 = train_state.numpy_confusions() classes = np.array(ac.actions) filename_conf1 = os.path.join(self.out_path, 'confusion_matrices', 'conf_top1_{}.png'.format(iteration_no)) filename_conf3 = os.path.join(self.out_path, 'confusion_matrices', 'conf_top3_{}.png'.format(iteration_no)) ac.plot.confusion_matrix(true, pred1, classes, filename_conf1, normalize=True) ac.plot.confusion_matrix(true, pred3, classes, filename_conf3, normalize=True) # Loss graph. filename = os.path.join(self.out_path, 'losses', 'loss_{}.png'.format(iteration_no)) ac.plot.loss_graph(train_state.logged_iterations, train_state.losses_tr, train_state.losses_ge, filename)