コード例 #1
0
def calc_broadcast_shape_from_param(params):
    """
    Calculate the broadcast shape from params.

    Args:
        params (dict): parameters used to initialize distribution.

    Returns:
        tuple.
    """
    broadcast_shape = []
    for value in params.values():
        if isinstance(value, (msp.bijector.Bijector, msp.distribution.Distribution)):
            return params['distribution'].broadcast_shape
        if isinstance(value, (str, type(params['dtype']))):
            continue
        if value is None:
            return None
        if isinstance(value, Parameter):
            value_t = value.default_input
        else:
            value_t = cast_to_tensor(value, mstype.float32)
        broadcast_shape = utils.get_broadcast_shape(
            broadcast_shape, list(value_t.shape), params['name'])
    return tuple(broadcast_shape)
コード例 #2
0
 def infer_shape(self, x_shape, y_shape):
     return get_broadcast_shape(x_shape, y_shape)