def generate_symbolic_dataset(dataset_path, basepath) -> None:
    derived_data = 'bimacs_derived_data'
    rgbd_data_ground_truth = 'bimacs_rgbd_data_ground_truth'

    os.makedirs(os.path.join(basepath, 'dataset_caches'), exist_ok=True)
    cachefile = os.path.join(basepath, 'dataset_caches',
                             'symbolic_dataset.cache')
    recs = {}

    ac.print2('Generating symbolic dataset.')
    derived_data_paths, ground_truth_paths = get_raw_dataset_paths(
        os.path.join(dataset_path, derived_data),
        os.path.join(dataset_path, rgbd_data_ground_truth))

    for derived_data_path, ground_truth_path in zip(derived_data_paths,
                                                    ground_truth_paths):
        *_, subject, task, take = derived_data_path.split('/')
        ac.print1('Generating objects for {} | {} | {}.'.format(
            subject, task, take))
        if subject not in recs:
            recs[subject] = {}
        if task not in recs[subject]:
            recs[subject][task] = {}
        rec = Recording(derived_data_path, ground_truth_path)
        recs[subject][task][take] = rec

    # Write cache file.
    with open(cachefile, 'wb') as f:
        ac.print2('Writing symbolic dataset cachefile.')
        pickle.dump(recs, f)
        ac.print3('Done writing cachefile.')
 def load(self):
     with open(self.path, 'rb') as f:
         try:
             return getattr(self,
                            'load_{}'.format(self.mode))(pickle.load(f))
         except pickle.UnpicklingError as e:
             ac.print1('Error while unpickling `{}`: {}.  Retrying.'.format(
                 self.path, e))
             time.sleep(0.1)
             try:
                 return getattr(self,
                                'load_{}'.format(self.mode))(pickle.load(f))
             except pickle.UnpicklingError:
                 ac.print0('Repeatedly failed to unpickle `{}`'.format(
                     self.path))
                 ac.print0(traceback.format_exc())
                 raise Exception('Repeatedly failed to unpickle file.')
Esempio n. 3
0
def predict(args) -> int:
    assert os.path.exists(
        args.basepath), 'Basepath `{}` is not a valid path.'.format(
            args.basepath)
    assert os.path.exists(os.path.join(args.basepath, 'dataset_caches', args.dataset_config)), \
        'Dataset config `{}` does not exist in `{}/dataset_caches`'.format(args.dataset_config, args.basepath)
    assert args.evaluation_id in [1, 2, 3, 4, 5,
                                  6], 'Invalid subject ID: {}.'.format(
                                      args.evaluation_id)

    namespace_path = os.path.join(args.basepath, args.namespace)
    fold_path = os.path.join(namespace_path,
                             'leave_out_{}'.format(args.evaluation_id))

    assert os.path.exists(
        namespace_path), 'Namespace `{}` does not exist.'.format(
            args.namespace)
    assert os.path.exists(
        fold_path), 'Fold for subject {} was not trained yet.'.format(
            args.evaluation_id)

    ac.print1('Prepare data.')
    test_set = ac.dataset.load(
        args.basepath,
        args.dataset_config,
        evaluation_mode=args.evaluation_mode,
        filter_if=lambda x: x.subject != args.evaluation_id)
    test_set.sort(key=lambda x: (x.subject, x.task, x.take, x.side, x.frame))

    # Ensure that the dataset is valid.
    ac.print3('Test size: {}.'.format(len(test_set)))
    assert all(x.subject == args.evaluation_id for x in test_set)
    ac.print3('All records in test set are from subject #{}.'.format(
        args.evaluation_id))

    ac.print1('Importing model.')
    import actclass.model
    model = ac.model.ActionClassifierModel(
        fold_path,
        processing_steps_count=args.processing_steps_count,
        layer_count=args.layer_count,
        neuron_count=args.neuron_count)
    model.predict(test_set, args.restore)

    return STATUS_OK
