def train_func(config): batch_size = config.get("batch_size", 32) hidden_size = config.get("hidden_size", 1) lr = config.get("lr", 1e-2) epochs = config.get("epochs", 3) train_dataset_pipeline_shard = sgd.get_dataset_shard("train") validation_dataset_pipeline_shard = sgd.get_dataset_shard("validation") device = torch.device( f"cuda:{sgd.local_rank()}" if torch.cuda.is_available() else "cpu") if torch.cuda.is_available(): torch.cuda.set_device(device) model = nn.Linear(1, hidden_size) model = model.to(device) model = DistributedDataParallel( model, device_ids=[sgd.local_rank()] if torch.cuda.is_available() else None) loss_fn = nn.MSELoss() optimizer = torch.optim.SGD(model.parameters(), lr=lr) results = [] train_dataset_iterator = train_dataset_pipeline_shard.iter_datasets() validation_dataset_iterator = \ validation_dataset_pipeline_shard.iter_datasets() for _ in range(epochs): train_dataset = next(train_dataset_iterator) validation_dataset = next(validation_dataset_iterator) train_torch_dataset = train_dataset.to_torch( label_column="y", feature_columns=["x"], label_column_dtype=torch.float, feature_column_dtypes=[torch.float], batch_size=batch_size, ) validation_torch_dataset = validation_dataset.to_torch( label_column="y", feature_columns=["x"], label_column_dtype=torch.float, feature_column_dtypes=[torch.float], batch_size=batch_size) train(train_torch_dataset, model, loss_fn, optimizer, device) result = validate(validation_torch_dataset, model, loss_fn, device) sgd.report(**result) results.append(result) return results
def train_actor_failure(): import sys sys.exit(0) return sgd.local_rank()
def train(): return sgd.local_rank()