Exemple #1
0
def inference_atrous4(x,
                      reg_c=0.1,
                      keep_prob=0.5,
                      channels=1,
                      n_class=2,
                      features_root=1,
                      filter_size=3,
                      pool_size=2,
                      dilation_rates=[2],
                      summaries=True,
                      resnet=True):
    '''
    Creates a new convolutional net for the given parametrization.

    :param x: input tensor, shape [?,nx,ny,nz,channels]
    :param keep_prob: dropout probability tensor
    :param channels: number of channels in the input image
    :param n_class: number of output labels
    :param layers: number of layers in the net
    :param features_root: number of features in the first layer
    :param filter_size: size of the convolution filter
    :param pool_size: size of the max pooling operation
    :param summaries: Flag if summaries should be created
    '''

    print(
        'Layers {layers}, features {features}, filter size {filter_size}x{filter_size}, pool size: {pool_size}x{pool_size}'
        .format(layers=5,
                features=features_root,
                filter_size=filter_size,
                pool_size=pool_size))
    if not resnet:
        add_res = partial(tflay.add_res, skip=True)
    else:
        add_res = partial(tflay.add_res, skip=False)

    # Placeholder for the input image
    nx, ny, nz, channels = x.get_shape()[-4:]
    x_image = tf.reshape(x, tf.stack([-1, nx, ny, nz, channels]))
    shape_u0a = [nx, ny, nz]
    shape_u1a = [(n + 1) // 2 for n in shape_u0a]
    shape_u2a = [(n + 1) // 2 for n in shape_u1a]
    shape_u3a = [(n + 1) // 2 for n in shape_u2a]
    shape_u4a = [(n + 1) // 2 for n in shape_u3a]

    batch_size = tf.shape(x_image)[0]

    d0a = tflay.relu(
        'relu_d0a',
        tflay.conv3d('conv_d0a', x_image, features_root, reg_constant=reg_c))
    d0b = tflay.relu('relu_d0b',
                     add_res('res_d0b',
                             tflay.conv3d('conv_d0b',
                                          d0a,
                                          features_root,
                                          reg_constant=reg_c),
                             x_image,
                             conv=False))  # 128 * 128 * 48, n

    d1a = tflay.max_pool('pool_d1a', d0b)  # 64 * 64 * 24, n
    d1b = tflay.relu('relu_d1b',
                     tflay.conv3d('conv_d1b',
                                  d1a,
                                  2**1 * features_root,
                                  reg_constant=reg_c))  # 64 * 64 * 24, 2n
    d1c = tflay.relu('relu_d1c',
                     add_res(
                         'res_d1c',
                         tflay.conv3d('conv_d1b-c',
                                      d1b,
                                      2**1 * features_root,
                                      reg_constant=reg_c),
                         d1a))  # 64 * 64 * 24, 2n

    d2a = tflay.max_pool('pool_d2a', d1c)  # 32 * 32 * 12, 2n
    d2b = tflay.relu('relu_d2b',
                     tflay.conv3d('conv_d2b',
                                  d2a,
                                  2**2 * features_root,
                                  reg_constant=reg_c))  # 32 * 32 * 12, 4n
    d2c = tflay.relu('relu_d2c',
                     add_res(
                         'res_d2c',
                         tflay.conv3d('conv_d2b-c',
                                      d2b,
                                      2**2 * features_root,
                                      reg_constant=reg_c),
                         d2a))  # 32 * 32 * 12, 4n

    d3a = tflay.max_pool('pool_d3a', d2c)  # 16 * 16 * 6, 4n
    d3b = tflay.relu('relu_d3b',
                     tflay.conv3d('conv_d3b',
                                  d3a,
                                  2**3 * features_root,
                                  reg_constant=reg_c))  # 16 * 16 * 6, 8n
    d3c = tflay.relu('relu_d3c',
                     add_res(
                         'res_d3c',
                         tflay.conv3d('conv_d3b-c',
                                      d3b,
                                      2**3 * features_root,
                                      reg_constant=reg_c),
                         d3a))  # 16 * 16 * 6, 8n

    d4a = tflay.max_pool('pool_d4a', d3c)  # 8 * 8 * 3, 8n
    d4b = tflay.relu('relu_d4b',
                     tflay.conv3d('conv_d4b',
                                  d4a,
                                  2**4 * features_root,
                                  reg_constant=reg_c))  # 8 * 8 * 3, 16n
    d4c = tflay.relu('relu_d4c',
                     add_res(
                         'res_d4c',
                         tflay.conv3d('conv_d4b-c',
                                      d4b,
                                      2**4 * features_root,
                                      reg_constant=reg_c),
                         d4a))  # 8 * 8 * 3, 16n
    d4c = tflay.dropout('dropout_d4c', d4c, keep_prob)

    bs = [d4c]
    for i, dilation_rate in enumerate(dilation_rates):
        name = 'b' + str(i) + 'a'
        tmp = tflay.relu('relu_' + name,
                         tflay.atrousconv3d(
                             'atrsconv_' + name,
                             d4c,
                             2**4 * features_root,
                             dilation_rate=[dilation_rate, dilation_rate, 1],
                             reg_constant=reg_c))  # 8 * 8 * 3, 16n
        bs.append(tmp)
    bna = tflay.multiconcat('concat_bna', bs)
    bna = tflay.dropout('dropout_bna', bna, keep_prob)

    u3a = tflay.concat('concat_u3a',
                       tflay.relu(
                           'relu_u3a',
                           tflay.upconv3d('upconv_u3a',
                                          bna,
                                          2**3 * features_root,
                                          shape_u3a,
                                          reg_constant=reg_c)),
                       d3c)  # 16 * 16 * 6, 16n
    u3b = tflay.relu('relu_u3b',
                     tflay.conv3d('conv_u3a-b',
                                  u3a,
                                  2**3 * features_root,
                                  reg_constant=reg_c))  # 16 * 16 * 6, 8n
    u3c = tflay.relu('relu_u3c',
                     add_res(
                         'res_u3c',
                         tflay.conv3d('conv_u3b-c',
                                      u3b,
                                      2**3 * features_root,
                                      reg_constant=reg_c),
                         u3a))  # 16 * 16 * 6, 8n

    u2a = tflay.concat('concat_u2a',
                       tflay.relu(
                           'relu_u2a',
                           tflay.upconv3d('upconv_u2a',
                                          u3c,
                                          2**2 * features_root,
                                          shape_u2a,
                                          reg_constant=reg_c)),
                       d2c)  # 32 * 32 * 12, 8n
    u2b = tflay.relu('relu_u2b',
                     tflay.conv3d('conv_u2a-b',
                                  u2a,
                                  2**2 * features_root,
                                  reg_constant=reg_c))  # 32 * 32 * 12, 4n
    u2c = tflay.relu('relu_u2c',
                     add_res(
                         'res_u2c',
                         tflay.conv3d('conv_u2b-c',
                                      u2b,
                                      2**2 * features_root,
                                      reg_constant=reg_c),
                         u2a))  # 32 * 32 * 12, 4n

    u1a = tflay.concat('concat_u1a',
                       tflay.relu(
                           'relu_u1a',
                           tflay.upconv3d('upconv_u1a',
                                          u2c,
                                          2**1 * features_root,
                                          shape_u1a,
                                          reg_constant=reg_c)),
                       d1c)  # 64 * 64 * 24, 4n
    u1b = tflay.relu('relu_u1b',
                     tflay.conv3d('conv_u1a-b',
                                  u1a,
                                  2**1 * features_root,
                                  reg_constant=reg_c))  # 64 * 64 * 24, 2n
    u1c = tflay.relu('relu_u1c',
                     add_res(
                         'res_u1c',
                         tflay.conv3d('conv_u1b-c',
                                      u1b,
                                      2**1 * features_root,
                                      reg_constant=reg_c),
                         u1a))  # 64 * 64 * 24, 2n

    u0a = tflay.concat('concat_u0a',
                       tflay.relu(
                           'relu_u0a',
                           tflay.upconv3d('upconv_u0a',
                                          u1c,
                                          2**0 * features_root,
                                          shape_u0a,
                                          reg_constant=reg_c)),
                       d0b)  # 128 * 128 * 48, 2n
    u0b = tflay.relu('relu_u0b',
                     tflay.conv3d('conv_u0a-b',
                                  u0a,
                                  2**0 * features_root,
                                  reg_constant=reg_c))  # 128 * 128 * 48, n
    u0c = tflay.relu('relu_u0c',
                     add_res(
                         'res_u0c',
                         tflay.conv3d('conv_u0b-c',
                                      u0b,
                                      2**0 * features_root,
                                      reg_constant=reg_c,
                                      padding='VALID'),
                         u0a))  # 128 * 128 * 48, n

    score = tflay.relu(
        'relu_result',
        tflay.conv3d('conv_result',
                     u0c,
                     n_class,
                     kernel_size=[1, 1, 1],
                     reg_constant=reg_c))

    return score
Exemple #2
0
    def body(features):
        with tf.variable_scope("SSDGraph"):
            with ipu.scopes.ipu_scope('/device:IPU:0'):
                # conv 1 block
                conv1_1 = layers.conv(features,
                                      ksize=3,
                                      stride=1,
                                      filters_out=64,
                                      name="conv1_1")
                conv1_1 = layers.relu(conv1_1)
                conv1_2 = layers.conv(conv1_1,
                                      ksize=3,
                                      stride=1,
                                      filters_out=64,
                                      name="conv1_2")
                conv1_2 = layers.relu(conv1_2)
                pool1 = layers.maxpool(conv1_2, size=2, stride=2)
                # conv 2 block
                conv2_1 = layers.conv(pool1,
                                      ksize=3,
                                      stride=1,
                                      filters_out=128,
                                      name="conv2_1")
                conv2_1 = layers.relu(conv2_1)
                conv2_2 = layers.conv(conv2_1,
                                      ksize=3,
                                      stride=1,
                                      filters_out=128,
                                      name="conv2_2")
                conv2_2 = layers.relu(conv2_2)
                pool2 = layers.maxpool(conv2_2, size=2, stride=2)
                # conv 3 block
                conv3_1 = layers.conv(pool2,
                                      ksize=3,
                                      stride=1,
                                      filters_out=256,
                                      name="conv3_1")
                conv3_1 = layers.relu(conv3_1)
                conv3_2 = layers.conv(conv3_1,
                                      ksize=3,
                                      stride=1,
                                      filters_out=256,
                                      name="conv3_2")
                conv3_2 = layers.relu(conv3_2)
                conv3_3 = layers.conv(conv3_2,
                                      ksize=3,
                                      stride=1,
                                      filters_out=256,
                                      name="conv3_3")
                conv3_3 = layers.relu(conv3_3)
                pool3 = layers.maxpool(conv3_3, size=2, stride=2)
                # conv 4 block
                conv4_1 = layers.conv(pool3,
                                      ksize=3,
                                      stride=1,
                                      filters_out=512,
                                      name="conv4_1")
                conv4_1 = layers.relu(conv4_1)
                conv4_2 = layers.conv(conv4_1,
                                      ksize=3,
                                      stride=1,
                                      filters_out=512,
                                      name="conv4_2")
                conv4_2 = layers.relu(conv4_2)
                conv4_3 = layers.conv(conv4_2,
                                      ksize=3,
                                      stride=1,
                                      filters_out=512,
                                      name="conv4_3")
                conv4_3 = layers.relu(
                    conv4_3
                )  # feature map to be used for object detection/classification
                pool4 = layers.maxpool(conv4_3, size=2, stride=2)
                # conv 5 block
                conv5_1 = layers.conv(pool4,
                                      ksize=3,
                                      stride=1,
                                      filters_out=512,
                                      name="conv5_1")
                conv5_1 = layers.relu(conv5_1)
                conv5_2 = layers.conv(conv5_1,
                                      ksize=3,
                                      stride=1,
                                      filters_out=512,
                                      name="conv5_2")
                conv5_2 = layers.relu(conv5_2)
                conv5_3 = layers.conv(conv5_2,
                                      ksize=3,
                                      stride=1,
                                      filters_out=512,
                                      name="conv5_3")
                conv5_3 = layers.relu(conv5_3)
                pool5 = layers.maxpool(conv5_3, size=3, stride=1)
                # END VGG

                # Extra feature layers
                # fc6
                fc6 = layers.conv(pool5,
                                  ksize=3,
                                  dilation_rate=(6, 6),
                                  stride=1,
                                  filters_out=1024,
                                  name="fc6")
                fc6 = layers.relu(fc6)
                # fc7
                fc7 = layers.conv(fc6,
                                  ksize=1,
                                  stride=1,
                                  filters_out=1024,
                                  name="fc7")
                fc7 = layers.relu(
                    fc7
                )  # feature map to be used for object detection/classification
                # conv 6 block
                conv6_1 = layers.conv(fc7,
                                      ksize=1,
                                      stride=1,
                                      filters_out=256,
                                      name="conv6_1")
                conv6_1 = layers.relu(conv6_1)
                conv6_1 = tf.pad(conv6_1,
                                 paddings=([[0, 0], [1, 1], [1, 1], [0, 0]]),
                                 name='conv6_padding')
                conv6_2 = layers.conv(conv6_1,
                                      ksize=3,
                                      stride=2,
                                      filters_out=512,
                                      padding='VALID',
                                      name="conv6_2")
                conv6_2 = layers.relu(
                    conv6_2
                )  # feature map to be used for object detection/classification
                # conv 7 block
                conv7_1 = layers.conv(conv6_2,
                                      ksize=1,
                                      stride=1,
                                      filters_out=128,
                                      name="conv7_1")
                conv7_1 = layers.relu(conv7_1)
                conv7_1 = tf.pad(conv7_1,
                                 paddings=([[0, 0], [1, 1], [1, 1], [0, 0]]),
                                 name='conv7_padding')
                conv7_2 = layers.conv(conv7_1,
                                      ksize=3,
                                      stride=2,
                                      filters_out=256,
                                      padding='VALID',
                                      name="conv7_2")
                conv7_2 = layers.relu(
                    conv7_2
                )  # feature map to be used for object detection/classification
                # conv 8 block
                conv8_1 = layers.conv(conv7_2,
                                      ksize=1,
                                      stride=1,
                                      filters_out=128,
                                      name="conv8_1")
                conv8_1 = layers.relu(conv8_1)
                conv8_2 = layers.conv(conv8_1,
                                      ksize=3,
                                      stride=1,
                                      filters_out=256,
                                      padding='VALID',
                                      name="conv8_2")
                conv8_2 = layers.relu(
                    conv8_2
                )  # feature map to be used for object detection/classification
                # conv 9 block
                conv9_1 = layers.conv(conv8_2,
                                      ksize=1,
                                      stride=1,
                                      filters_out=128,
                                      name="conv9_1")
                conv9_1 = layers.relu(conv9_1)
                conv9_2 = layers.conv(conv9_1,
                                      ksize=3,
                                      stride=1,
                                      filters_out=256,
                                      padding='VALID',
                                      name="conv9_2")
                conv9_2 = layers.relu(
                    conv9_2
                )  # feature map to be used for object detection/classification
                # Perform L2 normalization on conv4_3
                conv4_3_norm = tf.math.l2_normalize(conv4_3, axis=3)
                # Conv confidence predictors have output depth N_BOXES * N_CLASSES
                conv4_3_norm_mbox_conf = layers.conv(
                    conv4_3_norm,
                    ksize=3,
                    stride=1,
                    filters_out=N_BOXES[0] * N_CLASSES,
                    name='conv4_3_norm_mbox_conf')
                fc7_mbox_conf = layers.conv(fc7,
                                            ksize=3,
                                            stride=1,
                                            filters_out=N_BOXES[1] * N_CLASSES,
                                            name='fc7_mbox_conf')
                conv6_2_mbox_conf = layers.conv(conv6_2,
                                                ksize=3,
                                                stride=1,
                                                filters_out=N_BOXES[2] *
                                                N_CLASSES,
                                                name='conv6_2_mbox_conf')
                conv7_2_mbox_conf = layers.conv(conv7_2,
                                                ksize=3,
                                                stride=1,
                                                filters_out=N_BOXES[3] *
                                                N_CLASSES,
                                                name='conv7_2_mbox_conf')
                conv8_2_mbox_conf = layers.conv(conv8_2,
                                                ksize=3,
                                                stride=1,
                                                filters_out=N_BOXES[4] *
                                                N_CLASSES,
                                                name='conv8_2_mbox_conf')
                conv9_2_mbox_conf = layers.conv(conv9_2,
                                                ksize=3,
                                                stride=1,
                                                filters_out=N_BOXES[5] *
                                                N_CLASSES,
                                                name='conv9_2_mbox_conf')
                # Conv box location predictors have output depth N_BOXES * 4 (box coordinates)
                conv4_3_norm_mbox_loc = layers.conv(
                    conv4_3_norm,
                    ksize=3,
                    stride=1,
                    filters_out=N_BOXES[0] * 4,
                    name='conv4_3_norm_mbox_loc')
                fc7_mbox_loc = layers.conv(fc7,
                                           ksize=3,
                                           stride=1,
                                           filters_out=N_BOXES[1] * 4,
                                           name='fc7_mbox_loc')
                conv6_2_mbox_loc = layers.conv(conv6_2,
                                               ksize=3,
                                               stride=1,
                                               filters_out=N_BOXES[2] * 4,
                                               name='conv6_2_mbox_loc')
                conv7_2_mbox_loc = layers.conv(conv7_2,
                                               ksize=3,
                                               stride=1,
                                               filters_out=N_BOXES[3] * 4,
                                               name='conv7_2_mbox_loc')
                conv8_2_mbox_loc = layers.conv(conv8_2,
                                               ksize=3,
                                               stride=1,
                                               filters_out=N_BOXES[4] * 4,
                                               name='conv8_2_mbox_loc')
                conv9_2_mbox_loc = layers.conv(conv9_2,
                                               ksize=3,
                                               stride=1,
                                               filters_out=N_BOXES[5] * 4,
                                               name='conv9_2_mbox_loc')
                # Generate the anchor boxes
                conv4_3_norm_mbox_priorbox = AnchorBoxes(
                    HEIGHT,
                    WIDTH,
                    this_scale=SCALES[0],
                    next_scale=SCALES[1],
                    two_boxes_for_ar1=True,
                    this_steps=STEPS[0],
                    this_offsets=OFFSETS[0],
                    clip_boxes=False,
                    variances=VARIANCES,
                    aspect_ratios=ASPECT_RATIOS_PER_LAYER[0],
                    normalize_coords=True,
                    name='conv4_3_norm_mbox_priorbox')(conv4_3_norm_mbox_loc)
                fc7_mbox_priorbox = AnchorBoxes(
                    HEIGHT,
                    WIDTH,
                    this_scale=SCALES[1],
                    next_scale=SCALES[2],
                    two_boxes_for_ar1=True,
                    this_steps=STEPS[1],
                    this_offsets=OFFSETS[1],
                    clip_boxes=False,
                    variances=VARIANCES,
                    aspect_ratios=ASPECT_RATIOS_PER_LAYER[1],
                    normalize_coords=True,
                    name='fc7_mbox_priorbox')(fc7_mbox_loc)
                conv6_2_mbox_priorbox = AnchorBoxes(
                    HEIGHT,
                    WIDTH,
                    this_scale=SCALES[2],
                    next_scale=SCALES[3],
                    two_boxes_for_ar1=True,
                    this_steps=STEPS[2],
                    this_offsets=OFFSETS[2],
                    clip_boxes=False,
                    variances=VARIANCES,
                    aspect_ratios=ASPECT_RATIOS_PER_LAYER[2],
                    normalize_coords=True,
                    name='conv6_2_mbox_priorbox')(conv6_2_mbox_loc)
                conv7_2_mbox_priorbox = AnchorBoxes(
                    HEIGHT,
                    WIDTH,
                    this_scale=SCALES[3],
                    next_scale=SCALES[4],
                    two_boxes_for_ar1=True,
                    this_steps=STEPS[3],
                    this_offsets=OFFSETS[3],
                    clip_boxes=False,
                    variances=VARIANCES,
                    aspect_ratios=ASPECT_RATIOS_PER_LAYER[3],
                    normalize_coords=True,
                    name='conv7_2_mbox_priorbox')(conv7_2_mbox_loc)
                conv8_2_mbox_priorbox = AnchorBoxes(
                    HEIGHT,
                    WIDTH,
                    this_scale=SCALES[4],
                    next_scale=SCALES[5],
                    two_boxes_for_ar1=True,
                    this_steps=STEPS[4],
                    this_offsets=OFFSETS[4],
                    clip_boxes=False,
                    variances=VARIANCES,
                    aspect_ratios=ASPECT_RATIOS_PER_LAYER[4],
                    normalize_coords=True,
                    name='conv8_2_mbox_priorbox')(conv8_2_mbox_loc)
                conv9_2_mbox_priorbox = AnchorBoxes(
                    HEIGHT,
                    WIDTH,
                    this_scale=SCALES[5],
                    next_scale=SCALES[6],
                    two_boxes_for_ar1=True,
                    this_steps=STEPS[5],
                    this_offsets=OFFSETS[5],
                    clip_boxes=False,
                    variances=VARIANCES,
                    aspect_ratios=ASPECT_RATIOS_PER_LAYER[5],
                    normalize_coords=True,
                    name='conv9_2_mbox_priorbox')(conv9_2_mbox_loc)
                # Reshape class predictions
                conv4_3_norm_mbox_conf_reshape = tf.reshape(
                    conv4_3_norm_mbox_conf,
                    shape=(-1, conv4_3_norm_mbox_conf.shape[1] *
                           conv4_3_norm_mbox_conf.shape[2] * N_BOXES[0],
                           N_CLASSES),
                    name='conv4_3_norm_mbox_conf_reshape')
                fc7_mbox_conf_reshape = tf.reshape(
                    fc7_mbox_conf,
                    shape=(-1, fc7_mbox_conf.shape[1] *
                           fc7_mbox_conf.shape[2] * N_BOXES[1], N_CLASSES),
                    name='fc7_mbox_conf_reshape')
                conv6_2_mbox_conf_reshape = tf.reshape(
                    conv6_2_mbox_conf,
                    shape=(-1, conv6_2_mbox_conf.shape[1] *
                           conv6_2_mbox_conf.shape[2] * N_BOXES[2], N_CLASSES),
                    name='conv6_2_mbox_conf_reshape')
                conv7_2_mbox_conf_reshape = tf.reshape(
                    conv7_2_mbox_conf,
                    shape=(-1, conv7_2_mbox_conf.shape[1] *
                           conv7_2_mbox_conf.shape[2] * N_BOXES[3], N_CLASSES),
                    name='conv7_2_mbox_conf_reshape')
                conv8_2_mbox_conf_reshape = tf.reshape(
                    conv8_2_mbox_conf,
                    shape=(-1, conv8_2_mbox_conf.shape[1] *
                           conv8_2_mbox_conf.shape[2] * N_BOXES[4], N_CLASSES),
                    name='conv8_2_mbox_conf_reshape')
                conv9_2_mbox_conf_reshape = tf.reshape(
                    conv9_2_mbox_conf,
                    shape=(-1, conv9_2_mbox_conf.shape[1] *
                           conv9_2_mbox_conf.shape[2] * N_BOXES[5], N_CLASSES),
                    name='conv9_2_mbox_conf_reshape')
                # Reshape box location predictions
                conv4_3_norm_mbox_loc_reshape = tf.reshape(
                    conv4_3_norm_mbox_loc,
                    shape=(-1, conv4_3_norm_mbox_loc.shape[1] *
                           conv4_3_norm_mbox_loc.shape[2] * N_BOXES[0], 4),
                    name='conv4_3_norm_mbox_loc_reshape')
                fc7_mbox_loc_reshape = tf.reshape(
                    fc7_mbox_loc,
                    shape=(-1, fc7_mbox_loc.shape[1] * fc7_mbox_loc.shape[2] *
                           N_BOXES[1], 4),
                    name='fc7_mbox_loc_reshape')
                conv6_2_mbox_loc_reshape = tf.reshape(
                    conv6_2_mbox_loc,
                    shape=(-1, conv6_2_mbox_loc.shape[1] *
                           conv6_2_mbox_loc.shape[2] * N_BOXES[2], 4),
                    name='conv6_2_mbox_loc_reshape')
                conv7_2_mbox_loc_reshape = tf.reshape(
                    conv7_2_mbox_loc,
                    shape=(-1, conv7_2_mbox_loc.shape[1] *
                           conv7_2_mbox_loc.shape[2] * N_BOXES[3], 4),
                    name='conv7_2_mbox_loc_reshape')
                conv8_2_mbox_loc_reshape = tf.reshape(
                    conv8_2_mbox_loc,
                    shape=(-1, conv8_2_mbox_loc.shape[1] *
                           conv8_2_mbox_loc.shape[2] * N_BOXES[4], 4),
                    name='conv8_2_mbox_loc_reshape')
                conv9_2_mbox_loc_reshape = tf.reshape(
                    conv9_2_mbox_loc,
                    shape=(-1, conv9_2_mbox_loc.shape[1] *
                           conv9_2_mbox_loc.shape[2] * N_BOXES[5], 4),
                    name='conv9_2_mbox_loc_reshape')
                # Reshape anchor box tensors
                conv4_3_norm_mbox_priorbox_reshape = tf.reshape(
                    conv4_3_norm_mbox_priorbox,
                    shape=(-1, conv4_3_norm_mbox_priorbox.shape[1] *
                           conv4_3_norm_mbox_priorbox.shape[2] * N_BOXES[0],
                           8),
                    name='conv4_3_norm_mbox_priorbox_reshape')
                fc7_mbox_priorbox_reshape = tf.reshape(
                    fc7_mbox_priorbox,
                    shape=(-1, fc7_mbox_priorbox.shape[1] *
                           fc7_mbox_priorbox.shape[2] * N_BOXES[1], 8),
                    name='fc7_mbox_priorbox_reshape')
                conv6_2_mbox_priorbox_reshape = tf.reshape(
                    conv6_2_mbox_priorbox,
                    shape=(-1, conv6_2_mbox_priorbox.shape[1] *
                           conv6_2_mbox_priorbox.shape[2] * N_BOXES[2], 8),
                    name='conv6_2_mbox_priorbox_reshape')
                conv7_2_mbox_priorbox_reshape = tf.reshape(
                    conv7_2_mbox_priorbox,
                    shape=(-1, conv7_2_mbox_priorbox.shape[1] *
                           conv7_2_mbox_priorbox.shape[2] * N_BOXES[3], 8),
                    name='conv7_2_mbox_priorbox_reshape')
                conv8_2_mbox_priorbox_reshape = tf.reshape(
                    conv8_2_mbox_priorbox,
                    shape=(-1, conv8_2_mbox_priorbox.shape[1] *
                           conv8_2_mbox_priorbox.shape[2] * N_BOXES[4], 8),
                    name='conv8_2_mbox_priorbox_reshape')
                conv9_2_mbox_priorbox_reshape = tf.reshape(
                    conv9_2_mbox_priorbox,
                    shape=(-1, conv9_2_mbox_priorbox.shape[1] *
                           conv9_2_mbox_priorbox.shape[2] * N_BOXES[5], 8),
                    name='conv9_2_mbox_priorbox_reshape')
                # Concatenate predictions from different layers
                mbox_conf = tf.concat([
                    conv4_3_norm_mbox_conf_reshape, fc7_mbox_conf_reshape,
                    conv6_2_mbox_conf_reshape, conv7_2_mbox_conf_reshape,
                    conv8_2_mbox_conf_reshape, conv9_2_mbox_conf_reshape
                ],
                                      axis=1,
                                      name='mbox_conf')
                mbox_loc = tf.concat([
                    conv4_3_norm_mbox_loc_reshape, fc7_mbox_loc_reshape,
                    conv6_2_mbox_loc_reshape, conv7_2_mbox_loc_reshape,
                    conv8_2_mbox_loc_reshape, conv9_2_mbox_loc_reshape
                ],
                                     axis=1,
                                     name='mbox_loc')
                mbox_priorbox = tf.concat([
                    conv4_3_norm_mbox_priorbox_reshape,
                    fc7_mbox_priorbox_reshape, conv6_2_mbox_priorbox_reshape,
                    conv7_2_mbox_priorbox_reshape,
                    conv8_2_mbox_priorbox_reshape,
                    conv9_2_mbox_priorbox_reshape
                ],
                                          axis=1,
                                          name='mbox_priorbox')

                # Softmax activation layer
                mbox_conf_softmax = tf.nn.softmax(mbox_conf,
                                                  name='mbox_conf_softmax')
                predictions = tf.concat(
                    [mbox_conf_softmax, mbox_loc, mbox_priorbox],
                    axis=2,
                    name='predictions')

                # Output
                outfeed = outfeed_queue.enqueue(predictions)
                return outfeed
Exemple #3
0
def run_mnist_model(model_type, normalisation=None):
    ''' model_type = 'relu' or 'shunted_relu' '''
    # MNIST model Parameters
    n_input = 784  # MNIST data input (img shape: 28*28)
    n_classes = 10  # MNIST total classes (0-9 digits)
    learning_rate = 1e-4
    batch_size = 2**7  # 128
    training_epochs = 7
    n_batch_per_epoch = int(np.ceil(mnist.train.num_examples / batch_size))
    root_dir = "output/"
    n_run = get_run_index(root_dir)
    tb_log_string = root_dir + n_run + '_run_' + model_type + '_norm-' + str(
        normalisation)

    # tf Graph input
    #phase = tf.placeholder(tf.bool, name='phase')
    #with tf.name_scope('input'):
    X = tf.placeholder("float32", [None, n_input], name="X")
    y = tf.placeholder("float32", [None, n_classes], name="y")

    #x_image = tf.reshape(X, [-1, 28, 28, 1])
    #tf.summary.image('input', x_image, 3)
    training_bool = tf.placeholder(tf.bool, name='training_bool')

    # Build model
    n_neurons = 100
    n_inhib = 2
    if model_type == 'shunted_relu':
        h1 = nn.shunted_relu(X, n_neurons, n_inhib, 'h1')
        h2 = nn.shunted_relu(h1, n_neurons, n_inhib, 'h2')
        hf = nn.shunted_relu(h2, n_neurons, n_inhib, 'h_final')
        tb_log_string += '_' + str(n_neurons) + '_neurons' + '_inhib_' + str(
            n_inhib)

        #tb_log_string +=' episoln 1'
        clip_op = get_clip_op()

        if normalisation is not None:
            print('*****************************************')
            print('No normalisation coded for shunting relu!')
            print('*****************************************')

    elif model_type == 'relu':
        clip_op = None
        tb_log_string += '_' + str(n_neurons) + '_neurons'
        if normalisation is None:
            h1 = nn.relu(X, n_neurons, 'h1')
            h2 = nn.relu(h1, n_neurons, 'h2')
            hf = nn.relu(h2, n_neurons, 'h_final')

        elif normalisation == 'ln':
            h1 = nn.layer_norm_relu(X, n_neurons, 'h1')
            h2 = nn.layer_norm_relu(h1, n_neurons, 'h2')
            hf = nn.layer_norm_relu(h2, n_neurons, 'h_final')

        elif normalisation == 'bn':
            print('batch_norm engaged')
            h1 = nn.batch_relu(X, n_neurons, training_bool, 'h1')
            h2 = nn.batch_relu(h1, n_neurons, training_bool, 'h2')
            hf = nn.batch_relu(h2, n_neurons, training_bool, 'h_final')

        else:
            print('Unregcognised relu norm type!')
            return 0

    else:
        print('Unregcognised model type!')
        return 0

    logits = nn.logits_layer(hf, n_classes)

    with tf.name_scope("cross_ent"):
        xent = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(
            logits=logits, labels=y),
                              name="cross_ent")
        tf.summary.scalar("cross_ent", xent)

    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(
            update_ops):  # add update batch norm params to training step
        #with tf.name_scope("train"):
        train_step = tf.train.AdamOptimizer(learning_rate).minimize(xent)

    # make this and op?
    with tf.name_scope("accuracy"):  # this an op?
        correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(y, 1))
        accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
        train_summary = tf.summary.scalar("training_accuracy", accuracy)
        test_summary = tf.summary.scalar("testing_accuracy", accuracy)

    summaries_op = get_summary_op(
    )  # rename to something that shows not depend on X
    activations_summaries_op = get_activations_summary_op()

    # run mnist code
    # Todo split into return the graph and running
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    writer = tf.summary.FileWriter(
        tb_log_string)  # can update this later for hyper params
    writer.add_graph(sess.graph)

    # log the starting parameters
    summary_buffer = sess.run(summaries_op)
    writer.add_summary(summary_buffer, global_step=0)

    print(n_batch_per_epoch,
          ' batches per epoch - though check not how to handle last batch?')
    for epoch in range(training_epochs):
        for i in range(n_batch_per_epoch):
            step = epoch * n_batch_per_epoch + i  # steps are batch presentations

            batch = mnist.train.next_batch(batch_size)
            sess.run(train_step,
                     feed_dict={
                         'X:0': batch[0],
                         'y:0': batch[1],
                         'training_bool:0': 1
                     })
            if clip_op is not None:
                sess.run(clip_op)

            # Todo - your handling of the accuracy ops is a little clunky / unsure what it going on exactly
            if step % 50 == 0:
                train_accuracy, train_sum = sess.run(
                    [accuracy, train_summary],
                    feed_dict={
                        'X:0': mnist.train.images,
                        'y:0': mnist.train.labels,
                        'training_bool:0': 1
                    })
                #feed_dict={'X:0': batch[0], 'y:0': batch[1], 'training_bool:0'=1})
                writer.add_summary(train_sum, step)

            if step % 50 == 0:
                #test_accuracy, test_sum = sess.run([accuracy, test_summary],
                test_accuracy, test_sum = sess.run(
                    [accuracy, test_summary],
                    feed_dict={
                        'X:0': mnist.test.images,
                        'y:0': mnist.test.labels,
                        'training_bool:0': 0
                    })
                writer.add_summary(test_sum, step)

            if step % n_batch_per_epoch == 0:
                print(test_accuracy)

            if step % 100 == 0:
                output = sess.run(summaries_op)
                writer.add_summary(output, step)
                activations_output = sess.run(activations_summaries_op,
                                              feed_dict={
                                                  'X:0': batch[0],
                                                  'y:0': batch[1],
                                                  'training_bool:0': 1
                                              })
                writer.add_summary(activations_output, step)