コード例 #1
0
def _weight_variable(scope_name,
                     name,
                     shape,
                     from_pretrain=False,
                     stddev=0.01):
    # with tf.device('/gpu:3'):
    if from_pretrain:
        weights = get_pretrained_weights(scope_name, name, shape)
        if weights is None:
            if FLAGS.xavier_init:
                return tf.get_variable(
                    name,
                    shape,
                    DTYPE,
                    initializer=tf.contrib.layers.xavier_initializer())
            else:
                return tf.get_variable(
                    name, shape, DTYPE,
                    tf.truncated_normal_initializer(stddev=stddev))
        else:
            init = tf.constant(weights)

        return tf.get_variable(name, initializer=init)
    else:
        if FLAGS.xavier_init:
            return tf.get_variable(
                name,
                shape,
                DTYPE,
                initializer=tf.contrib.layers.xavier_initializer())
        else:
            return tf.get_variable(
                name, shape, DTYPE,
                tf.truncated_normal_initializer(stddev=stddev))
コード例 #2
0
def _bias_variable( scope_name, name, shape, from_pretrain=False, constant_value=0.01):
    if from_pretrain:
        
        bias = get_pretrained_weights(scope_name, name,shape)
        if bias is None:
            return tf.get_variable(name, shape, DTYPE, tf.constant_initializer(constant_value))
        else:
            init = tf.constant(bias)
            return tf.get_variable(name, initializer=init)
    else:
        bias = tf.get_variable(name, shape, DTYPE, tf.constant_initializer(constant_value))
        return bias