def tensor_scatter_nd_add_fn( params_def: oft.ListNumpy.Placeholder(params.shape, dtype=flow.float), indices_def: oft.ListNumpy.Placeholder(indices_static_shape, dtype=flow.int32), updates_def: oft.ListNumpy.Placeholder(updates_static_shape, dtype=flow.float), ): with flow.scope.placement("gpu", "0:0"): return flow.tensor_scatter_nd_add(params_def, indices_def, updates_def)
def do_tensor_scatter_nd_add(params_blob, indices_blob, updates_blob): with flow.scope.placement(device_type, "0:0"): params_var = flow.get_variable( "params", shape=params_blob.shape, dtype=flow.float32, initializer=flow.constant_initializer(0), ) updates_var = flow.get_variable( "updates", shape=updates_blob.shape, dtype=flow.float32, initializer=flow.constant_initializer(0), ) params_var = flow.cast_to_current_logical_view(params_var) params_blob = flow.cast_to_current_logical_view(params_blob) updates_blob = flow.cast_to_current_logical_view(updates_blob) updates_var = flow.cast_to_current_logical_view(updates_var) params_var = params_var + params_blob updates_var = updates_var + updates_blob out = flow.tensor_scatter_nd_add(params_var, indices_blob, updates_var) flow.losses.add_loss(out) flow.watch_diff(params_var, params_grad_watcher) flow.watch_diff(updates_var, updates_grad_watcher) return out
def do_tensor_scatter_nd_add(params_blob, indices_blob, updates_blob): with flow.scope.placement(device_type, "0:0"): params_var = flow.get_variable( "params", shape=params_blob.shape, dtype=flow.float32, initializer=flow.constant_initializer(0), ) updates_var = flow.get_variable( "updates", shape=updates_blob.shape, dtype=flow.float32, initializer=flow.constant_initializer(0), ) params_var = flow.cast_to_current_logical_view(params_var) params_blob = flow.cast_to_current_logical_view(params_blob) updates_blob = flow.cast_to_current_logical_view(updates_blob) updates_var = flow.cast_to_current_logical_view(updates_var) params_var = params_var + params_blob updates_var = updates_var + updates_blob out = flow.tensor_scatter_nd_add(params_var, indices_blob, updates_var) flow.optimizer.SGD( flow.optimizer.PiecewiseConstantScheduler([], [1e-3]), momentum=0 ).minimize(out) flow.watch_diff(params_var, params_grad_watcher) flow.watch_diff(updates_var, updates_grad_watcher) return out