コード例 #1
0
    def train_online(self, model, *args, **kwargs):
        remote_model = RayRemoteModel(model)
        results = self.executor.execute(lambda trainer: trainer.train_online(
            remote_model.load(), *args, **kwargs))

        weights = results[0]
        load_weights_from_buffer(model, weights)
        return model
コード例 #2
0
 def load(self):
     obj = self.cls(*self.args)
     buf = ray.get(self.state)
     load_weights_from_buffer(obj, buf)
     return obj