Ejemplo n.º 1
0
    def build_model(self):
        """Find the model and build the graph."""

        # Convert feature_names and feature_sizes to lists of values.
        feature_names, feature_sizes = utils.GetListOfFeatureNamesAndSizes(
            FLAGS.feature_names, FLAGS.feature_sizes)

        if FLAGS.distillation_features:
            print "distillation readers"
            if FLAGS.frame_features:
                reader = readers.YT8MFrameDistillationFeatureReader(
                    feature_names=feature_names, feature_sizes=feature_sizes)
            else:
                reader = readers.YT8MAggregatedDistillationFeatureReader(
                    feature_names=feature_names, feature_sizes=feature_sizes)
        else:
            if FLAGS.frame_features:
                reader = readers.YT8MFrameFeatureReader(
                    feature_names=feature_names, feature_sizes=feature_sizes)
            else:
                reader = readers.YT8MAggregatedFeatureReader(
                    feature_names=feature_names, feature_sizes=feature_sizes)

        # Find the model.
        model = find_class_by_name(FLAGS.model,
                                   [frame_level_models, video_level_models])()
        label_loss_fn = find_class_by_name(FLAGS.label_loss, [losses])()
        optimizer_class = find_class_by_name(FLAGS.optimizer, [tf.train])
        transformer_class = find_class_by_name(FLAGS.feature_transformer,
                                               [feature_transform])
        augmenter_class = find_class_by_name(FLAGS.data_augmenter,
                                             [data_augmentation])

        build_graph(
            reader=reader,
            model=model,
            optimizer_class=optimizer_class,
            augmenter_class=augmenter_class,
            transformer_class=transformer_class,
            clip_gradient_norm=FLAGS.clip_gradient_norm,
            train_data_pattern=FLAGS.train_data_pattern,
            label_loss_fn=label_loss_fn,
            base_learning_rate=FLAGS.base_learning_rate,
            learning_rate_decay=FLAGS.learning_rate_decay,
            learning_rate_decay_examples=FLAGS.learning_rate_decay_examples,
            regularization_penalty=FLAGS.regularization_penalty,
            num_readers=FLAGS.num_readers,
            batch_size=FLAGS.batch_size,
            num_epochs=FLAGS.num_epochs)

        logging.info("%s: Built graph.", task_as_string(self.task))

        return tf.train.Saver(
            max_to_keep=3,
            keep_checkpoint_every_n_hours=FLAGS.keep_checkpoint_every_n_hours)
def get_reader():
    # Convert feature_names and feature_sizes to lists of values.
    feature_names, feature_sizes = utils.GetListOfFeatureNamesAndSizes(
        FLAGS.feature_names, FLAGS.feature_sizes)

    if FLAGS.frame_features:
        reader = readers.YT8MFrameDistillationFeatureReader(
            feature_names=feature_names, feature_sizes=feature_sizes)
    else:
        reader = readers.YT8MAggregatedFeatureReader(
            feature_names=feature_names, feature_sizes=feature_sizes)

    return reader