Ejemplo n.º 1
0
 def to_tf(cls, ctx, node, **kwargs):
     separate_fused_activation_function(ctx, node)
     # No need to change 'padding' attribute
     stride_h = node.get_attr_int("stride_h")
     stride_w = node.get_attr_int("stride_w")
     filter_height = node.get_attr_int("filter_height")
     filter_width = node.get_attr_int("filter_width")
     node.set_attr("strides", [1, stride_h, stride_w, 1])
     node.set_attr("ksize", [1, filter_height, filter_width, 1])
     del node.attr["stride_h"]
     del node.attr["stride_w"]
     del node.attr["filter_height"]
     del node.attr["filter_width"]
     node.set_attr("data_format", "NHWC")
Ejemplo n.º 2
0
 def to_tf(cls, ctx, node, **kwargs):
     separate_fused_activation_function(ctx, node)
     # No need to change 'padding' or 'depth_multiplier' attributes
     stride_h = node.get_attr_int("stride_h")
     stride_w = node.get_attr_int("stride_w")
     dilation_w_factor = node.get_attr_int("dilation_w_factor")
     dilation_h_factor = node.get_attr_int("dilation_h_factor")
     node.set_attr("strides", [1, stride_h, stride_w, 1])
     node.set_attr("dilations", [1, dilation_h_factor, dilation_w_factor, 1])
     del node.attr["stride_h"]
     del node.attr["stride_w"]
     del node.attr["dilation_h_factor"]
     del node.attr["dilation_w_factor"]
     transpose_node = ctx.insert_new_node_on_input(node, "Transpose", node.input[1], name=None, perm=[1, 2, 3, 0])
     transpose_node.skip_conversion = True
     node.set_attr("data_format", "NHWC")
Ejemplo n.º 3
0
 def to_tf(cls, ctx, node, **kwargs):
     separate_fused_activation_function(ctx, node)
     # No need to change 'padding' attribute
     stride_h = node.get_attr_int("stride_h")
     stride_w = node.get_attr_int("stride_w")
     stride_d = node.get_attr_int("stride_d")
     dilation_w_factor = node.get_attr_int("dilation_w_factor")
     dilation_h_factor = node.get_attr_int("dilation_h_factor")
     dilation_d_factor = node.get_attr_int("dilation_d_factor")
     node.set_attr("strides", [1, stride_d, stride_h, stride_w, 1])
     node.set_attr("dilations", [1, dilation_d_factor, dilation_h_factor, dilation_w_factor, 1])
     del node.attr["stride_h"]
     del node.attr["stride_w"]
     del node.attr["stride_d"]
     del node.attr["dilation_h_factor"]
     del node.attr["dilation_w_factor"]
     del node.attr["dilation_d_factor"]
     node.set_attr("data_format", "NDHWC")