def create_model(): print("Creating model") FLAGS.top_mlp_sizes = [int(s) for s in FLAGS.top_mlp_sizes] FLAGS.bottom_mlp_sizes = [int(s) for s in FLAGS.bottom_mlp_sizes] model_config = { 'top_mlp_sizes': FLAGS.top_mlp_sizes, 'bottom_mlp_sizes': FLAGS.bottom_mlp_sizes, 'embedding_type': FLAGS.embedding_type, 'embedding_dim': FLAGS.embedding_dim, 'interaction_op': FLAGS.interaction_op, 'categorical_feature_sizes': get_categorical_feature_sizes(FLAGS), 'num_numerical_features': FLAGS.num_numerical_features, 'hash_indices': FLAGS.hash_indices, 'use_cpp_mlp': FLAGS.optimized_mlp, 'fp16': FLAGS.amp, 'base_device': FLAGS.base_device, } model = Dlrm.from_dict(model_config) print(model) model.to(FLAGS.base_device) if FLAGS.load_checkpoint_path is not None: checkpoint_loader = make_serial_checkpoint_loader( embedding_indices=range(len(get_categorical_feature_sizes(FLAGS))), device="cpu") checkpoint_loader.load_checkpoint(model, FLAGS.load_checkpoint_path) model.to(FLAGS.base_device) return model
def create_synthetic_datasets(flags, device_mapping: Optional[Dict] = None): dataset_train = SyntheticDataset(num_entries=flags.synthetic_dataset_num_entries, batch_size=flags.batch_size, numerical_features=flags.num_numerical_features, categorical_feature_sizes=get_categorical_feature_sizes(flags), device_mapping=device_mapping) dataset_test = SyntheticDataset(num_entries=flags.synthetic_dataset_num_entries, batch_size=flags.test_batch_size, numerical_features=flags.num_numerical_features, categorical_feature_sizes=get_categorical_feature_sizes(flags), device_mapping=device_mapping) return dataset_train, dataset_test
def create_real_datasets( flags, path, dataset_class: type = SplitCriteoDataset, train_dataset_path="train", test_dataset_path="test", **kwargs ): train_dataset = os.path.join(path, train_dataset_path) test_dataset = os.path.join(path, test_dataset_path) categorical_sizes = get_categorical_feature_sizes(flags) dataset_train = dataset_class( data_path=train_dataset, batch_size=flags.batch_size, numerical_features=flags.num_numerical_features, categorical_features=range(len(categorical_sizes)), categorical_feature_sizes=categorical_sizes, **kwargs ) dataset_test = dataset_class( data_path=test_dataset, batch_size=flags.test_batch_size, numerical_features=flags.num_numerical_features, categorical_features=range(len(categorical_sizes)), categorical_feature_sizes=categorical_sizes, **kwargs ) return dataset_train, dataset_test
def create_datasets(self) -> Tuple[Dataset, Dataset]: train_dataset_path = os.path.join(self._flags.dataset, "train") test_dataset_path = os.path.join(self._flags.dataset, "test") categorical_sizes = get_categorical_feature_sizes(self._flags) # prefetching is currently unsupported if using the batch-wise shuffle prefetch_depth = 0 if self._flags.shuffle_batch_order else 10 dataset_train = SplitCriteoDataset( data_path=train_dataset_path, batch_size=self._flags.batch_size, numerical_features=self._numerical_features, categorical_features=self._categorical_features, categorical_feature_sizes=categorical_sizes, prefetch_depth=prefetch_depth ) dataset_test = SplitCriteoDataset( data_path=test_dataset_path, batch_size=self._flags.test_batch_size, numerical_features=self._numerical_features, categorical_features=self._categorical_features, categorical_feature_sizes=categorical_sizes, prefetch_depth=prefetch_depth ) return dataset_train, dataset_test
def create_datasets(self) -> Tuple[Dataset, Dataset]: train_dataset_path = os.path.join(self._flags.dataset, "train") test_dataset_path = os.path.join(self._flags.dataset, "test") categorical_sizes = get_categorical_feature_sizes(self._flags) dataset_train = SplitCriteoDataset( data_path=train_dataset_path, batch_size=self._flags.batch_size, numerical_features=self._numerical_features, categorical_features=self._categorical_features, categorical_feature_sizes=categorical_sizes ) dataset_test = SplitCriteoDataset( data_path=test_dataset_path, batch_size=self._flags.test_batch_size, numerical_features=self._numerical_features, categorical_features=self._categorical_features, categorical_feature_sizes=categorical_sizes ) return dataset_train, dataset_test
def create_real_datasets(flags, path, dataset_class: type = CriteoBinDataset): train_dataset = os.path.join(path, "train_data.bin") test_dataset = os.path.join(path, "test_data.bin") categorical_sizes = get_categorical_feature_sizes(flags) dataset_train = dataset_class( data_file=train_dataset, batch_size=flags.batch_size, subset=flags.dataset_subset, numerical_features=flags.num_numerical_features, categorical_features=len(categorical_sizes), ) dataset_test = dataset_class( data_file=test_dataset, batch_size=flags.test_batch_size, numerical_features=flags.num_numerical_features, categorical_features=len(categorical_sizes), ) return dataset_train, dataset_test
def create_dataset_factory(flags, device_mapping: Optional[dict] = None ) -> DatasetFactory: """ By default each dataset can be used in single GPU or distributed setting - please keep that in mind when adding new datasets. Distributed case requires selection of categorical features provided in `device_mapping` (see `DatasetFactory#create_collate_fn`). :param flags: :param device_mapping: dict, information about model bottom mlp and embeddings devices assignment :return: """ dataset_type = flags.dataset_type if dataset_type == "binary": return BinaryDatasetFactory(flags, device_mapping) if dataset_type == "split": if is_distributed(): assert device_mapping is not None, "Distributed dataset requires information about model device mapping." rank = get_rank() return SplitBinaryDatasetFactory( flags=flags, numerical_features=device_mapping["bottom_mlp"] == rank, categorical_features=device_mapping["embedding"][rank]) return SplitBinaryDatasetFactory( flags=flags, numerical_features=True, categorical_features=range( len(get_categorical_feature_sizes(flags)))) if dataset_type == "synthetic_gpu": return SyntheticGpuDatasetFactory(flags, device_mapping) if dataset_type == "synthetic_disk": return SyntheticDiskDatasetFactory(flags, device_mapping) raise NotImplementedError(f"unknown dataset type: {dataset_type}")
def train(model, loss_fn, optimizer, data_loader_train, data_loader_test, scaled_lr): """Train and evaluate the model Args: model (dlrm): loss_fn (torch.nn.Module): Loss function optimizer (torch.nn.optim): data_loader_train (torch.utils.data.DataLoader): data_loader_test (torch.utils.data.DataLoader): """ model.train() prefetching_enabled = is_data_prefetching_enabled() base_device = FLAGS.base_device print_freq = FLAGS.print_freq steps_per_epoch = len(data_loader_train) checkpoint_writer = make_serial_checkpoint_writer( embedding_indices=range(len(get_categorical_feature_sizes(FLAGS))), config=FLAGS.flag_values_dict()) test_freq = FLAGS.test_freq if FLAGS.test_freq is not None else steps_per_epoch - 1 metric_logger = utils.MetricLogger(delimiter=" ") metric_logger.add_meter( 'loss', utils.SmoothedValue(window_size=1, fmt='{value:.4f}')) metric_logger.add_meter( 'step_time', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) metric_logger.add_meter( 'lr', utils.SmoothedValue(window_size=1, fmt='{value:.4f}')) if prefetching_enabled: data_stream = torch.cuda.Stream() timer = utils.StepTimer() best_auc = 0 best_epoch = 0 start_time = time() timer.click() for epoch in range(FLAGS.epochs): input_pipeline = iter(data_loader_train) if prefetching_enabled: input_pipeline = prefetcher(input_pipeline, data_stream) for step, batch in enumerate(input_pipeline): global_step = steps_per_epoch * epoch + step numerical_features, categorical_features, click = batch utils.lr_step(optimizer, num_warmup_iter=FLAGS.warmup_steps, current_step=global_step + 1, base_lr=scaled_lr, warmup_factor=FLAGS.warmup_factor, decay_steps=FLAGS.decay_steps, decay_start_step=FLAGS.decay_start_step) if FLAGS.max_steps and global_step > FLAGS.max_steps: print( f"Reached max global steps of {FLAGS.max_steps}. Stopping." ) break if prefetching_enabled: torch.cuda.synchronize() output = model(numerical_features, categorical_features).squeeze().float() loss = loss_fn(output, click.squeeze()) # Setting grad to None is faster than zero_grad() for param_group in optimizer.param_groups: for param in param_group['params']: param.grad = None if FLAGS.amp: loss *= FLAGS.loss_scale with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() optimizer.step() if step % print_freq == 0 and step > 0: loss_value = loss.item() timer.click() if global_step < FLAGS.benchmark_warmup_steps: metric_logger.update(loss=loss_value, lr=optimizer.param_groups[0]["lr"]) else: unscale_factor = FLAGS.loss_scale if FLAGS.amp else 1 metric_logger.update( loss=loss_value / unscale_factor, step_time=timer.measured / FLAGS.print_freq, lr=optimizer.param_groups[0]["lr"] * unscale_factor) if global_step < FLAGS.benchmark_warmup_steps: print( f'Warming up, step [{global_step}/{FLAGS.benchmark_warmup_steps}]' ) continue eta_str = datetime.timedelta( seconds=int(metric_logger.step_time.global_avg * (steps_per_epoch - step))) metric_logger.print( header= f"Epoch:[{epoch}/{FLAGS.epochs}] [{step}/{steps_per_epoch}] eta: {eta_str}" ) if (global_step % test_freq == 0 and global_step > 0 and global_step / steps_per_epoch >= FLAGS.test_after): loss, auc, test_step_time = evaluate(model, loss_fn, data_loader_test) print( f"Epoch {epoch} step {step}. Test loss {loss:.5f}, auc {auc:.6f}" ) if auc > best_auc: best_auc = auc best_epoch = epoch + ((step + 1) / steps_per_epoch) maybe_save_checkpoint(checkpoint_writer, model, FLAGS.save_checkpoint_path) if FLAGS.auc_threshold and auc >= FLAGS.auc_threshold: stop_time = time() run_time_s = int(stop_time - start_time) print( f"Hit target accuracy AUC {FLAGS.auc_threshold} at epoch " f"{global_step/steps_per_epoch:.2f} in {run_time_s}s. " f"Average speed {global_step * FLAGS.batch_size / run_time_s:.1f} records/s." ) return stop_time = time() run_time_s = int(stop_time - start_time) print( f"Finished training in {run_time_s}s. " f"Average speed {global_step * FLAGS.batch_size / run_time_s:.1f} records/s." ) avg_throughput = FLAGS.batch_size / metric_logger.step_time.avg results = { 'best_auc': best_auc, 'best_epoch': best_epoch, 'average_train_throughput': avg_throughput } if 'test_step_time' in locals(): avg_test_throughput = FLAGS.test_batch_size / test_step_time results['average_test_throughput'] = avg_test_throughput dllogger.log(data=results, step=tuple())