Esempio n. 4
0
def main(argv: List[str]):
    ac.print0('Running on `{}`.'.format(socket.gethostname()))
    env = {
        'dataset_path_default': os.getenv('BIMACS_DATASET_PATH', None),
        'basepath_default': os.getenv('BIMACS_BASEPATH', None)
    }
    args = parse_args(argv, env)
    try:
        code = getattr(ac.exec, args.command)(args)
    except KeyboardInterrupt:
        ac.print0('Interrupted by user.')
        code = STATUS_INTERRUPTED
    except Exception:
        ac.print0('Exception occured!', silent=True)
        ac.print0(traceback.format_exc(), silent=True)
        code = STATUS_UNHANDLED_EXCEPTION
    ac.print1('Exiting with code {}.'.format(code))
    ac.wait_for_logfile()
    return code
Esempio n. 5
0
def dataset(args) -> int:
    assert os.path.exists(
        args.basepath), 'Basepath `{}` is not a valid path.'.format(
            args.basepath)

    dataset_config = 'h{}'.format(args.history_size)

    assert not os.path.exists(os.path.join(args.basepath, 'dataset_caches', dataset_config)), \
        'Dataset configuration `{}` already exists in `{}/dataset_caches`'.format(dataset_config, args.basepath)

    if not ac.dataset.symbolic_datset_exists(args.basepath):
        ac.print1(
            "Symbolic dataset does not exist.  Generating now from --raw-dataset-path='{}'."
            .format(args.raw_dataset_path))
        ac.dataset.generate_symbolic_dataset(args.raw_dataset_path,
                                             args.basepath)
    ac.print1('Generating dataset.')
    ac.dataset.generate_dataset(args.basepath, dataset_config,
                                args.history_size)

    return STATUS_OK
def generate_dataset(basepath: str, config: str, history_size: int) -> None:
    recs = load_symbolic(basepath)
    cachepath = os.path.join(basepath, 'dataset_caches')

    def write_frame(s, ts, tk, fr, graphs_to_write):
        # Write cache file
        os.makedirs(os.path.join(cachepath, config, s, ts, tk), exist_ok=True)
        for side in ['left', 'right']:
            cachefile = os.path.join(cachepath, config, s, ts, tk,
                                     'frame_{}_{}.cache'.format(fr, side))
            written = False
            while not written:
                try:
                    with open(cachefile, 'wb') as f:
                        pickle.dump(graphs_to_write[side], f)
                except IOError:
                    ac.print0('Error writing frame.  Retrying...')
                    time.sleep(0.25)
                else:
                    written = True

    ac.print2('Generating dataset.')
    for subject, task, take in crawl_dataset():
        ac.print1('Generating graphs data for {} | {} | {}.'.format(
            subject, task, take))
        recording: Recording = recs[subject][task][take]
        recording.check_integrity()
        for i in range(0, recording.frame_count, 1):
            sgl = recording.to_scene_graphs(i, history_size=history_size)
            scene_graph = flatten_scene_graphs(sgl)
            scene_graph.check_integrity()
            graphs = {
                'right': scene_graph.to_data_dict(mirrored=False),
                'left': scene_graph.to_data_dict(mirrored=True)
            }
            write_frame(subject, task, take, i, graphs)
