Exemplo n.º 1
0
 def extract_group_norm(self,
                        input_name,
                        groups,
                        output_name,
                        scope_id,
                        data_format="NCHW",
                        axis=1,
                        layer_names=["GroupNorm", "gamma", "beta"]):
     assert (data_format == "NCHW")
     gamma, beta = self.get_weights(scope_id, layer_names)
     layer = caffe_net.LayerParameter(name=output_name,
                                      type='GroupNorm',
                                      bottom=[input_name],
                                      top=[output_name])
     layer.add_data(gamma, beta)
     layer.group_norm_param(groups)
     self.caffe_model.add_layer(layer)
     if (self.data_dict[input_name] is not None):
         input_data, input_shape, inv_transpose_dims = self.preprocess_nchwc8_nchw_input(
             input_name, axis)
         output_data = Operators.group_norm(input_data, groups, gamma, beta,
                                            output_name)
         self.data_dict[output_name] = self.postprocess_nchwc8_nchw_output(
             output_data, input_shape, inv_transpose_dims)
     else:
         self.data_dict[output_name] = None
     return output_name