コード例 #1
0
ファイル: unet.py プロジェクト: serre-lab/mgh-video
    def model(inputs, is_training):

        net = inputs
        layer_outputs = {}

        print('Input: {}'.format(net.get_shape().as_list()))

        with tf.variable_scope(model_scope_name):
            # Encode_1
            end_point = 'Encode_1'
            with tf.variable_scope(end_point):
                # 3x3 Conv, padding='same'
                conv2d_1a = conv_batchnorm_relu(
                    net,
                    'Conv2d_1a',
                    num_classes,
                    kernel_size=3,
                    stride=1,
                    padding='SAME',
                    is_training=is_training,
                    num_cores=num_cores,
                    use_batch_norm=use_batch_norm,
                    use_cross_replica_batch_norm=use_cross_replica_batch_norm)
                get_shape = conv2d_1a.get_shape().as_list()
                full_layer_name = model_scope_name + '/' + end_point + '/' + 'Conv2d_1a'
                print('{}: {}'.format(full_layer_name, get_shape))
                layer_outputs[full_layer_name] = conv2d_1a

                # 3x3 Conv, padding='same'
                conv2d_1b = conv_batchnorm_relu(
                    conv2d_1a,
                    'Conv2d_1b',
                    num_classes,
                    kernel_size=3,
                    stride=1,
                    padding='SAME',
                    is_training=is_training,
                    num_cores=num_cores,
                    use_batch_norm=use_batch_norm,
                    use_cross_replica_batch_norm=use_cross_replica_batch_norm)
                get_shape = conv2d_1b.get_shape().as_list()
                full_layer_name = model_scope_name + '/' + end_point + '/' + 'Conv2d_1b'
                print('{}: {}'.format(full_layer_name, get_shape))
                layer_outputs[full_layer_name] = conv2d_1b

                # 2x2 MaxPool
                maxpool_1a = maxpool(conv2d_1b,
                                     'MaxPool_1a',
                                     ksize=[1, 2, 2, 1],
                                     strides=[1, 2, 2, 1],
                                     padding='SAME')
                get_shape = maxpool_1a.get_shape().as_list()
                full_layer_name = model_scope_name + '/' + end_point + '/' + 'MaxPool_1a'
                print('{}: {}'.format(full_layer_name, get_shape))
                layer_outputs[full_layer_name] = maxpool_1a

            if final_endpoint == end_point: return maxpool_1a, layer_outputs

            # Encode_2
            end_point = 'Encode_2'
            with tf.variable_scope(end_point):
                # 3x3 Conv, padding='same'
                conv2d_2a = conv_batchnorm_relu(
                    maxpool_1a,
                    'Conv2d_2a',
                    num_classes * 2,
                    kernel_size=3,
                    stride=1,
                    padding='SAME',
                    is_training=is_training,
                    num_cores=num_cores,
                    use_batch_norm=use_batch_norm,
                    use_cross_replica_batch_norm=use_cross_replica_batch_norm)
                get_shape = conv2d_2a.get_shape().as_list()
                full_layer_name = model_scope_name + '/' + end_point + '/' + 'Conv2d_2a'
                print('{}: {}'.format(full_layer_name, get_shape))
                layer_outputs[full_layer_name] = conv2d_2a

                # 3x3 Conv, padding='same'
                conv2d_2b = conv_batchnorm_relu(
                    conv2d_2a,
                    'Conv2d_2b',
                    num_classes * 2,
                    kernel_size=3,
                    stride=1,
                    padding='SAME',
                    is_training=is_training,
                    num_cores=num_cores,
                    use_batch_norm=use_batch_norm,
                    use_cross_replica_batch_norm=use_cross_replica_batch_norm)
                get_shape = conv2d_2b.get_shape().as_list()
                full_layer_name = model_scope_name + '/' + end_point + '/' + 'Conv2d_2b'
                print('{}: {}'.format(full_layer_name, get_shape))
                layer_outputs[full_layer_name] = conv2d_2b

                # 2x2 MaxPool
                maxpool_2a = maxpool(conv2d_2b,
                                     'MaxPool_2a',
                                     ksize=[1, 2, 2, 1],
                                     strides=[1, 2, 2, 1],
                                     padding='SAME')
                get_shape = maxpool_2a.get_shape().as_list()
                full_layer_name = model_scope_name + '/' + end_point + '/' + 'MaxPool_2a'
                print('{}: {}'.format(full_layer_name, get_shape))
                layer_outputs[full_layer_name] = maxpool_2a

            layer_outputs[end_point] = maxpool_2a
            if final_endpoint == end_point: return maxpool_2a, layer_outputs

            # Encode_3
            end_point = 'Encode_3'
            with tf.variable_scope(end_point):
                # 3x3 Conv, padding='same'
                conv2d_3a = conv_batchnorm_relu(
                    maxpool_2a,
                    'Conv2d_3a',
                    num_classes * 4,
                    kernel_size=3,
                    stride=1,
                    padding='SAME',
                    is_training=is_training,
                    num_cores=num_cores,
                    use_batch_norm=use_batch_norm,
                    use_cross_replica_batch_norm=use_cross_replica_batch_norm)
                get_shape = conv2d_3a.get_shape().as_list()
                full_layer_name = model_scope_name + '/' + end_point + '/' + 'Conv2d_3a'
                print('{}: {}'.format(full_layer_name, get_shape))
                layer_outputs[full_layer_name] = conv2d_3a

                # 3x3 Conv, padding='same'
                conv2d_3b = conv_batchnorm_relu(
                    conv2d_3a,
                    'Conv2d_3b',
                    num_classes * 4,
                    kernel_size=3,
                    stride=1,
                    padding='SAME',
                    is_training=is_training,
                    num_cores=num_cores,
                    use_batch_norm=use_batch_norm,
                    use_cross_replica_batch_norm=use_cross_replica_batch_norm)
                get_shape = conv2d_3b.get_shape().as_list()
                full_layer_name = model_scope_name + '/' + end_point + '/' + 'Conv2d_3b'
                print('{}: {}'.format(full_layer_name, get_shape))
                layer_outputs[full_layer_name] = conv2d_3b

                # 2x2 MaxPool
                maxpool_3a = maxpool(conv2d_3b,
                                     'MaxPool_3a',
                                     ksize=[1, 2, 2, 1],
                                     strides=[1, 2, 2, 1],
                                     padding='SAME')
                get_shape = maxpool_3a.get_shape().as_list()
                full_layer_name = model_scope_name + '/' + end_point + '/' + 'MaxPool_3a'
                print('{}: {}'.format(full_layer_name, get_shape))
                layer_outputs[full_layer_name] = maxpool_3a

            layer_outputs[end_point] = maxpool_3a
            if final_endpoint == end_point: return maxpool_3a, layer_outputs

            # Encode_4
            end_point = 'Encode_4'
            with tf.variable_scope(end_point):
                # 3x3 Conv, padding='same'
                conv2d_4a = conv_batchnorm_relu(
                    maxpool_3a,
                    'Conv2d_4a',
                    num_classes * 8,
                    kernel_size=3,
                    stride=1,
                    padding='SAME',
                    is_training=is_training,
                    num_cores=num_cores,
                    use_batch_norm=use_batch_norm,
                    use_cross_replica_batch_norm=use_cross_replica_batch_norm)
                get_shape = conv2d_4a.get_shape().as_list()
                full_layer_name = model_scope_name + '/' + end_point + '/' + 'Conv2d_4a'
                print('{}: {}'.format(full_layer_name, get_shape))
                layer_outputs[full_layer_name] = conv2d_4a

                # 3x3 Conv, padding='same'
                conv2d_4b = conv_batchnorm_relu(
                    conv2d_4a,
                    'Conv2d_4b',
                    num_classes * 8,
                    kernel_size=3,
                    stride=1,
                    padding='SAME',
                    is_training=is_training,
                    num_cores=num_cores,
                    use_batch_norm=use_batch_norm,
                    use_cross_replica_batch_norm=use_cross_replica_batch_norm)
                get_shape = conv2d_4b.get_shape().as_list()
                full_layer_name = model_scope_name + '/' + end_point + '/' + 'Conv2d_4b'
                print('{}: {}'.format(full_layer_name, get_shape))
                layer_outputs[full_layer_name] = conv2d_4b

                # 2x2 MaxPool
                maxpool_4a = maxpool(conv2d_4b,
                                     'MaxPool_4a',
                                     ksize=[1, 2, 2, 1],
                                     strides=[1, 2, 2, 1],
                                     padding='SAME')
                get_shape = maxpool_4a.get_shape().as_list()
                full_layer_name = model_scope_name + '/' + end_point + '/' + 'MaxPool_4a'
                print('{}: {}'.format(full_layer_name, get_shape))
                layer_outputs[full_layer_name] = maxpool_4a

            layer_outputs[end_point] = maxpool_4a
            if final_endpoint == end_point: return maxpool_4a, layer_outputs

            # Encode_5
            end_point = 'Encode_5'
            with tf.variable_scope(end_point):
                # 3x3 Conv, padding='same'
                conv2d_5a = conv_batchnorm_relu(
                    maxpool_4a,
                    'Conv2d_5a',
                    num_classes * 16,
                    kernel_size=3,
                    stride=1,
                    padding='SAME',
                    is_training=is_training,
                    num_cores=num_cores,
                    use_batch_norm=use_batch_norm,
                    use_cross_replica_batch_norm=use_cross_replica_batch_norm)
                get_shape = conv2d_5a.get_shape().as_list()
                full_layer_name = model_scope_name + '/' + end_point + '/' + 'Conv2d_5a'
                print('{}: {}'.format(full_layer_name, get_shape))
                layer_outputs[full_layer_name] = conv2d_5a

                # 3x3 Conv, padding='same'
                conv2d_5b = conv_batchnorm_relu(
                    conv2d_5a,
                    'Conv2d_5b',
                    num_classes * 16,
                    kernel_size=3,
                    stride=1,
                    padding='SAME',
                    is_training=is_training,
                    num_cores=num_cores,
                    use_batch_norm=use_batch_norm,
                    use_cross_replica_batch_norm=use_cross_replica_batch_norm)
                get_shape = conv2d_5b.get_shape().as_list()
                full_layer_name = model_scope_name + '/' + end_point + '/' + 'Conv2d_5b'
                print('{}: {}'.format(full_layer_name, get_shape))
                layer_outputs[full_layer_name] = conv2d_5b

            layer_outputs[end_point] = conv2d_5b
            if final_endpoint == end_point: return conv2d_5b, layer_outputs

            # Decode_1
            end_point = 'Decode_1'
            with tf.variable_scope(end_point):
                # Up-convolution
                upconv2d_1a = upconv_2D(conv2d_5b,
                                        'UpConv2d_1a',
                                        num_classes * 8,
                                        kernel_size=(2, 2),
                                        strides=(2, 2),
                                        use_bias=True,
                                        padding='valid')
                get_shape = upconv2d_1a.get_shape().as_list()
                full_layer_name = model_scope_name + '/' + end_point + '/' + 'UpConv2d_1a'
                print('{}: {}'.format(full_layer_name, get_shape))
                layer_outputs[full_layer_name] = upconv2d_1a

                # Merge
                merge_1a = tf.concat([conv2d_4b, upconv2d_1a],
                                     axis=-1,
                                     name='merge_1a')
                get_shape = merge_1a.get_shape().as_list()
                full_layer_name = model_scope_name + '/' + end_point + '/' + 'merge_1a'
                print('{}: {}'.format(full_layer_name, get_shape))
                layer_outputs[full_layer_name] = merge_1a

                # 3x3 Conv, padding='same'
                conv2d_d_1a = conv_batchnorm_relu(
                    merge_1a,
                    'Conv2d_d_1a',
                    num_classes * 8,
                    kernel_size=3,
                    stride=1,
                    padding='SAME',
                    is_training=is_training,
                    num_cores=num_cores,
                    use_batch_norm=use_batch_norm,
                    use_cross_replica_batch_norm=use_cross_replica_batch_norm)
                get_shape = conv2d_d_1a.get_shape().as_list()
                full_layer_name = model_scope_name + '/' + end_point + '/' + 'Conv2d_d_1a'
                print('{}: {}'.format(full_layer_name, get_shape))
                layer_outputs[full_layer_name] = conv2d_d_1a

                # 3x3 Conv, padding='same'
                conv2d_d_1b = conv_batchnorm_relu(
                    conv2d_d_1a,
                    'Conv2d_d_1b',
                    num_classes * 8,
                    kernel_size=3,
                    stride=1,
                    padding='SAME',
                    is_training=is_training,
                    num_cores=num_cores,
                    use_batch_norm=use_batch_norm,
                    use_cross_replica_batch_norm=use_cross_replica_batch_norm)
                get_shape = conv2d_d_1b.get_shape().as_list()
                full_layer_name = model_scope_name + '/' + end_point + '/' + 'Conv2d_d_1b'
                print('{}: {}'.format(full_layer_name, get_shape))
                layer_outputs[full_layer_name] = conv2d_d_1b

            layer_outputs[end_point] = conv2d_d_1b
            if final_endpoint == end_point: return conv2d_d_1b, layer_outputs

            # Decode_2
            end_point = 'Decode_2'
            with tf.variable_scope(end_point):
                # Up-convolution
                upconv2d_2a = upconv_2D(conv2d_d_1b,
                                        'UpConv2d_2a',
                                        num_classes * 4,
                                        kernel_size=(2, 2),
                                        strides=(2, 2),
                                        use_bias=True,
                                        padding='valid')
                get_shape = upconv2d_2a.get_shape().as_list()
                full_layer_name = model_scope_name + '/' + end_point + '/' + 'UpConv2d_2a'
                print('{}: {}'.format(full_layer_name, get_shape))
                layer_outputs[full_layer_name] = upconv2d_2a

                # Merge
                merge_2a = tf.concat([conv2d_3b, upconv2d_2a],
                                     axis=-1,
                                     name='merge_2a')
                get_shape = merge_2a.get_shape().as_list()
                full_layer_name = model_scope_name + '/' + end_point + '/' + 'merge_2a'
                print('{}: {}'.format(full_layer_name, get_shape))
                layer_outputs[full_layer_name] = merge_2a

                # 3x3 Conv, padding='same'
                conv2d_d_2a = conv_batchnorm_relu(
                    merge_2a,
                    'Conv2d_d_2a',
                    num_classes * 4,
                    kernel_size=3,
                    stride=1,
                    padding='SAME',
                    is_training=is_training,
                    num_cores=num_cores,
                    use_batch_norm=use_batch_norm,
                    use_cross_replica_batch_norm=use_cross_replica_batch_norm)
                get_shape = conv2d_d_2a.get_shape().as_list()
                full_layer_name = model_scope_name + '/' + end_point + '/' + 'Conv2d_d_2a'
                print('{}: {}'.format(full_layer_name, get_shape))
                layer_outputs[full_layer_name] = conv2d_d_2a

                # 3x3 Conv, padding='same'
                conv2d_d_2b = conv_batchnorm_relu(
                    conv2d_d_2a,
                    'Conv2d_d_2b',
                    num_classes * 4,
                    kernel_size=3,
                    stride=1,
                    padding='SAME',
                    is_training=is_training,
                    num_cores=num_cores,
                    use_batch_norm=use_batch_norm,
                    use_cross_replica_batch_norm=use_cross_replica_batch_norm)
                get_shape = conv2d_d_2b.get_shape().as_list()
                full_layer_name = model_scope_name + '/' + end_point + '/' + 'Conv2d_d_2b'
                print('{}: {}'.format(full_layer_name, get_shape))
                layer_outputs[full_layer_name] = conv2d_d_2b

            layer_outputs[end_point] = conv2d_d_2b
            if final_endpoint == end_point: return conv2d_d_2b, layer_outputs

            # Decode_3
            end_point = 'Decode_3'
            with tf.variable_scope(end_point):
                # Up-convolution
                upconv2d_3a = upconv_2D(conv2d_d_2b,
                                        'UpConv2d_3a',
                                        num_classes * 2,
                                        kernel_size=(2, 2),
                                        strides=(2, 2),
                                        use_bias=True,
                                        padding='valid')
                get_shape = upconv2d_3a.get_shape().as_list()
                full_layer_name = model_scope_name + '/' + end_point + '/' + 'UpConv2d_3a'
                print('{}: {}'.format(full_layer_name, get_shape))
                layer_outputs[full_layer_name] = upconv2d_3a

                # Merge
                merge_3a = tf.concat([conv2d_2b, upconv2d_3a],
                                     axis=-1,
                                     name='merge_3a')
                get_shape = merge_3a.get_shape().as_list()
                full_layer_name = model_scope_name + '/' + end_point + '/' + 'merge_3a'
                print('{}: {}'.format(full_layer_name, get_shape))
                layer_outputs[full_layer_name] = merge_3a

                # 3x3 Conv, padding='same'
                conv2d_d_3a = conv_batchnorm_relu(
                    merge_3a,
                    'Conv2d_d_3a',
                    num_classes * 2,
                    kernel_size=3,
                    stride=1,
                    padding='SAME',
                    is_training=is_training,
                    num_cores=num_cores,
                    use_batch_norm=use_batch_norm,
                    use_cross_replica_batch_norm=use_cross_replica_batch_norm)
                get_shape = conv2d_d_3a.get_shape().as_list()
                full_layer_name = model_scope_name + '/' + end_point + '/' + 'Conv2d_d_3a'
                print('{}: {}'.format(full_layer_name, get_shape))
                layer_outputs[full_layer_name] = conv2d_d_3a

                # 3x3 Conv, padding='same'
                conv2d_d_3b = conv_batchnorm_relu(
                    conv2d_d_3a,
                    'Conv2d_d_3b',
                    num_classes * 2,
                    kernel_size=3,
                    stride=1,
                    padding='SAME',
                    is_training=is_training,
                    num_cores=num_cores,
                    use_batch_norm=use_batch_norm,
                    use_cross_replica_batch_norm=use_cross_replica_batch_norm)
                get_shape = conv2d_d_3b.get_shape().as_list()
                full_layer_name = model_scope_name + '/' + end_point + '/' + 'Conv2d_d_3b'
                print('{}: {}'.format(full_layer_name, get_shape))
                layer_outputs[full_layer_name] = conv2d_d_3b

            layer_outputs[end_point] = conv2d_d_3b
            if final_endpoint == end_point: return conv2d_d_3b, layer_outputs

            # Decode_4
            end_point = 'Decode_4'
            with tf.variable_scope(end_point):
                # Up-convolution
                upconv2d_4a = upconv_2D(conv2d_d_3b,
                                        'UpConv2d_4a',
                                        num_classes,
                                        kernel_size=(2, 2),
                                        strides=(2, 2),
                                        use_bias=True,
                                        padding='valid')
                get_shape = upconv2d_4a.get_shape().as_list()
                full_layer_name = model_scope_name + '/' + end_point + '/' + 'UpConv2d_4a'
                print('{}: {}'.format(full_layer_name, get_shape))
                layer_outputs[full_layer_name] = upconv2d_4a

                # Merge
                merge_4a = tf.concat([conv2d_1b, upconv2d_4a],
                                     axis=-1,
                                     name='merge_4a')
                get_shape = merge_4a.get_shape().as_list()
                full_layer_name = model_scope_name + '/' + end_point + '/' + 'merge_4a'
                print('{}: {}'.format(full_layer_name, get_shape))
                layer_outputs[full_layer_name] = merge_4a

                # 3x3 Conv, padding='same'
                conv2d_d_4a = conv_batchnorm_relu(
                    merge_4a,
                    'Conv2d_d_4a',
                    num_classes,
                    kernel_size=3,
                    stride=1,
                    padding='SAME',
                    is_training=is_training,
                    num_cores=num_cores,
                    use_batch_norm=use_batch_norm,
                    use_cross_replica_batch_norm=use_cross_replica_batch_norm)
                get_shape = conv2d_d_4a.get_shape().as_list()
                full_layer_name = model_scope_name + '/' + end_point + '/' + 'Conv2d_d_4a'
                print('{}: {}'.format(full_layer_name, get_shape))
                layer_outputs[full_layer_name] = conv2d_d_4a

                # 3x3 Conv, padding='same'
                conv2d_d_4b = conv_batchnorm_relu(
                    conv2d_d_4a,
                    'Conv2d_d_4b',
                    num_classes,
                    kernel_size=3,
                    stride=1,
                    padding='SAME',
                    is_training=is_training,
                    num_cores=num_cores,
                    use_batch_norm=use_batch_norm,
                    use_cross_replica_batch_norm=use_cross_replica_batch_norm)
                get_shape = conv2d_d_4b.get_shape().as_list()
                full_layer_name = model_scope_name + '/' + end_point + '/' + 'Conv2d_d_4b'
                print('{}: {}'.format(full_layer_name, get_shape))
                layer_outputs[full_layer_name] = conv2d_d_4b

                # 1x1 Conv, padding='same'
                if single_channel:
                    num_output_channels = 1
                else:
                    num_output_channels = num_classes

