def train(config, checkpoint=False): for i in range(10): tune.report(test=i) checkpoint_dir = tune.make_checkpoint_dir(step=10) checkpoint_path = os.path.join(checkpoint_dir, "hello") with open(checkpoint_path, "w") as f: f.write("hello") tune.save_checkpoint(checkpoint_path)
def train(config, checkpoint=None): for step in range(10): if step % 3 == 0: checkpoint_dir = tune.make_checkpoint_dir(step=step) path = os.path.join(checkpoint_dir, "checkpoint") with open(path, "w") as f: f.write(json.dumps({"step": step})) tune.save_checkpoint(path) tune.report(test=step)
def train(config, checkpoint=False): itr = 0 if checkpoint: with open(checkpoint, "r") as f: itr = int(f.read()) + 1 for i in range(itr, config["max_iter"]): checkpoint_dir = tune.make_checkpoint_dir(step=i) checkpoint_path = os.path.join(checkpoint_dir, "goodbye") with open(checkpoint_path, "w") as f: f.write(str(i)) tune.save_checkpoint(checkpoint_path) tune.report(test=i, training_iteration=i)
def train(config, checkpoint=None): restored = bool(checkpoint) itr = 0 if checkpoint: with open(checkpoint, "r") as f: itr = int(f.read()) + 1 for i in range(itr, 10): if i == 5 and not restored: raise Exception("try to fail me") checkpoint_dir = tune.make_checkpoint_dir() checkpoint_path = os.path.join(checkpoint_dir, "goodbye") with open(checkpoint_path, "w") as f: f.write(str(i)) tune.save_checkpoint(checkpoint_path) tune.report(test=i, training_iteration=i)
def train(config, checkpoint=None): step = 0 if checkpoint: with open(checkpoint) as f: step = json.loads(f.read())["timestep"] for timestep in range(step, 100): v = np.tanh(float(timestep) / config.get("width", 1)) v *= config.get("height", 1) if timestep % 3 == 0: checkpoint_dir = tune.make_checkpoint_dir(step=timestep) path = os.path.join(checkpoint_dir, "checkpoint") with open(path, "w") as f: f.write(json.dumps({"timestep": timestep})) tune.save_checkpoint(path) # Here we use `episode_reward_mean`, but you can also report other # objectives such as loss or accuracy. tune.report(episode_reward_mean=v)
def on_validation_end(self, trainer, pl_module): path = tune.make_checkpoint_dir(trainer.global_step) trainer.save_checkpoint(os.path.join(path, "checkpoint")) tune.save_checkpoint(path)
def pbt_function(config, checkpoint=None): """Toy PBT problem for benchmarking adaptive learning rate. The goal is to optimize this trainable's accuracy. The accuracy increases fastest at the optimal lr, which is a function of the current accuracy. The optimal lr schedule for this problem is the triangle wave as follows. Note that many lr schedules for real models also follow this shape: best lr ^ | /\ | / \ | / \ | / \ ------------> accuracy In this problem, using PBT with a population of 2-4 is sufficient to roughly approximate this lr schedule. Higher population sizes will yield faster convergence. Training will not converge without PBT. """ lr = config["lr"] accuracy = 0.0 # end = 1000 start = 0 if checkpoint: with open(checkpoint) as f: state = json.loads(f.read()) accuracy = state["acc"] start = state["step"] midpoint = 100 # lr starts decreasing after acc > midpoint q_tolerance = 3 # penalize exceeding lr by more than this multiple noise_level = 2 # add gaussian noise to the acc increase # triangle wave: # - start at 0.001 @ t=0, # - peak at 0.01 @ t=midpoint, # - end at 0.001 @ t=midpoint * 2, for step in range(start, 100): if accuracy < midpoint: optimal_lr = 0.01 * accuracy / midpoint else: optimal_lr = 0.01 - 0.01 * (accuracy - midpoint) / midpoint optimal_lr = min(0.01, max(0.001, optimal_lr)) # compute accuracy increase q_err = max(lr, optimal_lr) / min(lr, optimal_lr) if q_err < q_tolerance: accuracy += (1.0 / q_err) * random.random() elif lr > optimal_lr: accuracy -= (q_err - q_tolerance) * random.random() accuracy += noise_level * np.random.normal() accuracy = max(0, accuracy) if step % 3 == 0: checkpoint_dir = tune.make_checkpoint_dir(step=step) path = os.path.join(checkpoint_dir, "checkpoint") with open(path, "w") as f: f.write(json.dumps({"acc": accuracy, "step": start})) tune.save_checkpoint(path) tune.report( mean_accuracy=accuracy, cur_lr=lr, optimal_lr=optimal_lr, # for debugging q_err=q_err, # for debugging done=accuracy > midpoint * 2)
def train_cifar(config, checkpoint=None, data_dir=None): net = Net(config["l1"], config["l2"]) device = "cpu" if torch.cuda.is_available(): device = "cuda:0" if torch.cuda.device_count() > 1: net = nn.DataParallel(net) net.to(device) criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(net.parameters(), lr=config["lr"], momentum=0.9) if checkpoint: print("loading checkpoint {}".format(checkpoint)) model_state, optimizer_state = torch.load(checkpoint) net.load_state_dict(model_state) optimizer.load_state_dict(optimizer_state) trainset, testset = load_data(data_dir) test_abs = int(len(trainset) * 0.8) train_subset, val_subset = random_split( trainset, [test_abs, len(trainset) - test_abs]) trainloader = torch.utils.data.DataLoader( train_subset, batch_size=int(config["batch_size"]), shuffle=True, num_workers=8) valloader = torch.utils.data.DataLoader( val_subset, batch_size=int(config["batch_size"]), shuffle=True, num_workers=8) for epoch in range(10): # loop over the dataset multiple times running_loss = 0.0 epoch_steps = 0 for i, data in enumerate(trainloader, 0): # get the inputs; data is a list of [inputs, labels] inputs, labels = data inputs, labels = inputs.to(device), labels.to(device) # zero the parameter gradients optimizer.zero_grad() # forward + backward + optimize outputs = net(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() # print statistics running_loss += loss.item() epoch_steps += 1 if i % 2000 == 1999: # print every 2000 mini-batches print("[%d, %5d] loss: %.3f" % (epoch + 1, i + 1, running_loss / epoch_steps)) running_loss = 0.0 # Validation loss val_loss = 0.0 val_steps = 0 total = 0 correct = 0 for i, data in enumerate(valloader, 0): with torch.no_grad(): inputs, labels = data inputs, labels = inputs.to(device), labels.to(device) outputs = net(inputs) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() loss = criterion(outputs, labels) val_loss += loss.cpu().numpy() val_steps += 1 checkpoint_dir = tune.make_checkpoint_dir(epoch) path = os.path.join(checkpoint_dir, "checkpoint") torch.save((net.state_dict(), optimizer.state_dict()), path) tune.save_checkpoint(path) tune.report(loss=(val_loss / val_steps), accuracy=correct / total) print("Finished Training")
def __exit__(self, type, value, traceback): if torch.distributed.get_rank() == 0 and not self.disable: tune.save_checkpoint(self.file)