Exemplo n.º 1
0
def bn(x, c):
    x_shape = x.get_shape()
    params_shape = x_shape[-1:]

    if c['use_bias']:
        bias = _get_variable('bias',
                             params_shape,
                             initializer=tf.zeros_initializer)
        return x + bias

    axis = list(range(len(x_shape) - 1))

    beta = resnet._get_variable('beta',
                                params_shape,
                                initializer=tf.zeros_initializer)
    gamma = resnet._get_variable('gamma',
                                 params_shape,
                                 initializer=tf.ones_initializer)

    moving_mean = resnet._get_variable('moving_mean',
                                       params_shape,
                                       initializer=tf.zeros_initializer,
                                       trainable=False)
    moving_variance = resnet._get_variable('moving_variance',
                                           params_shape,
                                           initializer=tf.ones_initializer,
                                           trainable=False)

    # These ops will only be preformed when training.
    mean, variance = tf.nn.moments(x, axis)
    update_moving_mean = moving_averages.assign_moving_average(
        moving_mean, mean, BN_DECAY)
    update_moving_variance = moving_averages.assign_moving_average(
        moving_variance, variance, BN_DECAY)
    tf.add_to_collection(UPDATE_OPS_COLLECTION, update_moving_mean)
    tf.add_to_collection(UPDATE_OPS_COLLECTION, update_moving_variance)

    mean, variance = control_flow_ops.cond(
        c['is_training'], lambda: (mean, variance), lambda:
        (moving_mean, moving_variance))

    x = tf.nn.batch_normalization(x, mean, variance, beta, gamma, BN_EPSILON)
    #x.set_shape(inputs.get_shape()) ??

    return x
Exemplo n.º 2
0
def conv(x, c):
    ksize = c['ksize']
    stride = c['stride']
    filters_out = c['conv_filters_out']

    filters_in = x.get_shape()[-1]
    shape = [ksize, ksize, ksize, filters_in, filters_out]
    initializer = tf.truncated_normal_initializer(stddev=CONV_WEIGHT_STDDEV)
    weights = resnet._get_variable('weights',
                                   shape=shape,
                                   dtype='float',
                                   initializer=initializer,
                                   weight_decay=CONV_WEIGHT_DECAY)
    return tf.nn.conv3d(x, weights, [1, 1, stride, stride, 1], padding='SAME')