Esempio n. 7
0
def train(args) -> int:
    assert os.path.exists(
        args.basepath), 'Basepath `{}` is not a valid path.'.format(
            args.basepath)
    assert os.path.exists(os.path.join(args.basepath, 'dataset_caches', args.dataset_config)), \
        'Dataset config `{}` does not exist in `{}/dataset_caches`'.format(args.dataset_config, args.basepath)
    assert args.evaluation_id in [1, 2, 3, 4, 5,
                                  6], 'Invalid subject ID: {}.'.format(
                                      args.evaluation_id)

    namespace_path = os.path.join(args.basepath, args.namespace)
    fold_path = os.path.join(namespace_path,
                             'leave_out_{}'.format(args.evaluation_id))

    assert os.path.exists(
        namespace_path), 'Namespace `{}` does not exist.'.format(
            args.namespace)
    if args.restore is None:
        assert not os.path.exists(
            fold_path
        ), 'An evaluation was already started for subject #{}.'.format(
            args.evaluation_id)

    # Setup logging.
    os.makedirs(fold_path, exist_ok=True)
    ac.logfile_path = fold_path
    ac.write_logfile = True
    ac.print1('Persistent logging enabled.')

    def is_ev_sub(fs: ac.dataset.SceneGraphProxy):
        return fs.subject == args.evaluation_id

    def is_vld(fs: ac.dataset.SceneGraphProxy):
        return fs.take == args.validation_id

    ac.print1('Prepare data.')
    train_set = ac.dataset.load(args.basepath,
                                args.dataset_config,
                                evaluation_mode=args.evaluation_mode,
                                filter_if=lambda x: is_ev_sub(x) or is_vld(x))
    valid_set = ac.dataset.load(
        args.basepath,
        args.dataset_config,
        evaluation_mode=args.evaluation_mode,
        filter_if=lambda x: is_ev_sub(x) or not is_vld(x))
    random.shuffle(train_set)
    random.shuffle(valid_set)

    # Ensure that the datasets are valid.
    ac.print3('Train/Validation sizes: {} vs. {}.'.format(
        len(train_set), len(valid_set)))
    assert len(train_set) > len(valid_set)
    assert all(x.subject != args.evaluation_id
               for x in train_set), 'Test data in train set!'
    assert all(x.take != args.validation_id
               for x in train_set), 'Validation data in train set!'
    ac.print3(
        'All records in train set are not from subject #{} and are not take #{}'
        .format(args.evaluation_id, args.validation_id))
    assert all(x.subject != args.evaluation_id
               for x in valid_set), 'Test data in validation set!'
    assert all(x.take == args.validation_id
               for x in valid_set), 'Non-validation data in validation set!'
    ac.print3(
        'All records in valid set are not from subject #{} and are take #{}.'.
        format(args.evaluation_id, args.validation_id))

    ac.print1('Importing model.')
    import actclass.model
    model = ac.model.ActionClassifierModel(
        fold_path,
        processing_steps_count=args.processing_steps_count,
        layer_count=args.layer_count,
        neuron_count=args.neuron_count)
    model.train(train_set,
                valid_set,
                restore=args.restore,
                max_iteration=args.max_iteration,
                log_interval=args.log_interval,
                save_interval=args.save_interval)

    return STATUS_OK
Esempio n. 8
0
def parse_args(argv: List[str], env: Dict[str, str]) -> argparse.Namespace:
    parser = argparse.ArgumentParser(prog='ac')
    subparsers = parser.add_subparsers(dest='command')
    mkevenv_parser = subparsers.add_parser('mkevenv')
    train_parser = subparsers.add_parser('train')
    predict_parser = subparsers.add_parser('predict')
    dataset_parser = subparsers.add_parser('dataset')
    evaluate_parser = subparsers.add_parser('evaluate')

    # Common arguments.
    # All parsers need the basepath.
    for p in [
            mkevenv_parser, train_parser, predict_parser, dataset_parser,
            evaluate_parser
    ]:
        p.add_argument('-b',
                       '--basepath',
                       type=str,
                       required=env['basepath_default'] is None,
                       default=env['basepath_default'],
                       help='Basepath to the namespace folders')
        p.add_argument('--verbose',
                       '-v',
                       action='count',
                       dest='verbosity',
                       default=0,
                       help='Sets the verbosity')
    # Namespace argument.
    for p in [mkevenv_parser, train_parser, predict_parser, evaluate_parser]:
        p.add_argument('-n',
                       '--namespace',
                       type=str,
                       required=True,
                       help='String identifier of the namespace')
    # Left out subject ID, restore point for model.
    for p in [train_parser, predict_parser]:
        p.add_argument(
            '-e',
            '--evaluation-id',
            metavar='[1,2,3,4,5,6]',
            type=int,
            choices=[1, 2, 3, 4, 5, 6],
            required=True,
            help='Numerical identifier of the left-out evaluation subject')
        p.add_argument('--restore',
                       type=int,
                       required=(p == predict_parser),
                       help='Iteration number of the state to restore')

    # Make evaluation environment.
    mkevenv_parser.add_argument(
        '--dataset-config',
        type=str,
        required=True,
        help='String identifier of the dataset configuration')
    mkevenv_parser.add_argument(
        '--processing-steps-count',
        type=int,
        default=10,
        help='Number of processing steps the graph network should perform.')
    mkevenv_parser.add_argument(
        '--layer-count',
        type=int,
        default=2,
        help='Number of layers used for the MLPs in the graph network.')
    mkevenv_parser.add_argument(
        '--neuron-count',
        type=int,
        default=32,
        help=
        'Number of neurons per layer used for the MLPs in the graph network.')
    mkevenv_parser.add_argument(
        '--validation-id',
        type=int,
        default=7,
        help='ID of the takes to use as validation set')
    mkevenv_parser.add_argument(
        '--evaluation-mode',
        type=str,
        choices=['normal', 'contact', 'centroids'],
        required=True,
        help='Evaluation mode. (normal, contacts only, bb centroids.)')
    # Train.
    train_parser.add_argument(
        '--max-iteration',
        type=int,
        default=3000,
        help=
        'Maximum iteration number before interrupting.  Negative values = unlimited'
    )
    train_parser.add_argument(
        '--log-interval',
        type=int,
        default=120,
        help='Interval in seconds after which to perform a validation')
    train_parser.add_argument(
        '--save-interval',
        type=int,
        default=100,
        help='Number of iterations after which a model checkpoint is saved')
    # Dataset.
    dataset_parser.add_argument(
        '--history-size',
        type=int,
        default=10,
        required=False,
        help=
        'Amount of scene graphs to be considered in the history for temporal edges'
    )
    dataset_parser.add_argument('--raw-dataset-path',
                                type=str,
                                required=False,
                                default=env['dataset_path_default'],
                                help='Path to the raw dataset')

    # Parse args.
    args = {}
    parsed_args = parser.parse_args(argv)

    # Find environment configuration if applicable.
    if parsed_args.command in ['train', 'predict']:
        with open(
                os.path.join(parsed_args.basepath, parsed_args.namespace,
                             'env_config.json')) as f:
            args = json.load(f)
    # Add/overwrite env config with current arguments.
    for key, value in parsed_args.__dict__.items():
        args[key] = value

    # Echo configuration.
    ac.verbosity = args['verbosity']
    for key, value in args.items():
        ac.print1('{k:<25s} {v:<10s}'.format(k='{}:'.format(key),
                                             v=str(value)))

    return argparse.Namespace(**args)
