Ejemplo n.º 1
0
 def pull_embedding_param(self, name, ids):
     tensor = Tensor(name, None, ids)
     pb = core_pb2.Tensor()
     serialize_to_pb(tensor, pb)
     res = self.pserver_stubs[0].pull_embedding_param(pb)
     res_tensor = Tensor()
     deserialize_from_pb(res, res_tensor)
     return res_tensor
Ejemplo n.º 2
0
 def push_embedding_param(self, request, _):
     tensor = Tensor()
     deserialize_from_pb(request, tensor)
     embedding_param = EmbeddingTable(request.name, 0, request.dim,
                                      request.initializer)
     self.kvstore.set_embedding_table(request.name, embedding_param)
     return empty_pb2.Empty()
Ejemplo n.º 3
0
 def pull_param(self, name):
     pb = core_pb2.Tensor()
     pb.name = name
     res = self.pserver_stubs[0].pull_param(pb)
     tensor = Tensor()
     deserialize_from_pb(res, tensor)
     return tensor
Ejemplo n.º 4
0
    def get(self, ids):
        values = []
        for i in ids:
            if i not in self.vectors:
                initializer = tf.keras.initializers.get(self.initializer)
                value = initializer(shape=self.dim).numpy()
                self.vectors[i] = value
            else:
                value = self.vectors[i]
            values.append(value)

        if len(values) == 0:
            tensor = Tensor(self.name, None, None, self.version)
        else:
            values = np.stack(values)
            tensor = Tensor(self.name, values, ids, self.version)
        return tensor
Ejemplo n.º 5
0
 def push_embedding_grad(self, request, _):
     tensor = Tensor()
     deserialize_from_pb(request, tensor)
     self.grad_queue.put_grad(tensor)
     return empty_pb2.Empty()
Ejemplo n.º 6
0
 def push_param(self, request, _):
     tensor = Tensor()
     deserialize_from_pb(request, tensor)
     self.kvstore.set_param(request.name, tensor)
     return empty_pb2.Empty()
Ejemplo n.º 7
0
        serialize_to_pb(grad, pb)
        self.pserver_stubs[0].push_grad(pb)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('-p', '--pserver_endpoints', nargs='+', required=True)
    parser.add_argument("-i", "--worker_id", type=int)
    args = parser.parse_args()

    worker = KVStoreClient(args.pserver_endpoints, args.worker_id)
    print("Starting worker. Connecting to pserver %s." %
          " ".join(args.pserver_endpoints))

    init_param = Tensor(name="tom",
                        value=np.ones(shape=(3, )),
                        initializer="uniform")
    worker.push_embedding_param(init_param)

    ids = [0, 1, 2, 3, 4, 5]
    res = worker.pull_embedding_param("tom", ids)
    print(res.value)
    print(res.indices)

    time.sleep(2)

    ids = [0, 3, 3]
    value = np.full((3, 3), 2.0)
    grad = Tensor("tom", value, ids)

    worker.push_embedding_grad(grad)