def calc_broadcast_shape_from_param(params): """ Calculate the broadcast shape from params. Args: params (dict): parameters used to initialize distribution. Returns: tuple. """ broadcast_shape = [] for value in params.values(): if isinstance(value, (msp.bijector.Bijector, msp.distribution.Distribution)): return params['distribution'].broadcast_shape if isinstance(value, (str, type(params['dtype']))): continue if value is None: return None if isinstance(value, Parameter): value_t = value.default_input else: value_t = cast_to_tensor(value, mstype.float32) broadcast_shape = utils.get_broadcast_shape( broadcast_shape, list(value_t.shape), params['name']) return tuple(broadcast_shape)
def infer_shape(self, x_shape, y_shape): return get_broadcast_shape(x_shape, y_shape)