Esempio n. 9
0
def evaluate(args) -> int:
    assert os.path.exists(
        args.basepath), 'Basepath `{}` is not a valid path.'.format(
            args.basepath)
    assert os.path.exists(os.path.join(args.basepath, args.namespace)), \
        'Could not find namespace `{}`'.format(args.namespace)

    namespace = os.path.join(args.basepath, args.namespace)

    ac.plot.latexify(fig_height=2.8)

    for frame_number in [1, 2, 3, 4, 5, 6]:
        assert os.path.exists(os.path.join(namespace, 'leave_out_{}'.format(frame_number), 'predictions')), \
            'Predictions not complete!'

    ac.print1('Load symbolic dataset.')
    data = ac.dataset.load_symbolic(args.basepath)

    def to_task_str(task_id: str) -> str:
        tasks = [
            'task_1_k_cooking', 'task_2_k_cooking_with_bowls',
            'task_3_k_pouring', 'task_4_k_wiping', 'task_5_k_cereals',
            'task_6_w_hard_drive', 'task_7_w_free_hard_drive',
            'task_8_w_hammering', 'task_9_w_sawing'
        ]
        return tasks[int(task_id[2:3]) - 1]

    ground_truth: List[int] = []
    correct_top1: List[int] = []
    correct_top3: List[int] = []
    correct_top1_count: int = 0
    correct_top3_count: int = 0

    subtotals = []
    classes = np.array(ac.actions)
    os.makedirs(os.path.join(namespace, 'evaluation_results'), exist_ok=True)

    seg = {}

    ac.print1('Evaluating now...')
    for evaluation_id in [1, 2, 3, 4, 5, 6]:
        subject_ground_truth: List[int] = []
        subject_correct_top1: List[int] = []
        subject_correct_top3: List[int] = []
        subject_correct_top1_count: int = 0
        subject_correct_top3_count: int = 0

        path = os.path.join(args.basepath, args.namespace,
                            'leave_out_{}'.format(evaluation_id),
                            'predictions')
        for file_name in os.listdir(path):
            with open(os.path.join(path, file_name)) as f:
                predictions = json.load(f)

            _, subject_id, task_id, take_id = file_name[:-len('.json')].split(
                '_')
            subject_id = subject_id.replace('s', 'subject_')
            task_id = to_task_str(task_id)
            take_id = take_id.replace('tk', 'take_')

            if subject_id not in seg:
                seg[subject_id] = {}
            if task_id not in seg[subject_id]:
                seg[subject_id][task_id] = {}
            if take_id not in seg[subject_id][task_id]:
                seg[subject_id][task_id][take_id] = {
                    'top1': {
                        'left': [],
                        'right': []
                    },
                    'top3': {
                        'left': [],
                        'right': []
                    }
                }

            assert 'subject_{}'.format(evaluation_id) == subject_id

            assert len(predictions['right']) == len(predictions['left'])
            frame_count = len(predictions['right'])

            for side in ['left', 'right']:
                for frame_number in range(0, frame_count, 1):
                    f = np.exp(predictions[side][frame_number])
                    normalised_predictions = f / np.sum(f)
                    del f
                    top3_pred_indices = np.argsort(-normalised_predictions,
                                                   axis=-1)[0:3]

                    ground_truth_action = getattr(
                        data[subject_id][task_id][take_id],
                        'ground_truth_{}'.format(side))[frame_number]

                    # Segmentation
                    seg[subject_id][task_id][take_id]['top1'][side].append(
                        int(top3_pred_indices[0]))
                    seg[subject_id][task_id][take_id]['top3'][side].append(
                        int(ground_truth_action if ground_truth_action in
                            top3_pred_indices else top3_pred_indices[0]))

                    if ground_truth_action is None:
                        continue

                    # Ground truth.
                    ground_truth.append(ground_truth_action)
                    subject_ground_truth.append(ground_truth_action)

                    # Correct top 1.
                    correct_top1.append(top3_pred_indices[0])
                    subject_correct_top1.append(top3_pred_indices[0])
                    if top3_pred_indices[0] == ground_truth_action:
                        correct_top1_count += 1
                        subject_correct_top1_count += 1

                    # Correct top 3.
                    if ground_truth_action in top3_pred_indices:
                        correct_top3.append(ground_truth_action)
                        subject_correct_top3.append(ground_truth_action)
                        correct_top3_count += 1
                        subject_correct_top3_count += 1
                    else:
                        correct_top3.append(top3_pred_indices[0])
                        subject_correct_top3.append(top3_pred_indices[0])

        assert len(subject_ground_truth) == len(subject_correct_top1) == len(
            subject_correct_top3)

        ground_truth_np = np.array(subject_ground_truth, dtype=np.int64)
        correct_top1_np = np.array(subject_correct_top1, dtype=np.int64)
        correct_top3_np = np.array(subject_correct_top3, dtype=np.int64)

        top1_filename = os.path.join(
            namespace, 'evaluation_results',
            'subject{}_confusion_matrix_top1.pdf'.format(evaluation_id))
        top3_filename = os.path.join(
            namespace, 'evaluation_results',
            'subject{}_confusion_matrix_top3.pdf'.format(evaluation_id))
        ac.plot.confusion_matrix(ground_truth_np,
                                 correct_top1_np,
                                 classes,
                                 top1_filename,
                                 normalize=True,
                                 title='',
                                 axis_subject='action')
        ac.plot.confusion_matrix(ground_truth_np,
                                 correct_top3_np,
                                 classes,
                                 top3_filename,
                                 normalize=True,
                                 title='',
                                 axis_subject='action')

        total_count = len(subject_ground_truth)
        subtotal_stats = {
            'correct_top1_count': subject_correct_top1_count,
            'correct_top3_count': subject_correct_top3_count,
            'total_count': total_count,
            'percent_top1': subject_correct_top1_count / total_count,
            'percent_top3': subject_correct_top3_count / total_count
        }
        subtotals.append(subtotal_stats)

    assert len(ground_truth) == len(correct_top1) == len(correct_top3)

    ground_truth_np = np.array(ground_truth, dtype=np.int64)
    correct_top1_np = np.array(correct_top1, dtype=np.int64)
    correct_top3_np = np.array(correct_top3, dtype=np.int64)

    top1_filename = os.path.join(namespace, 'evaluation_results',
                                 'confusion_matrix_top1.pdf')
    top3_filename = os.path.join(namespace, 'evaluation_results',
                                 'confusion_matrix_top3.pdf')
    ac.plot.confusion_matrix(ground_truth_np,
                             correct_top1_np,
                             classes,
                             top1_filename,
                             normalize=True,
                             title='',
                             axis_subject='action')
    ac.plot.confusion_matrix(ground_truth_np,
                             correct_top3_np,
                             classes,
                             top3_filename,
                             normalize=True,
                             title='',
                             axis_subject='action')

    total_count = len(ground_truth)
    stats = {
        'correct_top1_count':
        correct_top1_count,
        'correct_top3_count':
        correct_top3_count,
        'total_count':
        len(ground_truth),
        'percent_top1':
        correct_top1_count / total_count,
        'percent_top3':
        correct_top3_count / total_count,
        'report_top1':
        sklearn.metrics.classification_report(ground_truth_np,
                                              correct_top1_np,
                                              target_names=ac.actions,
                                              output_dict=True),
        'report_top3':
        sklearn.metrics.classification_report(ground_truth_np,
                                              correct_top3_np,
                                              target_names=ac.actions,
                                              output_dict=True),
        'subtotals':
        subtotals,
    }

    with open(os.path.join(namespace, 'evaluation_results', 'stats.json'),
              'w') as f:
        json.dump(stats, f, indent=4)

    os.makedirs(os.path.join(namespace, 'evaluation_results', 'segmentations'),
                exist_ok=True)
    for subject_id in seg:
        for task_id in seg[subject_id]:
            for take_id in seg[subject_id][task_id]:
                for kind in ['top1', 'top3']:
                    segmentation = {'left_hand': [], 'right_hand': []}
                    seg_lr = seg[subject_id][task_id][take_id][kind]
                    assert 'left' in seg_lr and 'right' in seg_lr
                    assert len(seg_lr['left']) == len(seg_lr['right'])
                    for side, action_list in seg_lr.items():
                        segmentation['{}_hand'.format(
                            side)] = ac.util.to_segmentation_file(action_list)
                    filename = '{}_{}_{}_{}.json'.format(
                        subject_id, task_id, take_id, kind)
                    with open(
                            os.path.join(namespace, 'evaluation_results',
                                         'segmentations', filename), 'w') as f:
                        json.dump(segmentation, f)

    return STATUS_OK
