def RemoteTrainer(estimator, metadata, last_checkpoint_state, run_id, dataset_idx): # Estimator parameters gradient_compression = estimator.getGradientCompression() input_shapes = estimator.getInputShapes() label_shapes = estimator.getLabelShapes() feature_columns = estimator.getFeatureCols() label_columns = estimator.getLabelCols() num_labels = len(label_columns) should_validate = estimator.getValidation() batch_size = estimator.getBatchSize() val_batch_size = estimator.getValBatchSize() if estimator.getValBatchSize( ) else batch_size epochs = estimator.getEpochs() train_steps_per_epoch = estimator.getTrainStepsPerEpoch() validation_steps_per_epoch = estimator.getValidationStepsPerEpoch() sample_weight_col = estimator.getSampleWeightCol() metric_fn_groups = estimator.getMetrics() user_shuffle_buffer_size = estimator.getShufflingBufferSize() user_verbose = estimator.getVerbose() train_minibatch_fn = estimator.getTrainMinibatchFn() train_minibatch = train_minibatch_fn if train_minibatch_fn else _train_minibatch_fn( ) loss_fns_pre_train = to_list(estimator.getLoss(), num_labels) loss_constructors = to_list(estimator.getLossConstructors(), num_labels) transformation_fn = estimator.getTransformationFn() transformation = transformation_fn if transformation_fn else None inmemory_cache_all = estimator.getInMemoryCacheAll() # If loss weight is not provided, use equal loss for all the labels loss_weights = estimator.getLossWeights() if not loss_weights: loss_weights = [float(1) / num_labels for _ in range(num_labels)] else: if not isinstance(loss_weights, list) or \ len(loss_weights) != len(label_columns): raise ValueError('loss_weights needs to be a list with the same ' 'length as the label_columns.') # Data reader parameters train_reader_worker_count = estimator.getTrainReaderNumWorker() val_reader_worker_count = estimator.getValReaderNumWorker() # Utility functions deserialize = deserialize_fn() get_optimizer_with_unscaled_lr = _get_optimizer_with_unscaled_lr_fn() calculate_shuffle_buffer_size = _calculate_shuffle_buffer_size_fn() construct_metric_value_holders = _construct_metric_value_holders_fn() metric_cls = _metric_cls() prepare_np_data = _prepare_np_data_fn() get_metric_avgs = _get_metric_avgs_fn() update_metrics = _update_metrics_fn(metric_fn_groups) write_metrics_summary = _write_metrics_summary_fn() calculate_loss = _calculate_loss_fn() # Storage store = estimator.getStore() remote_store = store.to_remote(run_id, dataset_idx) is_dbfs = isinstance(store, DBFSLocalStore) @contextlib.contextmanager def empty_batch_reader(): yield None def train(serialized_model, optimizer_cls, model_opt_state_serialized, train_rows, val_rows, avg_row_size): from petastorm import TransformSpec, make_reader, make_batch_reader from petastorm.pytorch import BatchedDataLoader import torch import horovod.torch as hvd # Deserializing objects model_opt_state = torch.load(model_opt_state_serialized) model = deserialize(serialized_model) if loss_fns_pre_train: loss_fns = loss_fns_pre_train if loss_constructors: local_vars = locals() loss_fns = [ loss_constructor(**local_vars) for loss_constructor in loss_constructors ] # Horovod: initialize library. hvd.init() if not user_shuffle_buffer_size: shuffle_buffer_size = \ calculate_shuffle_buffer_size(hvd, avg_row_size, train_rows / hvd.size()) else: shuffle_buffer_size = user_shuffle_buffer_size cuda_available = torch.cuda.is_available() if cuda_available: # Horovod: pin GPU to local rank or the assigned GPU from spark. torch.cuda.set_device( _get_assigned_gpu_or_default(default=hvd.local_rank())) # Move model to GPU. model.cuda() # Optimizer object needs to be re-instantiated. Internally, it uses memory addresses of # objects as their identity and therefore it cannot be serialized and then # deserialized. The deserialized optimizer object stores the names of the parameters # with their old memory addresses but in reality those are different than the # reconstructed deserialized object and that creates problem. # Learning rate is a required parameters in SGD optimizer. It will be overridden with # load_state_dict. optimizer = optimizer_cls(model.parameters(), lr=1) optimizer_state = model_opt_state['optimizer'] if last_checkpoint_state is not None: model.load_state_dict(last_checkpoint_state['model']) optimizer.load_state_dict(last_checkpoint_state['optimizer']) else: # scale the learning rate with the number of horovod workers for i in range(len(optimizer_state['param_groups'])): optimizer_state['param_groups'][i]['lr'] = \ optimizer_state['param_groups'][i]['lr'] * hvd.size() optimizer.load_state_dict(optimizer_state) # Horovod: broadcast parameters & optimizer state. hvd.broadcast_parameters(model.state_dict(), root_rank=0) for group in optimizer.param_groups: for p in group['params']: if id(p) not in optimizer.state_dict()['state']: p.grad = p.data.new(p.size()).zero_() optimizer.step() hvd.broadcast_optimizer_state(optimizer, root_rank=0) dist_optimizer_args = dict(optimizer=optimizer, named_parameters=model.named_parameters()) if gradient_compression: # Pass the compression arg only if it is specified by the user. dist_optimizer_args['compression'] = gradient_compression # Horovod: wrap optimizer with DistributedOptimizer. optimizer = hvd.DistributedOptimizer(**dist_optimizer_args) # This function takes the current optimizer and constructs a new optimizer with the # same state except with learning rate scaled down with the number of horovod workers. # This is important the retraining of the model. User may retrain the model with # different number of workers and we need the raw learning rate to adjust with the # new number of workers. transform_spec = None if transformation: transform_spec = TransformSpec(transformation) schema_fields = feature_columns + label_columns if sample_weight_col: schema_fields.append(sample_weight_col) if train_steps_per_epoch is None: steps_per_epoch = int( math.floor(float(train_rows) / batch_size / hvd.size())) else: steps_per_epoch = train_steps_per_epoch with remote_store.get_local_output_dir() as run_output_dir: logs_dir = os.path.join(run_output_dir, remote_store.logs_subdir) log_writer = SummaryWriter(logs_dir) if hvd.rank() == 0 else None ckpt_file = os.path.join(run_output_dir, remote_store.checkpoint_filename) def save_checkpoint(): model.cpu() optimizer_with_scaled_down_lr = \ get_optimizer_with_unscaled_lr(hvd, optimizer, optimizer_cls, model) state = { 'model': model.state_dict(), 'optimizer': optimizer_with_scaled_down_lr.state_dict(), } torch.save(state, ckpt_file) if cuda_available: model.cuda() # In general, make_batch_reader is faster than make_reader for reading the dataset. # However, we found out that make_reader performs data transformations much faster than # make_batch_reader with parallel worker processes. Therefore, the default reader # we choose is make_batch_reader unless there are data transformations. reader_factory = None reader_factory_kwargs = dict() if transform_spec: reader_factory = make_reader reader_factory_kwargs['pyarrow_serialize'] = True else: reader_factory = make_batch_reader # Petastorm: read data from the store with the correct shard for this rank # setting num_epochs=None will cause an infinite iterator # and enables ranks to perform training and validation with # unequal number of samples with reader_factory(remote_store.train_data_path, num_epochs=1 if inmemory_cache_all else None, cur_shard=hvd.rank(), reader_pool_type='process', workers_count=train_reader_worker_count, shard_count=hvd.size(), hdfs_driver=PETASTORM_HDFS_DRIVER, schema_fields=schema_fields, transform_spec=transform_spec, **reader_factory_kwargs) as train_reader: with reader_factory(remote_store.val_data_path, num_epochs=1 if inmemory_cache_all else None, cur_shard=hvd.rank(), reader_pool_type='process', workers_count=val_reader_worker_count, shard_count=hvd.size(), hdfs_driver=PETASTORM_HDFS_DRIVER, schema_fields=schema_fields, transform_spec=transform_spec, **reader_factory_kwargs) \ if should_validate else empty_batch_reader() as val_reader: train_loader = BatchedDataLoader( train_reader, num_epochs=epochs if inmemory_cache_all else None, batch_size=batch_size, shuffling_queue_capacity=shuffle_buffer_size, inmemory_cache_all=inmemory_cache_all) train_loader_iter = iter(train_loader) def prepare_batch(row): inputs = [ prepare_np_data(row[col].float(), col, metadata).reshape(shape) for col, shape in zip(feature_columns, input_shapes) ] labels = [ prepare_np_data(row[col].float(), col, metadata) for col in label_columns ] sample_weights = row.get(sample_weight_col, None) if sample_weights is not None: sample_weights = sample_weights.float() if cuda_available: inputs = [input.cuda() for input in inputs] labels = [label.cuda() for label in labels] if sample_weights is not None: sample_weights = sample_weights.cuda() return inputs, labels, sample_weights def transform_outputs(outputs, labels): if not isinstance(outputs, tuple) and not isinstance( outputs, list): outputs = [outputs] # reshape labels to match the output shape of the model if hasattr(outputs[0], 'shape'): if label_shapes: labels = [ label.reshape(label_shape) for label, label_shape in zip( labels, label_shapes) ] else: # If label_shapes parameter is not provided, reshape the label # columns data to match the shape of the model output labels = [ label.reshape(output.shape) if output.shape.numel() == label.shape.numel() else label for label, output in zip(labels, outputs) ] return outputs, labels def aggregate_metrics(stage, epoch, loss, metric_value_groups): all_metric_groups_values = get_metric_avgs( metric_value_groups) if remote_store.saving_runs: write_metrics_summary(stage, epoch, loss, all_metric_groups_values, log_writer) return { loss.name: loss.avg.item(), 'all_metrics': all_metric_groups_values } def loss_fn(outputs, labels, sample_weights): loss = calculate_loss(outputs, labels, loss_weights, loss_fns, sample_weights) return loss def print_metrics(batch_idx, loss, metric_value_groups, phase): if user_verbose > 0 and hvd.rank() == 0 and \ batch_idx % METRIC_PRINT_FREQUENCY == 0: print( "{phase}\tepoch:\t{epoch}\tstep\t{batch_idx}:\t{metrics}" .format(phase=phase, epoch=epoch, batch_idx=batch_idx, metrics=aggregate_metrics( phase, epoch, loss, metric_value_groups))) def _train(epoch): model.train() train_loss = metric_cls('loss', hvd) metric_value_groups = construct_metric_value_holders( metric_cls, metric_fn_groups, label_columns, hvd) # iterate on one epoch for batch_idx in range(steps_per_epoch): row = next(train_loader_iter) inputs, labels, sample_weights = prepare_batch(row) outputs, loss = train_minibatch( model, optimizer, transform_outputs, loss_fn, inputs, labels, sample_weights) update_metrics(metric_value_groups, outputs, labels) train_loss.update(loss) print_metrics(batch_idx, train_loss, metric_value_groups, 'train') return aggregate_metrics('train', epoch, train_loss, metric_value_groups) if should_validate: val_loader = BatchedDataLoader( val_reader, num_epochs=epochs if inmemory_cache_all else None, batch_size=batch_size, inmemory_cache_all=inmemory_cache_all) val_loader_iter = iter(val_loader) if validation_steps_per_epoch is None: validation_steps = int( math.ceil( float(val_rows) / val_batch_size / hvd.size())) else: validation_steps = validation_steps_per_epoch def _validate(epoch): model.eval() val_loss = metric_cls('loss', hvd) metric_value_groups = construct_metric_value_holders( metric_cls, metric_fn_groups, label_columns, hvd) # iterate on one epoch for batch_idx in range(validation_steps): row = next(val_loader_iter) inputs, labels, sample_weights = prepare_batch( row) outputs = model(*inputs) outputs, labels = transform_outputs( outputs, labels) loss = calculate_loss(outputs, labels, loss_weights, loss_fns, sample_weights) val_loss.update(loss) update_metrics(metric_value_groups, outputs, labels) print_metrics(batch_idx, val_loss, metric_value_groups, 'val') return aggregate_metrics('val', epoch, val_loss, metric_value_groups) history = [] for epoch in range(epochs): epoch_metrics = { 'epoch': epoch, 'train': _train(epoch) } if should_validate: epoch_metrics['validation'] = _validate(epoch) if user_verbose > 0: pdt_dt = datetime.now(timezone.utc) pdt_time_str = pdt_dt.strftime( "%Y-%b-%d %H:%M:%S UTC") print(pdt_time_str, epoch_metrics) history.append(epoch_metrics) if hvd.rank() == 0: # Save model after every epoch save_checkpoint() if remote_store.saving_runs: remote_store.sync(run_output_dir) if hvd.rank() == 0: best_checkpoint = torch.load(ckpt_file) serialized_checkpoint = io.BytesIO() torch.save(best_checkpoint, serialized_checkpoint) serialized_checkpoint.seek(0) return history, serialized_checkpoint return train
def to_lightning_module(model, optimizer, loss_fns, loss_weights, feature_cols, label_cols, sample_weights_col, validation): optimizer_cls = optimizer.__class__ optimizer_state = optimizer.state_dict() loss_weights = loss_weights or [1.0 / len(label_cols)] * len(label_cols) loss_fns = to_list(loss_fns, len(label_cols)) class _EstimatorLightningModule(LightningModule): def __init__(self): super().__init__() self._model = model def forward(self, **kwargs): return self._model(**kwargs) def configure_optimizers(self): # Optimizer object needs to be re-instantiated. Internally, it uses memory addresses of # objects as their identity and therefore it cannot be serialized and then # deserialized. The deserialized optimizer object stores the names of the parameters # with their old memory addresses but in reality those are different than the # reconstructed deserialized object and that creates problem. # Learning rate is a required parameters in SGD optimizer. It will be overridden with # load_state_dict. optimizer = optimizer_cls(self.parameters(), lr=1) optimizer.load_state_dict(optimizer_state) return optimizer def training_step(self, batch, batch_nb): loss = self._step(batch) tensorboard_logs = {'train_loss': loss} return {'loss': loss, 'log': tensorboard_logs} def _step(self, batch): inputs = {feature: batch[feature].float() for feature in feature_cols} labels = [batch[label].float() for label in label_cols] sample_weights = batch[sample_weights_col].float() if sample_weights_col else None outputs = self(**inputs) outputs, labels = self._transform_outputs(outputs, labels) return self._calculate_loss(outputs, labels, sample_weights) def _transform_outputs(self, outputs, labels): if type(outputs) != tuple and type(outputs) != list: outputs = [outputs] # reshape labels to match the output shape of the model if hasattr(outputs[0], 'shape'): labels = [label.reshape(output.shape) if output.shape.numel() == label.shape.numel() else label for label, output in zip(labels, outputs)] return outputs, labels def _calculate_loss(self, outputs, labels, sample_weights=None): if sample_weights is not None: # when reduction='none', loss function returns the value of all the losses # from all the samples. We multiply each sample's weight to its loss and # then take the mean of the weight adjusted losses from all the samples in the # batch. Note that this approach is not "weighted average" because the sum of # the sample weights in each batch does not necessarily add up to one. If we add # the weights and divide the sum to the sum of weights, the impact of two # samples with identical weights but in different batches will not be equal on # the calculated gradients. losses = [] for output, label, loss_fn, loss_weight in zip(outputs, labels, loss_fns, loss_weights): weight_adjusted_sample_losses = \ loss_fn(output, label, reduction='none').flatten() * sample_weights output_loss = weight_adjusted_sample_losses.mean() losses.append(output_loss * loss_weight) else: losses = [loss_fn(output, label) * loss_weight for output, label, loss_fn, loss_weight in zip(outputs, labels, loss_fns, loss_weights)] loss = sum(losses) return loss lightning_module = _EstimatorLightningModule() if validation: def validation_step(batch, batch_nb): loss = lightning_module._step(batch) return {'val_loss': loss} lightning_module.validation_step = validation_step def validation_epoch_end(outputs): avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean() \ if len(outputs) > 0 else torch.tensor(float('inf')) tensorboard_logs = {'val_loss': avg_loss} return {'avg_val_loss': avg_loss, 'log': tensorboard_logs} lightning_module.validation_epoch_end = validation_epoch_end return lightning_module