Пример #1
0
def main(_):
  # Create model.

  batch_size = 2
  image_size = 256

  vid_placeholder = tf.placeholder(tf.float32,
                                   (batch_size*FLAGS.num_frames, image_size, image_size, 3))  # pylint: disable=line-too-long
  object_placeholder = tf.placeholder(tf.float32,
                                      (batch_size*FLAGS.num_frames, image_size, image_size, FLAGS.num_object_classes))  # pylint: disable=line-too-long
  input_placeholder = (vid_placeholder, object_placeholder)

  # We are using the full_asnp50_structure, since we feed both video and object.
  FLAGS.model_structure = json.dumps(model_structures.full_asnp50_structure)  # pylint: disable=line-too-long
  FLAGS.model_edge_weights = json.dumps(model_structures.full_asnp_structure_weights)  # pylint: disable=line-too-long

  network = assemblenet_plus.assemblenet_plus(
      assemblenet_depth=50,
      num_classes=FLAGS.num_classes,
      data_format='channels_last')

  # The model function takes the inputs and is_training.
  outputs = network(input_placeholder, False)

  with tf.Session() as sess:
    # Generate a random video to run on.
    # This should be replaced by a real video.
    sess.run(tf.global_variables_initializer())
    vid = np.random.rand(*vid_placeholder.shape)
    obj = np.random.rand(*object_placeholder.shape)
    logits = sess.run(outputs, feed_dict={input_placeholder: (vid, obj)})
    print(logits)
    print(np.argmax(logits, axis=1))
Пример #2
0
def main(_):
    # Create model.

    batch_size = 2
    image_size = 256

    vid_placeholder = tf.placeholder(tf.float32,
                                     (batch_size, FLAGS.num_frames, image_size, image_size, 3))  # pylint: disable=line-too-long

    if FLAGS.assemblenet_mode == 'assemblenet_plus_lite':
        FLAGS.model_structure = json.dumps(
            model_structures.asnp_lite_structure)
        FLAGS.model_edge_weights = json.dumps(model_structures.asnp_lite_structure_weights)  # pylint: disable=line-too-long

        network = assemblenet_plus_lite.assemblenet_plus_lite(
            num_layers=[3, 5, 11, 7],
            num_classes=FLAGS.num_classes,
            data_format='channels_last')
    else:
        vid_placeholder = tf.reshape(vid_placeholder,
                                     [batch_size * FLAGS.num_frames, image_size, image_size, 3])  # pylint: disable=line-too-long

        if FLAGS.assemblenet_mode == 'assemblenet_plus':
            # Here, we are using model_structures.asn50_structure for AssembleNet++
            # instead of full_asnp50_structure. By using asn50_structure, it
            # essentially becomes AssembleNet++ without objects, only requiring RGB
            # inputs (and optical flow to be computed inside the model).
            FLAGS.model_structure = json.dumps(
                model_structures.asn50_structure)
            FLAGS.model_edge_weights = json.dumps(model_structures.asn_structure_weights)  # pylint: disable=line-too-long

            network = assemblenet_plus.assemblenet_plus(
                assemblenet_depth=50,
                num_classes=FLAGS.num_classes,
                data_format='channels_last')
        else:
            FLAGS.model_structure = json.dumps(
                model_structures.asn50_structure)
            FLAGS.model_edge_weights = json.dumps(model_structures.asn_structure_weights)  # pylint: disable=line-too-long

            network = assemblenet.assemblenet_v1(assemblenet_depth=50,
                                                 num_classes=FLAGS.num_classes,
                                                 data_format='channels_last')

    # The model function takes the inputs and is_training.
    outputs = network(vid_placeholder, False)

    with tf.Session() as sess:
        # Generate a random video to run on.
        # This should be replaced by a real video.
        vid = np.random.rand(*vid_placeholder.shape)
        sess.run(tf.global_variables_initializer())
        logits = sess.run(outputs, feed_dict={vid_placeholder: vid})
        print(logits)
        print(np.argmax(logits, axis=1))