Esempio n. 10
0
    def predict(self, test_set: List[ac.dataset.SceneGraphProxy], restore) -> None:
        # Create output directories.
        os.makedirs(os.path.join(self.out_path, 'predictions'), exist_ok=False)

        # Model.
        placeholder = [test_set[0].load()]
        input_ph = gn.utils_tf.placeholders_from_data_dicts(placeholder)
        output_ops = self.model(input_ph, self.processing_steps_count)

        # Init and restore session.
        self._init_session()
        self._restore_session(restore)
        ac.print3('Model restored from `{}`.'.format(self.saver_path))

        times: List[float] = []  # To assess timings.
        current_take = {'right': [], 'left': []}
        global_frame_id: int = 0
        current_frame: ac.dataset.SceneGraphProxy = test_set[global_frame_id]

        # Helper function to check if the current frame is the last frame in this take.
        def is_last_frame_in_take():
            try:
                return test_set[global_frame_id + 1].take != test_set[global_frame_id].take
            except IndexError:
                return True

        ac.print2('Looping over test set now.')
        has_next_frame: bool = True
        while has_next_frame:
            start_time = time.time()
            # Get predictions and append them to list.
            graph = current_frame.load()

            # Empty graphs will cause the framework to crash.
            if len(graph['edges']) > 0:
                feed_dict = self._create_predict_feed_dict(graph, input_ph)
                run_params_test = {
                    'outputs': output_ops
                }
                values = self.session.run(run_params_test, feed_dict=feed_dict)
                output = gn.utils_np.graphs_tuple_to_data_dicts(values['outputs'][-1])
                assert len(output) == 1
                output = output[0]['globals']
                current_take[current_frame.side].append(output.tolist())
                times.append(time.time() - start_time)
            else:
                current_take[current_frame.side].append([0] * len(ac.actions))
                times.append(time.time() - start_time)

            # Persist predictions if this was the last frame.
            if is_last_frame_in_take():
                assert len(current_take['right']) == len(current_take['left'])
                identifier = 's{}_ts{}_tk{}'.format(current_frame.subject, current_frame.task, current_frame.take)
                ac.print2('Writing predictions for `{}` ({} frames).'.format(identifier, len(current_take['right'])))
                filename = 'predictions_{}.json'.format(identifier)
                with open(os.path.join(self.out_path, 'predictions', filename), 'w') as fp:
                    json.dump(current_take, fp)
                current_take['right'] = []
                current_take['left'] = []

            # Try to load next frame or exit if it fails.
            global_frame_id += 1
            try:
                current_frame = test_set[global_frame_id]
            except IndexError:
                ac.print1('Reached the end of the test set.  Exiting main loop now.')
                has_next_frame = False
        ac.print1('Predicting took {} ms on average'.format(np.average(times) * 1000))
