Example #1
0
def _Pool3DShape(op):
  """Shape function for Max/AvgPool3D."""
  input_shape = op.inputs[0].get_shape().with_rank(5)
  ksize_b, ksize_p, ksize_r, ksize_c, ksize_d = op.get_attr("ksize")
  assert ksize_b == 1
  assert ksize_d == 1

  stride_b, stride_p, stride_r, stride_c, stride_d = op.get_attr("strides")
  assert stride_b == 1
  assert stride_d == 1

  batch_size = input_shape[0]
  channels = input_shape[4]

  padding = op.get_attr("padding")
  out_planes, out_rows, out_cols = common_shapes.get_conv_output_size(
      input_shape[1:4], (ksize_p, ksize_r, ksize_c),
      (stride_p, stride_r, stride_c), padding)
  return [tensor_shape.TensorShape([batch_size, out_planes, out_rows, out_cols,
                                    channels])]
Example #2
0
def _Pool3DShape(op):
    """Shape function for Max/AvgPool3D."""
    input_shape = op.inputs[0].get_shape().with_rank(5)
    ksize_b, ksize_p, ksize_r, ksize_c, ksize_d = op.get_attr("ksize")
    assert ksize_b == 1
    assert ksize_d == 1

    stride_b, stride_p, stride_r, stride_c, stride_d = op.get_attr("strides")
    assert stride_b == 1
    assert stride_d == 1

    batch_size = input_shape[0]
    channels = input_shape[4]

    padding = op.get_attr("padding")
    out_planes, out_rows, out_cols = common_shapes.get_conv_output_size(
        input_shape[1:4], (ksize_p, ksize_r, ksize_c),
        (stride_p, stride_r, stride_c), padding)
    return [
        tensor_shape.TensorShape(
            [batch_size, out_planes, out_rows, out_cols, channels])
    ]
Example #3
0
def _Conv3DShape(op):
  """Shape function for Conv3D."""
  input_shape = op.inputs[0].get_shape().with_rank(5)
  filter_shape = op.inputs[1].get_shape().with_rank(5)

  batch_size = input_shape[0]
  out_channels = filter_shape[4]
  # Check that the input number of channels is compatible between
  # input data and filter size.
  input_shape[4].assert_is_compatible_with(filter_shape[3])

  stride_b, stride_p, stride_r, stride_c, stride_d = op.get_attr("strides")
  assert stride_b == 1
  assert stride_d == 1

  padding_type = op.get_attr("padding")
  out_planes, out_rows, out_cols = common_shapes.get_conv_output_size(
      input_shape[1:4], filter_shape[0:3], (stride_p, stride_r, stride_c),
      padding_type)

  return [tensor_shape.TensorShape([batch_size, out_planes, out_rows, out_cols,
                                    out_channels])]
Example #4
0
def _Conv3DShape(op):
    """Shape function for Conv3D."""
    input_shape = op.inputs[0].get_shape().with_rank(5)
    filter_shape = op.inputs[1].get_shape().with_rank(5)

    batch_size = input_shape[0]
    out_channels = filter_shape[4]
    # Check that the input number of channels is compatible between
    # input data and filter size.
    input_shape[4].assert_is_compatible_with(filter_shape[3])

    stride_b, stride_p, stride_r, stride_c, stride_d = op.get_attr("strides")
    assert stride_b == 1
    assert stride_d == 1

    padding_type = op.get_attr("padding")
    out_planes, out_rows, out_cols = common_shapes.get_conv_output_size(
        input_shape[1:4], filter_shape[0:3], (stride_p, stride_r, stride_c),
        padding_type)

    return [
        tensor_shape.TensorShape(
            [batch_size, out_planes, out_rows, out_cols, out_channels])
    ]