####
                conv2d_d_4c = conv_batchnorm_relu(
                    conv2d_d_4b,
                    'Conv2d_d_4c',
                    num_output_channels,
                    kernel_size=1,
                    stride=1,
                    padding='SAME',
                    is_training=is_training,
                    num_cores=num_cores,
                    use_batch_norm=use_batch_norm,
                    use_cross_replica_batch_norm=use_cross_replica_batch_norm
                )  #,
                #activation=None)
                get_shape = conv2d_d_4c.get_shape().as_list()
                full_layer_name = model_scope_name + '/' + end_point + '/' + 'Conv2d_d_4c'
                print('{}: {}'.format(full_layer_name, get_shape))
                layer_outputs[full_layer_name] = conv2d_d_4c

            layer_outputs[end_point] = conv2d_d_4c
            if final_endpoint == end_point: return conv2d_d_4c, layer_outputs
コード例 #2
0
    def model(inputs, is_training):

        net = inputs
        end_points = {}

        print('Inputs: {}'.format(net.get_shape().as_list()))

        # 7x7x7 Conv, stride 2
        end_point = 'Conv3d_1a_7x7'
        net = conv_batchnorm_relu(
            net,
            end_point,
            64,
            kernel_size=7,
            stride=2,
            is_training=is_training,
            num_cores=num_cores,
            use_batch_norm=use_batch_norm,
            use_cross_replica_batch_norm=use_cross_replica_batch_norm)
        get_shape = net.get_shape().as_list()
        print('{} : {}'.format(end_point, get_shape))

        end_points[end_point] = net
        if final_endpoint == end_point: return net, end_points

        # 1x3x3 Max-pool, stride 1, 2, 2
        end_point = 'MaxPool3d_2a_3x3'
        net = maxpool(net,
                      end_point,
                      ksize=[1, 1, 3, 3, 1],
                      strides=[1, 1, 2, 2, 1],
                      padding='SAME')
        get_shape = net.get_shape().as_list()
        print('{} : {}'.format(end_point, get_shape))

        end_points[end_point] = net
        if final_endpoint == end_point: return net, end_points

        # 1x1x1 Conv, stride 1
        end_point = 'Conv3d_2b_1x1'
        net = conv_batchnorm_relu(
            net,
            end_point,
            64,
            kernel_size=1,
            stride=1,
            is_training=is_training,
            num_cores=num_cores,
            use_batch_norm=use_batch_norm,
            use_cross_replica_batch_norm=use_cross_replica_batch_norm)
        get_shape = net.get_shape().as_list()
        print('{} : {}'.format(end_point, get_shape))

        end_points[end_point] = net
        if final_endpoint == end_point: return net, end_points

        # 3x3x3 Conv, stride 1
        end_point = 'Conv3d_2c_3x3'
        net = conv_batchnorm_relu(
            net,
            end_point,
            192,
            kernel_size=3,
            stride=1,
            is_training=is_training,
            num_cores=num_cores,
            use_batch_norm=use_batch_norm,
            use_cross_replica_batch_norm=use_cross_replica_batch_norm)
        get_shape = net.get_shape().as_list()
        print('{} : {}'.format(end_point, get_shape))

        end_points[end_point] = net
        if final_endpoint == end_point: return net, end_points

        # 1x3x3 Max-pool, stride 1, 2, 2
        end_point = 'MaxPool3d_3a_3x3'
        net = maxpool(net,
                      end_point,
                      ksize=[1, 1, 3, 3, 1],
                      strides=[1, 1, 2, 2, 1],
                      padding='SAME')
        get_shape = net.get_shape().as_list()
        print('{} : {}'.format(end_point, get_shape))

        end_points[end_point] = net
        if final_endpoint == end_point: return net, end_points

        # Mixed 3b : Inception block
        end_point = 'Mixed_3b'
        with tf.variable_scope(end_point):
            with tf.variable_scope('Branch_0'):
                # 1x1x1 Conv, stride 1
                branch_0 = conv_batchnorm_relu(
                    net,
                    'Conv3d_0a_1x1',
                    64,
                    kernel_size=1,
                    stride=1,
                    is_training=is_training,
                    num_cores=num_cores,
                    use_batch_norm=use_batch_norm,
                    use_cross_replica_batch_norm=use_cross_replica_batch_norm)

            with tf.variable_scope('Branch_1'):
                # 1x1x1 Conv, stride 1
                branch_1 = conv_batchnorm_relu(
                    net,
                    'Conv3d_0a_1x1',
                    96,
                    kernel_size=1,
                    stride=1,
                    is_training=is_training,
                    num_cores=num_cores,
                    use_batch_norm=use_batch_norm,
                    use_cross_replica_batch_norm=use_cross_replica_batch_norm)
                # 3x3x3 Conv, stride 1
                branch_1 = conv_batchnorm_relu(
                    branch_1,
                    'Conv3d_0b_3x3',
                    128,
                    kernel_size=3,
                    stride=1,
                    is_training=is_training,
                    num_cores=num_cores,
                    use_batch_norm=use_batch_norm,
                    use_cross_replica_batch_norm=use_cross_replica_batch_norm)

            with tf.variable_scope('Branch_2'):
                # 1x1x1 Conv, stride 1
                branch_2 = conv_batchnorm_relu(
                    net,
                    'Conv3d_0a_1x1',
                    16,
                    kernel_size=1,
                    stride=1,
                    is_training=is_training,
                    num_cores=num_cores,
                    use_batch_norm=use_batch_norm,
                    use_cross_replica_batch_norm=use_cross_replica_batch_norm)
                # 3x3x3 Conv, stride 1
                branch_2 = conv_batchnorm_relu(
                    branch_2,
                    'Conv3d_0b_3x3',
                    32,
                    kernel_size=3,
                    stride=1,
                    is_training=is_training,
                    num_cores=num_cores,
                    use_batch_norm=use_batch_norm,
                    use_cross_replica_batch_norm=use_cross_replica_batch_norm)

            with tf.variable_scope('Branch_3'):
                # 3x3x3 Max-pool, stride 1, 1, 1
                branch_3 = maxpool(net,
                                   'MaxPool3d_0a_3x3',
                                   ksize=[1, 3, 3, 3, 1],
                                   strides=[1, 1, 1, 1, 1],
                                   padding='SAME')
                # 1x1x1 Conv, stride 1
                branch_3 = conv_batchnorm_relu(
                    branch_3,
                    'Conv3d_0b_1x1',
                    32,
                    kernel_size=1,
                    stride=1,
                    is_training=is_training,
                    num_cores=num_cores,
                    use_batch_norm=use_batch_norm,
                    use_cross_replica_batch_norm=use_cross_replica_batch_norm)

            # Concat branch_[0-3]
            net = tf.concat([branch_0, branch_1, branch_2, branch_3], 4)
        get_shape = net.get_shape().as_list()
        print('{} : {}'.format(end_point, get_shape))

        end_points[end_point] = net
        if final_endpoint == end_point: return net, end_points

        # Mixed 3c: Inception block
        end_point = 'Mixed_3c'
        with tf.variable_scope(end_point):
            with tf.variable_scope('Branch_0'):
                # 1x1x1 Conv, stride 1
                branch_0 = conv_batchnorm_relu(
                    net,
                    'Conv3d_0a_1x1',
                    128,
                    kernel_size=1,
                    stride=1,
                    is_training=is_training,
                    num_cores=num_cores,
                    use_batch_norm=use_batch_norm,
                    use_cross_replica_batch_norm=use_cross_replica_batch_norm)

            with tf.variable_scope('Branch_1'):
                # 1x1x1 Conv, stride 1
                branch_1 = conv_batchnorm_relu(
                    net,
                    'Conv3d_0b_1x1',
                    128,
                    kernel_size=1,
                    stride=1,
                    is_training=is_training,
                    num_cores=num_cores,
                    use_batch_norm=use_batch_norm,
                    use_cross_replica_batch_norm=use_cross_replica_batch_norm)
                # 3x3x3 Conv, stride 1
                branch_1 = conv_batchnorm_relu(
                    branch_1,
                    'Conv3d_0b_3x3',
                    192,
                    kernel_size=3,
                    stride=1,
                    is_training=is_training,
                    num_cores=num_cores,
                    use_batch_norm=use_batch_norm,
                    use_cross_replica_batch_norm=use_cross_replica_batch_norm)

            with tf.variable_scope('Branch_2'):
                # 1x1x1 Conv, stride 1
                branch_2 = conv_batchnorm_relu(
                    net,
                    'Conv3d_0a_1x1',
                    32,
                    kernel_size=1,
                    stride=1,
                    is_training=is_training,
                    num_cores=num_cores,
                    use_batch_norm=use_batch_norm,
                    use_cross_replica_batch_norm=use_cross_replica_batch_norm)
                # 3x3x3 Conv, stride 1
                branch_2 = conv_batchnorm_relu(
                    branch_2,
                    'Conv3d_0b_3x3',
                    96,
                    kernel_size=3,
                    stride=1,
                    is_training=is_training,
                    num_cores=num_cores,
                    use_batch_norm=use_batch_norm,
                    use_cross_replica_batch_norm=use_cross_replica_batch_norm)

            with tf.variable_scope('Branch_3'):
                # 3x3x3 Max-Pool, stride 1, 1, 1
                branch_3 = maxpool(net,
                                   'MaxPool3d_0a_3x3',
                                   ksize=[1, 3, 3, 3, 1],
                                   strides=[1, 1, 1, 1, 1],
                                   padding='SAME')
                # 1x1x1 Conv, stide 1
                branch_3 = conv_batchnorm_relu(
                    branch_3,
                    'Conv3d_0b_1x1',
                    64,
                    kernel_size=1,
                    stride=1,
                    is_training=is_training,
                    num_cores=num_cores,
                    use_batch_norm=use_batch_norm,
                    use_cross_replica_batch_norm=use_cross_replica_batch_norm)

            # Concat branch_[0-3]
            net = tf.concat([branch_0, branch_1, branch_2, branch_3], 4)
        get_shape = net.get_shape().as_list()
        print('{} : {}'.format(end_point, get_shape))

        end_points[end_point] = net
        if final_endpoint == end_point: return net, end_points

        # 3x3x3 Max-pool, stride 2, 2, 2
        end_point = 'MaxPool3d_4a_3x3'
        net = maxpool(net,
                      end_point,
                      ksize=[1, 3, 3, 3, 1],
                      strides=[1, 2, 2, 2, 1],
                      padding='SAME')
        get_shape = net.get_shape().as_list()
        print('{} : {}'.format(end_point, get_shape))

        end_points[end_point] = net
        if final_endpoint == end_point: return net, end_points

        # Mixed 4b: Inception block
        end_point = 'Mixed_4b'
        with tf.variable_scope(end_point):
            with tf.variable_scope('Branch_0'):
                # 1x1x1 Conv, stride 1
                branch_0 = conv_batchnorm_relu(
                    net,
                    'Conv3d_0a_1x1',
                    192,
                    kernel_size=1,
                    stride=1,
                    is_training=is_training,
                    num_cores=num_cores,
                    use_batch_norm=use_batch_norm,
                    use_cross_replica_batch_norm=use_cross_replica_batch_norm)

            with tf.variable_scope('Branch_1'):
                # 1x1x1 Conv, stride 1
                branch_1 = conv_batchnorm_relu(
                    net,
                    'Conv3d_0a_1x1',
                    96,
                    kernel_size=1,
                    stride=1,
                    is_training=is_training,
                    num_cores=num_cores,
                    use_batch_norm=use_batch_norm,
                    use_cross_replica_batch_norm=use_cross_replica_batch_norm)
                # 3x3x3 Conv, stride 1
                branch_1 = conv_batchnorm_relu(
                    branch_1,
                    'Conv3d_0b_3x3',
                    208,
                    kernel_size=3,
                    stride=1,
                    is_training=is_training,
                    num_cores=num_cores,
                    use_batch_norm=use_batch_norm,
                    use_cross_replica_batch_norm=use_cross_replica_batch_norm)

            with tf.variable_scope('Branch_2'):
                # 1x1x1 Conv, stride 1
                branch_2 = conv_batchnorm_relu(
                    net,
                    'Conv3d_0a_1x1',
                    16,
                    kernel_size=1,
                    stride=1,
                    is_training=is_training,
                    num_cores=num_cores,
                    use_batch_norm=use_batch_norm,
                    use_cross_replica_batch_norm=use_cross_replica_batch_norm)
                # 3x3x3 Conv, stride 1
                branch_2 = conv_batchnorm_relu(
                    branch_2,
                    'Conv3d_0b_3x3',
                    48,
                    kernel_size=3,
                    stride=1,
                    is_training=is_training,
                    num_cores=num_cores,
                    use_batch_norm=use_batch_norm,
                    use_cross_replica_batch_norm=use_cross_replica_batch_norm)

            with tf.variable_scope('Branch_3'):
                # 3x3x3 Max-pool, stride 1, 1, 1
                branch_3 = maxpool(net,
                                   'MaxPool3d_0a_3x3',
                                   ksize=[1, 3, 3, 3, 1],
                                   strides=[1, 1, 1, 1, 1],
                                   padding='SAME')
                # 1x1x1 Conv, stride 1
                branch_3 = conv_batchnorm_relu(
                    branch_3,
                    'Conv3d_0b_1x1',
                    64,
                    kernel_size=1,
                    stride=1,
                    is_training=is_training,
                    num_cores=num_cores,
                    use_batch_norm=use_batch_norm,
                    use_cross_replica_batch_norm=use_cross_replica_batch_norm)

            # Concat branch_[0-3]
            net = tf.concat([branch_0, branch_1, branch_2, branch_3], 4)
        get_shape = net.get_shape().as_list()
        print('{} : {}'.format(end_point, get_shape))

        end_points[end_point] = net
        if final_endpoint == end_point: return net, end_points

        # Mixed 4c: Inception block
        end_point = 'Mixed_4c'
        with tf.variable_scope(end_point):
            with tf.variable_scope('Branch_0'):
                # 1x1x1 Conv, stride 1
                branch_0 = conv_batchnorm_relu(
                    net,
                    'Conv3d_0a_1x1',
                    160,
                    kernel_size=1,
                    stride=1,
                    is_training=is_training,
                    num_cores=num_cores,
                    use_batch_norm=use_batch_norm,
                    use_cross_replica_batch_norm=use_cross_replica_batch_norm)

            with tf.variable_scope('Branch_1'):
                # 1x1x1 Conv, stride 1
                branch_1 = conv_batchnorm_relu(
                    net,
                    'Conv3d_0a_1x1',
                    112,
                    kernel_size=1,
                    stride=1,
                    is_training=is_training,
                    num_cores=num_cores,
                    use_batch_norm=use_batch_norm,
                    use_cross_replica_batch_norm=use_cross_replica_batch_norm)
                # 3x3x3 Conv, stride 1
                branch_1 = conv_batchnorm_relu(
                    branch_1,
                    'Conv3d_0b_3x3',
                    224,
                    kernel_size=3,
                    stride=1,
                    is_training=is_training,
                    num_cores=num_cores,
                    use_batch_norm=use_batch_norm,
                    use_cross_replica_batch_norm=use_cross_replica_batch_norm)

            with tf.variable_scope('Branch_2'):
                # 1x1x1 Conv, stride 1
                branch_2 = conv_batchnorm_relu(
                    net,
                    'Conv3d_0a_1x1',
                    24,
                    kernel_size=1,
                    stride=1,
                    is_training=is_training,
                    num_cores=num_cores,
                    use_batch_norm=use_batch_norm,
                    use_cross_replica_batch_norm=use_cross_replica_batch_norm)
                # 3x3x3 Conv, stride 1
                branch_2 = conv_batchnorm_relu(
                    branch_2,
                    'Conv3d_0b_3x3',
                    64,
                    kernel_size=3,
                    stride=1,
                    is_training=is_training,
                    num_cores=num_cores,
                    use_batch_norm=use_batch_norm,
                    use_cross_replica_batch_norm=use_cross_replica_batch_norm)

            with tf.variable_scope('Branch_3'):
                # 3x3x3 Max-pool, stride 1, 1, 1
                branch_3 = maxpool(net,
                                   'MaxPool3d_0a_3x3',
                                   ksize=[1, 3, 3, 3, 1],
                                   strides=[1, 1, 1, 1, 1],
                                   padding='SAME')
                # 1x1x1 Conv, stride 1
                branch_3 = conv_batchnorm_relu(
                    branch_3,
                    'Conv3d_0b_1x1',
                    64,
                    kernel_size=1,
                    stride=1,
                    is_training=is_training,
                    num_cores=num_cores,
                    use_batch_norm=use_batch_norm,
                    use_cross_replica_batch_norm=use_cross_replica_batch_norm)

            # Concat branch_[0-3]
            net = tf.concat([branch_0, branch_1, branch_2, branch_3], 4)
        get_shape = net.get_shape().as_list()
        print('{} : {}'.format(end_point, get_shape))

        end_points[end_point] = net
        if final_endpoint == end_point: return net, end_points

        # Mixed 4d: Inception block
        end_point = 'Mixed_4d'
        with tf.variable_scope(end_point):
            with tf.variable_scope('Branch_0'):
                # 1x1x1 Conv, stride 1
                branch_0 = conv_batchnorm_relu(
                    net,
                    'Conv3d_0a_1x1',
                    128,
                    kernel_size=1,
                    stride=1,
                    is_training=is_training,
                    num_cores=num_cores,
                    use_batch_norm=use_batch_norm,
                    use_cross_replica_batch_norm=use_cross_replica_batch_norm)

            with tf.variable_scope('Branch_1'):
                # 1x1x1 Conv, stride 1
                branch_1 = conv_batchnorm_relu(
                    net,
                    'Conv3d_0a_1x1',
                    128,
                    kernel_size=1,
                    stride=1,
                    is_training=is_training,
                    num_cores=num_cores,
                    use_batch_norm=use_batch_norm,
                    use_cross_replica_batch_norm=use_cross_replica_batch_norm)
                # 3x3x3 Conv, stride 1
                branch_1 = conv_batchnorm_relu(
                    branch_1,
                    'Conv3d_0b_3x3',
                    256,
                    kernel_size=3,
                    stride=1,
                    is_training=is_training,
                    num_cores=num_cores,
                    use_batch_norm=use_batch_norm,
                    use_cross_replica_batch_norm=use_cross_replica_batch_norm)

            with tf.variable_scope('Branch_2'):
                # 1x1x1 Conv, stride 1
                branch_2 = conv_batchnorm_relu(
                    net,
                    'Conv3d_0a_1x1',
                    24,
                    kernel_size=1,
                    stride=1,
                    is_training=is_training,
                    num_cores=num_cores,
                    use_batch_norm=use_batch_norm,
                    use_cross_replica_batch_norm=use_cross_replica_batch_norm)
                # 3x3x3 Conv, stride 1
                branch_2 = conv_batchnorm_relu(
                    branch_2,
                    'Conv3d_0b_3x3',
                    64,
                    kernel_size=3,
                    stride=1,
                    is_training=is_training,
                    num_cores=num_cores,
                    use_batch_norm=use_batch_norm,
                    use_cross_replica_batch_norm=use_cross_replica_batch_norm)

            with tf.variable_scope('Branch_3'):
                # 3x3x3 Max-pool, stride 1, 1, 1
                branch_3 = maxpool(net,
                                   'MaxPool3d_0a_3x3',
                                   ksize=[1, 3, 3, 3, 1],
                                   strides=[1, 1, 1, 1, 1],
                                   padding='SAME')
                # 1x1x1 Conv, stride 1
                branch_3 = conv_batchnorm_relu(
                    branch_3,
                    'Conv3d_0b_1x1',
                    64,
                    kernel_size=1,
                    stride=1,
                    is_training=is_training,
                    num_cores=num_cores,
                    use_batch_norm=use_batch_norm,
                    use_cross_replica_batch_norm=use_cross_replica_batch_norm)

            # Concat branch_[0-3]
            net = tf.concat([branch_0, branch_1, branch_2, branch_3], 4)
        get_shape = net.get_shape().as_list()
        print('{} : {}'.format(end_point, get_shape))

        end_points[end_point] = net
        if final_endpoint == end_point: return net, end_points

        # Mixed 4e: Inception block
        end_point = 'Mixed_4e'
        with tf.variable_scope(end_point):
            with tf.variable_scope('Branch_0'):
                # 1x1x1 Conv, stride 1
                branch_0 = conv_batchnorm_relu(
                    net,
                    'Conv3d_0a_1x1',
                    112,
                    kernel_size=1,
                    stride=1,
                    is_training=is_training,
                    num_cores=num_cores,
                    use_batch_norm=use_batch_norm,
                    use_cross_replica_batch_norm=use_cross_replica_batch_norm)

            with tf.variable_scope('Branch_1'):
                # 1x1x1 Conv, stride 1
                branch_1 = conv_batchnorm_relu(
                    net,
                    'Conv3d_0a_1x1',
                    144,
                    kernel_size=1,
                    stride=1,
                    is_training=is_training,
                    num_cores=num_cores,
                    use_batch_norm=use_batch_norm,
                    use_cross_replica_batch_norm=use_cross_replica_batch_norm)
                # 3x3x3 Conv, stride 1
                branch_1 = conv_batchnorm_relu(
                    branch_1,
                    'Conv3d_0b_3x3',
                    288,
                    kernel_size=3,
                    stride=1,
                    is_training=is_training,
                    num_cores=num_cores,
                    use_batch_norm=use_batch_norm,
                    use_cross_replica_batch_norm=use_cross_replica_batch_norm)

            with tf.variable_scope('Branch_2'):
                # 1x1x1 Conv, stride 1
                branch_2 = conv_batchnorm_relu(
                    net,
                    'Conv3d_0a_1x1',
                    32,
                    kernel_size=1,
                    stride=1,
                    is_training=is_training,
                    num_cores=num_cores,
                    use_batch_norm=use_batch_norm,
                    use_cross_replica_batch_norm=use_cross_replica_batch_norm)
                # 3x3x3 Conv, stride 1
                branch_2 = conv_batchnorm_relu(
                    branch_2,
                    'Conv3d_0b_3x3',
                    64,
                    kernel_size=3,
                    stride=1,
                    is_training=is_training,
                    num_cores=num_cores,
                    use_batch_norm=use_batch_norm,
                    use_cross_replica_batch_norm=use_cross_replica_batch_norm)

            with tf.variable_scope('Branch_3'):
                # 3x3x3 Max-pool, stride 1, 1, 1
                branch_3 = maxpool(net,
                                   'MaxPool3d_0a_3x3',
                                   ksize=[1, 3, 3, 3, 1],
                                   strides=[1, 1, 1, 1, 1],
                                   padding='SAME')
                # 1x1x1 Conv, stride 1
                branch_3 = conv_batchnorm_relu(
                    branch_3,
                    'Conv3d_0b_1x1',
                    64,
                    kernel_size=1,
                    stride=1,
                    is_training=is_training,
                    num_cores=num_cores,
                    use_batch_norm=use_batch_norm,
                    use_cross_replica_batch_norm=use_cross_replica_batch_norm)

            # Concat branch_[0-3]
            net = tf.concat([branch_0, branch_1, branch_2, branch_3], 4)
        get_shape = net.get_shape().as_list()
        print('{} : {}'.format(end_point, get_shape))

        end_points[end_point] = net
        if final_endpoint == end_point: return net, end_points

        # Mixed 4f: Inception block
        end_point = 'Mixed_4f'
        with tf.variable_scope(end_point):
            with tf.variable_scope('Branch_0'):
                # 1x1x1 Conv, stride 1
                branch_0 = conv_batchnorm_relu(
                    net,
                    'Conv3d_0a_1x1',
                    256,
                    kernel_size=1,
                    stride=1,
                    is_training=is_training,
                    num_cores=num_cores,
                    use_batch_norm=use_batch_norm,
                    use_cross_replica_batch_norm=use_cross_replica_batch_norm)

            with tf.variable_scope('Branch_1'):
                # 1x1x1 Conv, stride 1
                branch_1 = conv_batchnorm_relu(
                    net,
                    'Conv3d_0a_1x1',
                    160,
                    kernel_size=1,
                    stride=1,
                    is_training=is_training,
                    num_cores=num_cores,
                    use_batch_norm=use_batch_norm,
                    use_cross_replica_batch_norm=use_cross_replica_batch_norm)
                # 3x3x3 Conv, stride 1
                branch_1 = conv_batchnorm_relu(
                    branch_1,
                    'Conv3d_0b_3x3',
                    320,
                    kernel_size=3,
                    stride=1,
                    is_training=is_training,
                    num_cores=num_cores,
                    use_batch_norm=use_batch_norm,
                    use_cross_replica_batch_norm=use_cross_replica_batch_norm)

            with tf.variable_scope('Branch_2'):
                # 1x1x1 Conv, stride 1
                branch_2 = conv_batchnorm_relu(
                    net,
                    'Conv3d_0a_1x1',
                    32,
                    kernel_size=1,
                    stride=1,
                    is_training=is_training,
                    num_cores=num_cores,
                    use_batch_norm=use_batch_norm,
                    use_cross_replica_batch_norm=use_cross_replica_batch_norm)
                # 3x3x3 Conv, stride 1
                branch_2 = conv_batchnorm_relu(
                    branch_2,
                    'Conv3d_0b_3x3',
                    128,
                    kernel_size=3,
                    stride=1,
                    is_training=is_training,
                    num_cores=num_cores,
                    use_batch_norm=use_batch_norm,
                    use_cross_replica_batch_norm=use_cross_replica_batch_norm)

            with tf.variable_scope('Branch_3'):
                # 3x3x3 Max-pool, stride 1, 1, 1
                branch_3 = maxpool(net,
                                   'MaxPool3d_0a_3x3',
                                   ksize=[1, 3, 3, 3, 1],
                                   strides=[1, 1, 1, 1, 1],
                                   padding='SAME')
                # 1x1x1 Conv, stride 1
                branch_3 = conv_batchnorm_relu(
                    branch_3,
                    'Conv3d_0b_1x1',
                    128,
                    kernel_size=1,
                    stride=1,
                    is_training=is_training,
                    num_cores=num_cores,
                    use_batch_norm=use_batch_norm,
                    use_cross_replica_batch_norm=use_cross_replica_batch_norm)

            # Concat branch_[0-3]
            net = tf.concat([branch_0, branch_1, branch_2, branch_3], 4)
        get_shape = net.get_shape().as_list()
        print('{} : {}'.format(end_point, get_shape))

        end_points[end_point] = net
        if final_endpoint == end_point: return net, end_points

        # 2x2x2 Max-pool, stride 2x2x2
        end_point = 'MaxPool3d_5a_2x2'
        net = maxpool(net,
                      end_point,
                      ksize=[1, 2, 2, 2, 1],
                      strides=[1, 2, 2, 2, 1],
                      padding='SAME')
        get_shape = net.get_shape().as_list()
        print('{} : {}'.format(end_point, get_shape))

        end_points[end_point] = net
        if final_endpoint == end_point: return net, end_points

        # Mixed 5b: Inception block
        end_point = 'Mixed_5b'
        with tf.variable_scope(end_point):
            with tf.variable_scope('Branch_0'):
                # 1x1x1 Conv, stride 1
                branch_0 = conv_batchnorm_relu(
                    net,
                    'Conv3d_0a_1x1',
                    256,
                    kernel_size=1,
                    stride=1,
                    is_training=is_training,
                    num_cores=num_cores,
                    use_batch_norm=use_batch_norm,
                    use_cross_replica_batch_norm=use_cross_replica_batch_norm)

            with tf.variable_scope('Branch_1'):
                # 1x1x1 Conv, stride 1
                branch_1 = conv_batchnorm_relu(
                    net,
                    'Conv3d_0a_1x1',
                    160,
                    kernel_size=1,
                    stride=1,
                    is_training=is_training,
                    num_cores=num_cores,
                    use_batch_norm=use_batch_norm,
                    use_cross_replica_batch_norm=use_cross_replica_batch_norm)
                # 3x3x3 Conv, stride 1
                branch_1 = conv_batchnorm_relu(
                    branch_1,
                    'Conv3d_0b_3x3',
                    320,
                    kernel_size=3,
                    stride=1,
                    is_training=is_training,
                    num_cores=num_cores,
                    use_batch_norm=use_batch_norm,
                    use_cross_replica_batch_norm=use_cross_replica_batch_norm)

            with tf.variable_scope('Branch_2'):
                # 1x1x1 Conv, stride 1
                branch_2 = conv_batchnorm_relu(
                    net,
                    'Conv3d_0a_1x1',
                    32,
                    kernel_size=1,
                    stride=1,
                    is_training=is_training,
                    num_cores=num_cores,
                    use_batch_norm=use_batch_norm,
                    use_cross_replica_batch_norm=use_cross_replica_batch_norm)
                # 3x3x3 Conv, stride 1
                branch_2 = conv_batchnorm_relu(
                    branch_2,
                    'Conv3d_0b_3x3',
                    128,
                    kernel_size=3,
                    stride=1,
                    is_training=is_training,
                    num_cores=num_cores,
                    use_batch_norm=use_batch_norm,
                    use_cross_replica_batch_norm=use_cross_replica_batch_norm)

            with tf.variable_scope('Branch_3'):
                # 3x3x3 Max-pool, stride 1, 1, 1
                branch_3 = maxpool(net,
                                   'MaxPool3d_0a_3x3',
                                   ksize=[1, 3, 3, 3, 1],
                                   strides=[1, 1, 1, 1, 1],
                                   padding='SAME')
                # 1x1x1 Conv, stride 1
                branch_3 = conv_batchnorm_relu(
                    branch_3,
                    'Conv3d_0b_1x1',
                    128,
                    kernel_size=1,
                    stride=1,
                    is_training=is_training,
                    num_cores=num_cores,
                    use_batch_norm=use_batch_norm,
                    use_cross_replica_batch_norm=use_cross_replica_batch_norm)

            # Concat branch_[0-3]
            net = tf.concat([branch_0, branch_1, branch_2, branch_3], 4)
        get_shape = net.get_shape().as_list()
        print('{} : {}'.format(end_point, get_shape))

        end_points[end_point] = net
        if final_endpoint == end_point: return net, end_points

        # Mixed 5c: Inception block
        end_point = 'Mixed_5c'
        with tf.variable_scope(end_point):
            with tf.variable_scope('Branch_0'):
                # 1x1x1 Conv, stride 1
                branch_0 = conv_batchnorm_relu(
                    net,
                    'Conv3d_0a_1x1',
                    384,
                    kernel_size=1,
                    stride=1,
                    is_training=is_training,
                    num_cores=num_cores,
                    use_batch_norm=use_batch_norm,
                    use_cross_replica_batch_norm=use_cross_replica_batch_norm)

            with tf.variable_scope('Branch_1'):
                # 1x1x1 Conv, stride 1
                branch_1 = conv_batchnorm_relu(
                    net,
                    'Conv3d_0a_1x1',
                    192,
                    kernel_size=1,
                    stride=1,
                    is_training=is_training,
                    num_cores=num_cores,
                    use_batch_norm=use_batch_norm,
                    use_cross_replica_batch_norm=use_cross_replica_batch_norm)
                # 3x3x3 Conv, stride 1
                branch_1 = conv_batchnorm_relu(
                    branch_1,
                    'Conv3d_0b_3x3',
                    384,
                    kernel_size=3,
                    stride=1,
                    is_training=is_training,
                    num_cores=num_cores,
                    use_batch_norm=use_batch_norm,
                    use_cross_replica_batch_norm=use_cross_replica_batch_norm)

            with tf.variable_scope('Branch_2'):
                # 1x1x1 Conv, stride 1
                branch_2 = conv_batchnorm_relu(
                    net,
                    'Conv3d_0a_1x1',
                    48,
                    kernel_size=1,
                    stride=1,
                    is_training=is_training,
                    num_cores=num_cores,
                    use_batch_norm=use_batch_norm,
                    use_cross_replica_batch_norm=use_cross_replica_batch_norm)
                # 3x3x3 Conv, stride 1
                branch_2 = conv_batchnorm_relu(
                    branch_2,
                    'Conv3d_0b_3x3',
                    128,
                    kernel_size=3,
                    stride=1,
                    is_training=is_training,
                    num_cores=num_cores,
                    use_batch_norm=use_batch_norm,
                    use_cross_replica_batch_norm=use_cross_replica_batch_norm)

            with tf.variable_scope('Branch_3'):
                # 3x3x3 Max-pool, stride 1, 1, 1
                branch_3 = maxpool(net,
                                   'MaxPool3d_0a_3x3',
                                   ksize=[1, 3, 3, 3, 1],
                                   strides=[1, 1, 1, 1, 1],
                                   padding='SAME')
                # 1x1x1 Conv, stride 1
                branch_3 = conv_batchnorm_relu(
                    branch_3,
                    'Conv3d_0b_1x1',
                    128,
                    kernel_size=1,
                    stride=1,
                    is_training=is_training,
                    num_cores=num_cores,
                    use_batch_norm=use_batch_norm,
                    use_cross_replica_batch_norm=use_cross_replica_batch_norm)

            # Concat branch_[0-3]
            net = tf.concat([branch_0, branch_1, branch_2, branch_3], 4)
        get_shape = net.get_shape().as_list()
        print('{} : {}'.format(end_point, get_shape))

        end_points[end_point] = net
        if final_endpoint == end_point: return net, end_points

        # Logits
        end_point = 'Logits'
        with tf.variable_scope(end_point):
            # 2x7x7 Average-pool, stride 1, 1, 1
            net = avgpool(net,
                          ksize=[1, 2, 7, 7, 1],
                          strides=[1, 1, 1, 1, 1],
                          padding='VALID')
            get_shape = net.get_shape().as_list()
            print('{} / Average-pool3D: {}'.format(end_point, get_shape))

            # Dropout
            net = tf.nn.dropout(net, dropout_keep_prob)

            # 1x1x1 Conv, stride 1
            logits = conv_batchnorm_relu(
                net,
                'Conv3d_0c_1x1',
                num_classes,
                kernel_size=1,
                stride=1,
                activation=None,
                use_batch_norm=use_batch_norm,
                use_cross_replica_batch_norm=use_cross_replica_batch_norm,
                is_training=is_training,
                num_cores=num_cores)
            get_shape = logits.get_shape().as_list()
            print('{} / Conv3d_0c_1x1 : {}'.format(end_point, get_shape))

            if spatial_squeeze:
                # Removes dimensions of size 1 from the shape of a tensor
                # Specify which dimensions have to be removed: 2 and 3
                logits = tf.squeeze(logits, [2, 3], name='SpatialSqueeze')
                get_shape = logits.get_shape().as_list()
                print('{} / Spatial Squeeze : {}'.format(end_point, get_shape))

        averaged_logits = tf.reduce_mean(logits, axis=1)
        get_shape = averaged_logits.get_shape().as_list()
        print('{} / Averaged Logits : {}'.format(end_point, get_shape))

        end_points[end_point] = averaged_logits
        if final_endpoint == end_point: return averaged_logits, end_points

        # Predictions
        end_point = 'Predictions'
        predictions = tf.nn.softmax(averaged_logits)
        end_points[end_point] = predictions
        return predictions, end_points