def _Pool3DShape(op): """Shape function for Max/AvgPool3D.""" input_shape = op.inputs[0].get_shape().with_rank(5) ksize_b, ksize_p, ksize_r, ksize_c, ksize_d = op.get_attr("ksize") assert ksize_b == 1 assert ksize_d == 1 stride_b, stride_p, stride_r, stride_c, stride_d = op.get_attr("strides") assert stride_b == 1 assert stride_d == 1 batch_size = input_shape[0] channels = input_shape[4] padding = op.get_attr("padding") out_planes, out_rows, out_cols = common_shapes.get_conv_output_size( input_shape[1:4], (ksize_p, ksize_r, ksize_c), (stride_p, stride_r, stride_c), padding) return [tensor_shape.TensorShape([batch_size, out_planes, out_rows, out_cols, channels])]
def _Pool3DShape(op): """Shape function for Max/AvgPool3D.""" input_shape = op.inputs[0].get_shape().with_rank(5) ksize_b, ksize_p, ksize_r, ksize_c, ksize_d = op.get_attr("ksize") assert ksize_b == 1 assert ksize_d == 1 stride_b, stride_p, stride_r, stride_c, stride_d = op.get_attr("strides") assert stride_b == 1 assert stride_d == 1 batch_size = input_shape[0] channels = input_shape[4] padding = op.get_attr("padding") out_planes, out_rows, out_cols = common_shapes.get_conv_output_size( input_shape[1:4], (ksize_p, ksize_r, ksize_c), (stride_p, stride_r, stride_c), padding) return [ tensor_shape.TensorShape( [batch_size, out_planes, out_rows, out_cols, channels]) ]
def _Conv3DShape(op): """Shape function for Conv3D.""" input_shape = op.inputs[0].get_shape().with_rank(5) filter_shape = op.inputs[1].get_shape().with_rank(5) batch_size = input_shape[0] out_channels = filter_shape[4] # Check that the input number of channels is compatible between # input data and filter size. input_shape[4].assert_is_compatible_with(filter_shape[3]) stride_b, stride_p, stride_r, stride_c, stride_d = op.get_attr("strides") assert stride_b == 1 assert stride_d == 1 padding_type = op.get_attr("padding") out_planes, out_rows, out_cols = common_shapes.get_conv_output_size( input_shape[1:4], filter_shape[0:3], (stride_p, stride_r, stride_c), padding_type) return [tensor_shape.TensorShape([batch_size, out_planes, out_rows, out_cols, out_channels])]
def _Conv3DShape(op): """Shape function for Conv3D.""" input_shape = op.inputs[0].get_shape().with_rank(5) filter_shape = op.inputs[1].get_shape().with_rank(5) batch_size = input_shape[0] out_channels = filter_shape[4] # Check that the input number of channels is compatible between # input data and filter size. input_shape[4].assert_is_compatible_with(filter_shape[3]) stride_b, stride_p, stride_r, stride_c, stride_d = op.get_attr("strides") assert stride_b == 1 assert stride_d == 1 padding_type = op.get_attr("padding") out_planes, out_rows, out_cols = common_shapes.get_conv_output_size( input_shape[1:4], filter_shape[0:3], (stride_p, stride_r, stride_c), padding_type) return [ tensor_shape.TensorShape( [batch_size, out_planes, out_rows, out_cols, out_channels]) ]