def train_online(self, model, *args, **kwargs): remote_model = RayRemoteModel(model) results = self.executor.execute(lambda trainer: trainer.train_online( remote_model.load(), *args, **kwargs)) weights = results[0] load_weights_from_buffer(model, weights) return model
def load(self): obj = self.cls(*self.args) buf = ray.get(self.state) load_weights_from_buffer(obj, buf) return obj