Ejemplo n.º 1
0
def where(
    condition: remote_blob_util.BlobDef,
    x: Optional[remote_blob_util.BlobDef] = None,
    y: Optional[remote_blob_util.BlobDef] = None,
    name: Optional[str] = None,
) -> remote_blob_util.BlobDef:
    if x is None and y is None:
        return argwhere(condition, name=name)
    elif x is not None and y is not None:
        if name is None:
            name = id_util.UniqueStr("Where_")

        if x.shape == condition.shape and y.shape == condition.shape:
            broadcast_cond = condition
            broadcast_x = x
            broadcast_y = y
        else:
            broadcast_cond = flow.broadcast_to_compatible_with(condition, [x, y])
            broadcast_x = flow.broadcast_to_compatible_with(x, [condition, y])
            broadcast_y = flow.broadcast_to_compatible_with(y, [condition, x])
        return (
            flow.user_op_builder(name)
            .Op("where")
            .Input("condition", [broadcast_cond])
            .Input("x", [broadcast_x])
            .Input("y", [broadcast_y])
            .Output("out")
            .Build()
            .InferAndTryRun()
            .RemoteBlobList()[0]
        )
    else:
        raise ValueError("it is not supported when exactly one of x or y is non-None")
Ejemplo n.º 2
0
    def broadcast_to_compatible_with_fn(x_def: oft.Numpy.Placeholder(
        x.shape, dtype=flow.float)):
        x_var = flow.get_variable(
            "x_var",
            shape=x.shape,
            dtype=flow.float,
            initializer=flow.constant_initializer(0),
            trainable=True,
        )
        compatible_var = [
            flow.get_variable(
                "compatible_var_{}".format(i),
                shape=cp_shape,
                dtype=flow.float,
                initializer=flow.random_normal_initializer(),
                trainable=False,
            ) for i, cp_shape in enumerate(compatible_shape)
        ]
        x_var = x_var + x_def
        y = flow.broadcast_to_compatible_with(x_var, compatible_var)
        flow.optimizer.SGD(flow.optimizer.PiecewiseConstantScheduler([],
                                                                     [1e-3]),
                           momentum=0).minimize(y)

        flow.watch_diff(x_var, dx_watcher)
        return y
Ejemplo n.º 3
0
 def broadcast_to_compatible_with_fn(
         x_def: oft.ListNumpy.Placeholder(x_shape, dtype=flow.float),
         a_def: oft.ListNumpy.Placeholder(a_shape, dtype=flow.float),
         b_def: oft.ListNumpy.Placeholder(b_shape, dtype=flow.float),
 ):
     return flow.broadcast_to_compatible_with(
         x_def,
         [flow.identity(a_def), flow.identity(b_def)])
Ejemplo n.º 4
0
 def broadcast_to_compatible_with_fn(x_def: oft.ListNumpy.Placeholder(
     shape=x_shape, dtype=flow.float)):
     compatible_var = [
         flow.get_variable(
             "compatible_var_{}".format(i),
             shape=cp_shape,
             dtype=flow.float,
             initializer=flow.random_normal_initializer(),
             trainable=False,
         ) for i, cp_shape in enumerate(compatible_shape)
     ]
     return flow.broadcast_to_compatible_with(x_def, compatible_var)
    def broadcast_to_compatible_with_fn(x_def: oft.Numpy.Placeholder(
        x.shape, dtype=flow.float)):
        x_var = flow.get_variable(
            "x_var",
            shape=x.shape,
            dtype=flow.float,
            initializer=flow.constant_initializer(0),
            trainable=True,
        )
        compatible_var = [
            flow.get_variable(
                "compatible_var_{}".format(i),
                shape=cp_shape,
                dtype=flow.float,
                initializer=flow.random_normal_initializer(),
                trainable=False,
            ) for i, cp_shape in enumerate(compatible_shape)
        ]
        x_var = x_var + x_def
        y = flow.broadcast_to_compatible_with(x_var, compatible_var)
        flow.losses.add_loss(y)

        flow.watch_diff(x_var, dx_watcher)
        return y