Esempio n. 11
0
    def train(self, train_set: List[ac.dataset.SceneGraphProxy], valid_set: List[ac.dataset.SceneGraphProxy],
              restore=None, batch_size_train: int = 512, batch_size_valid: int = 1024, learning_rate: float = 0.001,
              max_iteration: int = 3000, log_interval: int = 120, save_interval: int = 250) -> None:
        """
        Trains the model given an train_set and validates it on valid_set.

        :param train_set: Training set.
        :param valid_set: Validation set.
        :param restore: Iteration number of the state to restore.
        :param batch_size_train: Train batch size.
        :param batch_size_valid: Validation batch size.
        :param learning_rate: Learning rate.
        :param max_iteration: Max iteration (aborting afterwards).
        :param log_interval: Interval when a log should be printed to stdout.
        :param save_interval: Interval when the model should be saved to disk.
        """
        print('')
        ac.print2('Starting training.')
        ac.print2('configuration:')
        ac.print2(' - restore: {}'.format(restore))
        ac.print2(' - batch_size_train: {}'.format(batch_size_train))
        ac.print2(' - batch_size_valid: {}'.format(batch_size_valid))
        ac.print2(' - learning_rate: {}'.format(learning_rate))
        ac.print2(' - max_iteration: {}'.format(max_iteration))
        ac.print2(' - log_interval: {}'.format(log_interval))
        ac.print2(' - save_interval: {}'.format(save_interval))

        # Create output directories.
        os.makedirs(os.path.join(self.out_path, 'confusion_matrices'), exist_ok=True)
        os.makedirs(os.path.join(self.out_path, 'losses'), exist_ok=True)
        os.makedirs(os.path.join(self.out_path, 'models'), exist_ok=True)

        # Data.
        # Input and target placeholders.
        placeholder = [train_set[0].load()]
        input_ph = gn.utils_tf.placeholders_from_data_dicts(placeholder)
        target_ph = gn.utils_tf.placeholders_from_data_dicts(placeholder)

        # A list of outputs, one per processing step.
        output_ops_tr = self.model(input_ph, self.processing_steps_count)
        output_ops_ge = self.model(input_ph, self.processing_steps_count)

        # Training loss.
        loss_ops_tr = self._create_loss_ops(target_ph, output_ops_tr)
        # Loss across processing steps.
        loss_op_tr = sum(loss_ops_tr) / self.processing_steps_count
        # Generalization loss.
        loss_ops_ge = self._create_loss_ops(target_ph, output_ops_ge)
        loss_op_ge = loss_ops_ge[-1]  # Loss from final processing step.

        # Global step variable.
        global_step = tf.Variable(0, name='global_step', trainable=False)

        # Optimizer.
        optimizer = tf.train.AdamOptimizer(learning_rate)
        step_op = optimizer.minimize(loss_op_tr, global_step=global_step)

        # Lets an iterable of TF graphs be output from a session as NP graphs.
        input_ph, target_ph = self._make_all_runnable_in_session(input_ph, target_ph)

        # Initialise and restore session if desired.
        self._init_session()
        train_state = TrainState(self.saver_basepath)
        iteration_no: int = 0
        if restore:
            self._restore_session(restore)
            train_state.restore_state(restore)
            iteration_no = self.session.run(global_step)
            ac.print2('Model restored from `{}` at iteration #{}.'.format(self.saver_path, iteration_no))

        print('')
        ac.print1('Legend:')
        ac.print1('# (iteration number), T (elapsed seconds), '
                  'Ltr (training loss), Lge (generalisation loss), '
                  'Ctr (correct mean, train), '
                  'Cge (correct mean, generalisation), '
                  'C3tr (top3 guess mean, train), '
                  'C3ge (top3 guess mean, generalisation)')

        start_time = time.time()
        last_log_time = start_time
        # Iterate as long as max_iteration is not reached.  If max_iteration is set to a negative value, keep iterating
        # forever.
        while iteration_no <= max_iteration or max_iteration < 0:
            # Run session.
            iteration_no = self.session.run(global_step)
            feed_dict_tr = self._create_train_feed_dict(iteration_no, train_set, batch_size_train, input_ph, target_ph)
            run_params_tr = {
                'step': step_op,
                'targets': target_ph,
                'loss': loss_op_tr,
                'outputs': output_ops_tr,
            }
            train_values = self.session.run(run_params_tr, feed_dict=feed_dict_tr)

            # Save model (only if not the first iteration after starting).
            if iteration_no % save_interval == 0 and iteration_no not in [0, restore]:
                self._persist_session(iteration_no)
                train_state.persist_state()
                ac.print2('Model saved to `{}` for iteration #{}.'.format(self.saver_path, iteration_no))

            # Assess timings.
            the_time = time.time()
            elapsed_since_last_log = the_time - last_log_time

            # Validate.
            if elapsed_since_last_log > log_interval:
                last_log_time = the_time
                feed_dict_ge = self._create_train_feed_dict(-1, valid_set, batch_size_valid, input_ph, target_ph)
                run_params_ge = {
                    'targets': target_ph,
                    'loss': loss_op_ge,
                    'outputs': output_ops_ge,
                }
                test_values = self.session.run(run_params_ge, feed_dict=feed_dict_ge)
                cor_tr, cor3_tr = self._compute_accuracy(train_values['targets'], train_values['outputs'][-1])
                cor_ge, cor3_ge = self._compute_accuracy(test_values['targets'], test_values['outputs'][-1])
                elapsed = time.time() - start_time
                loss_tr = train_values['loss']
                loss_ge = test_values['loss']

                log_line = '# {:05d}, T {:.1f}, Ltr {:.4f}, Lge {:.4f}, ' \
                           'Ctr {:.4f}, Cge {:.4f}, C3tr {:.4f}, C3ge {:.4f}'
                ac.print1(log_line.format(iteration_no, elapsed, loss_tr, loss_ge, cor_tr, cor_ge, cor3_tr, cor3_ge))

                # Confusion matrix.
                cm_gt, cm_top_1, cm_top_3 = self._compute_confusions(test_values['targets'], test_values['outputs'][-1])
                train_state.log_validation(elapsed, int(iteration_no), float(loss_tr), float(loss_ge),
                                           float(cor_tr), float(cor_ge), float(cor3_tr), float(cor3_ge),
                                           list(cm_gt.tolist()), list(cm_top_1.tolist()), list(cm_top_3.tolist()))
                true, pred1, pred3 = train_state.numpy_confusions()
                classes = np.array(ac.actions)
                filename_conf1 = os.path.join(self.out_path, 'confusion_matrices',
                                              'conf_top1_{}.png'.format(iteration_no))
                filename_conf3 = os.path.join(self.out_path, 'confusion_matrices',
                                              'conf_top3_{}.png'.format(iteration_no))
                ac.plot.confusion_matrix(true, pred1, classes, filename_conf1, normalize=True)
                ac.plot.confusion_matrix(true, pred3, classes, filename_conf3, normalize=True)

                # Loss graph.
                filename = os.path.join(self.out_path, 'losses', 'loss_{}.png'.format(iteration_no))
                ac.plot.loss_graph(train_state.logged_iterations, train_state.losses_tr, train_state.losses_ge,
                                   filename)