def to_air_checkpoint(self) -> Optional[Checkpoint]: from ray.tune.trainable.util import TrainableUtil checkpoint_data = self.dir_or_data if not checkpoint_data: return None if isinstance(checkpoint_data, ray.ObjectRef): checkpoint_data = ray.get(checkpoint_data) if isinstance(checkpoint_data, str): checkpoint_dir = TrainableUtil.find_checkpoint_dir(checkpoint_data) checkpoint = Checkpoint.from_directory(checkpoint_dir) elif isinstance(checkpoint_data, bytes): with tempfile.TemporaryDirectory() as tmpdir: TrainableUtil.create_from_pickle(checkpoint_data, tmpdir) # Double wrap in checkpoint so we hold the data in memory and # can remove the temp directory checkpoint = Checkpoint.from_dict( Checkpoint.from_directory(tmpdir).to_dict()) elif isinstance(checkpoint_data, dict): checkpoint = Checkpoint.from_dict(checkpoint_data) else: raise RuntimeError( f"Unknown checkpoint data type: {type(checkpoint_data)}") return checkpoint
def train_func(config): itr = 0 ckpt = session.get_checkpoint() if ckpt is not None: ckpt = ckpt.to_dict() itr = ckpt["iter"] + 1 for i in range(itr, config["max_iter"]): session.report( dict(test=i, training_iteration=i), checkpoint=Checkpoint.from_dict(dict(iter=i)), )
def train_func(): ckpt = session.get_checkpoint() restored = bool(ckpt) # Does a previous checkpoint exist? itr = 0 if ckpt: ckpt = ckpt.to_dict() itr = ckpt["iter"] + 1 for i in range(itr, 4): if i == 2 and not restored: raise Exception("try to fail me") session.report( dict(test=i, training_iteration=i), checkpoint=Checkpoint.from_dict(dict(iter=i)), )
def test_resume_from_checkpoint(): with pytest.raises(ValueError): DummyTrainer(resume_from_checkpoint="invalid") with pytest.raises(ValueError): DummyTrainer(resume_from_checkpoint=False) with pytest.raises(ValueError): DummyTrainer(resume_from_checkpoint=True) with pytest.raises(ValueError): DummyTrainer(resume_from_checkpoint={}) # Succeed DummyTrainer(resume_from_checkpoint=None) # Succeed DummyTrainer(resume_from_checkpoint=Checkpoint.from_dict({"empty": ""}))
def commit(self, path: Optional[Path] = None) -> None: """Commit checkpoint to disk, if needed. Args: path: Path to commit checkpoint to. """ if self.storage_mode == CheckpointStorage.MEMORY: # Do not persist memory checkpoints return if not path: # If no path is given, skip return if not isinstance(self.dir_or_data, dict): # Only persist dictionaries return checkpoint = Checkpoint.from_dict(self.dir_or_data) self.dir_or_data = checkpoint.to_directory(str(path))
def to_air_checkpoint(self) -> Optional[Checkpoint]: from ray.tune.trainable.util import TrainableUtil checkpoint_data = self.dir_or_data if not checkpoint_data: return None if isinstance(checkpoint_data, ray.ObjectRef): checkpoint_data = ray.get(checkpoint_data) if isinstance(checkpoint_data, str): try: checkpoint_dir = TrainableUtil.find_checkpoint_dir( checkpoint_data) except FileNotFoundError: if log_once("checkpoint_not_available"): logger.error( f"The requested checkpoint is not available on this node, " f"most likely because you are using Ray client or disabled " f"checkpoint synchronization. To avoid this, enable checkpoint " f"synchronization to cloud storage by specifying a " f"`SyncConfig`. The checkpoint may be available on a different " f"node - please check this location on worker nodes: " f"{checkpoint_data}") return None checkpoint = Checkpoint.from_directory(checkpoint_dir) elif isinstance(checkpoint_data, bytes): checkpoint = Checkpoint.from_bytes(checkpoint_data) elif isinstance(checkpoint_data, dict): checkpoint = Checkpoint.from_dict(checkpoint_data) else: raise RuntimeError( f"Unknown checkpoint data type: {type(checkpoint_data)}") return checkpoint
def function_trainable_dict(config): session.report({"metric": 2}, checkpoint=Checkpoint.from_dict({"checkpoint_data": 3}))
def train_func(): for i in range(9): session.report(dict(test=i)) session.report( dict(test=i + 1), checkpoint=Checkpoint.from_dict(dict(hello="world")) )
def train_func(config): use_gpu = config["use_gpu"] num_epochs = config["num_epochs"] batch_size = config["batch_size"] num_layers = config["num_layers"] num_hidden = config["num_hidden"] dropout_every = config["dropout_every"] dropout_prob = config["dropout_prob"] num_features = config["num_features"] print("Defining model, loss, and optimizer...") # Setup device. device = torch.device(f"cuda:{session.get_local_rank()}" if use_gpu and torch.cuda.is_available() else "cpu") print(f"Device: {device}") # Setup data. train_dataset_pipeline = session.get_dataset_shard("train") train_dataset_epoch_iterator = train_dataset_pipeline.iter_epochs() test_dataset = session.get_dataset_shard("test") test_torch_dataset = test_dataset.to_torch(label_column="label", batch_size=batch_size, drop_last=True) net = Net( n_layers=num_layers, n_features=num_features, num_hidden=num_hidden, dropout_every=dropout_every, drop_prob=dropout_prob, ).to(device) print(net.parameters) net = train.torch.prepare_model(net) criterion = nn.BCEWithLogitsLoss() optimizer = optim.Adam(net.parameters(), weight_decay=0.0001) print("Starting training...") for epoch in range(num_epochs): train_dataset = next(train_dataset_epoch_iterator) train_torch_dataset = train_dataset.to_torch(label_column="label", batch_size=batch_size) train_running_loss, train_num_correct, train_num_total = train_epoch( train_torch_dataset, net, device, criterion, optimizer) train_acc = train_num_correct / train_num_total print(f"epoch [{epoch + 1}]: training accuracy: " f"{train_num_correct} / {train_num_total} = {train_acc:.4f}") test_running_loss, test_num_correct, test_num_total = test_epoch( test_torch_dataset, net, device, criterion) test_acc = test_num_correct / test_num_total print(f"epoch [{epoch + 1}]: testing accuracy: " f"{test_num_correct} / {test_num_total} = {test_acc:.4f}") # Checkpoint model. module = net.module if isinstance(net, DistributedDataParallel) else net checkpoint = Checkpoint.from_dict(dict(model=module.state_dict())) # Record and log stats. print(f"session report on {session.get_world_rank()}") session.report( dict( train_acc=train_acc, train_loss=train_running_loss, test_acc=test_acc, test_loss=test_running_loss, ), checkpoint=checkpoint, )