def _add_triaining_graph():
    with tf.Graph().as_default() as graph:
        logits = define_audio_slim(training=True)
        tf.summary.histogram('logits', logits)
        # define training subgraph
        with tf.variable_scope('train'):
            labels = tf.placeholder(tf.float32,
                                    shape=[None, params.NUM_CLASSES],
                                    name='labels')
            cross_entropy = tf.nn.softmax_cross_entropy_with_logits_v2(
                logits=logits, labels=labels, name='cross_entropy')
            loss = tf.reduce_mean(cross_entropy, name='loss_op')
            tf.summary.scalar('loss', loss)
            # training
            global_step = tf.Variable(0,
                                      name='global_step',
                                      trainable=False,
                                      collections=[
                                          tf.GraphKeys.GLOBAL_VARIABLES,
                                          tf.GraphKeys.GLOBAL_STEP
                                      ])
            optimizer = tf.train.AdamOptimizer(
                learning_rate=params.LEARNING_RATE,
                epsilon=params.ADAM_EPSILON)
            optimizer.minimize(loss, global_step=global_step, name='train_op')
        return graph
Ejemplo n.º 2
0
def _restore_from_defined_and_ckpt(sess, ckpt):
    """Restore graph from define and variables from ckpt file."""
    with sess.graph.as_default():
        audio_model.define_audio_slim(training=False)
        audio_model.load_audio_slim_checkpoint(sess, ckpt)
    Print to stdout an analysis of the number of floating point operations in the
    model broken down by individual operations.
    """
    tf.profiler.profile(
        graph=graph,
        options=tf.profiler.ProfileOptionBuilder.float_operation(),
        cmd='scope')


if __name__ == '__main__':

    X, y = np.arange(20).reshape((10, 2)), np.arange(10)
    print(X)
    print(y)
    tr, te, vl = train_test_val_split(X, y, shuffle=True)
    print(tr)
    print(te)
    print(vl)

    import sys
    sys.path.append('..')
    sys.path.append('../vggish')
    from audio_model import define_audio_slim
    from vggish_slim import define_vggish_slim
    import tensorflow as tf
    with tf.Graph().as_default() as graph:
        # define_vggish_slim(training=False)
        define_audio_slim(training=False)
        calculate_flops(graph)
    pass