def validate_flags(cat_feature_count): if FLAGS.max_table_size is not None and not FLAGS.hash_indices: raise ValueError('Hash indices must be True when setting a max_table_size') if FLAGS.base_device == 'cpu': if FLAGS.embedding_type in ('joint_fused', 'joint_sparse'): print('WARNING: CUDA joint embeddings are not supported on CPU') FLAGS.embedding_type = 'joint' if FLAGS.amp: print('WARNING: Automatic mixed precision not supported on CPU') FLAGS.amp = False if FLAGS.optimized_mlp: print('WARNING: Optimized MLP is not supported on CPU') FLAGS.optimized_mlp = False if FLAGS.embedding_type == 'custom_cuda': if (not is_distributed()) and FLAGS.embedding_dim == 128 and cat_feature_count == 26: FLAGS.embedding_type = 'joint_fused' else: FLAGS.embedding_type = 'joint_sparse' if FLAGS.embedding_type == 'joint_fused' and FLAGS.embedding_dim != 128: print('WARNING: Joint fused can be used only with embedding_dim=128. Changed embedding type to joint_sparse.') FLAGS.embedding_type = 'joint_sparse' if FLAGS.dataset is None and (FLAGS.dataset_type != 'synthetic_gpu' or FLAGS.synthetic_dataset_use_feature_spec): raise ValueError('Dataset argument has to specify a path to the dataset') FLAGS.inference_benchmark_batch_sizes = [int(x) for x in FLAGS.inference_benchmark_batch_sizes] FLAGS.top_mlp_sizes = [int(x) for x in FLAGS.top_mlp_sizes] FLAGS.bottom_mlp_sizes = [int(x) for x in FLAGS.bottom_mlp_sizes]
def create_embeddings(embedding_type: str, categorical_feature_sizes: Sequence[int], embedding_dim: int, device: str = "cuda", hash_indices: bool = False, fp16: bool = False) -> Embeddings: if embedding_type == "joint": return JointEmbedding(categorical_feature_sizes, embedding_dim, device=device, hash_indices=hash_indices) elif embedding_type == "joint_fused": assert not is_distributed(), "Joint fused embedding is not supported in the distributed mode. " \ "You may want to use 'joint_sparse' option instead." return FusedJointEmbedding(categorical_feature_sizes, embedding_dim, device=device, hash_indices=hash_indices, amp_train=fp16) elif embedding_type == "joint_sparse": return JointSparseEmbedding(categorical_feature_sizes, embedding_dim, device=device, hash_indices=hash_indices) elif embedding_type == "multi_table": return MultiTableEmbeddings(categorical_feature_sizes, embedding_dim, hash_indices=hash_indices, device=device) else: raise NotImplementedError(f"unknown embedding type: {embedding_type}")
def create_datasets(self) -> Tuple[Dataset, Dataset]: synthetic_train, synthetic_test = create_synthetic_datasets(self._flags) if is_distributed(): self._synchronized_write(synthetic_train, synthetic_test) else: self._write(synthetic_train, synthetic_test) return create_real_datasets(self._flags, self._flags.synthetic_dataset_dir)
def create_datasets(self) -> Tuple[Dataset, Dataset]: synthetic_train, synthetic_test = create_synthetic_datasets(self._flags) if is_distributed(): self._synchronized_write(synthetic_train, synthetic_test) else: self._write(synthetic_train, synthetic_test) return create_real_datasets( self._flags, self._flags.synthetic_dataset_dir, SplitCriteoDataset, "train", "test", prefetch_depth=10 )
def create_dataset_factory( flags, feature_spec: FeatureSpec, device_mapping: Optional[dict] = None) -> DatasetFactory: """ By default each dataset can be used in single GPU or distributed setting - please keep that in mind when adding new datasets. Distributed case requires selection of categorical features provided in `device_mapping` (see `DatasetFactory#create_collate_fn`). :param flags: :param device_mapping: dict, information about model bottom mlp and embeddings devices assignment :return: """ dataset_type = flags.dataset_type num_numerical_features = feature_spec.get_number_of_numerical_features() if is_distributed() or device_mapping: assert device_mapping is not None, "Distributed dataset requires information about model device mapping." rank = get_rank() local_categorical_positions = device_mapping["embedding"][rank] numerical_features_enabled = device_mapping["bottom_mlp"] == rank else: local_categorical_positions = list( range(len(feature_spec.get_categorical_feature_names()))) numerical_features_enabled = True if dataset_type == "parametric": local_categorical_names = feature_spec.cat_positions_to_names( local_categorical_positions) return ParametricDatasetFactory( flags=flags, feature_spec=feature_spec, numerical_features_enabled=numerical_features_enabled, categorical_features_to_read=local_categorical_names) if dataset_type == "synthetic_gpu": local_numerical_features = num_numerical_features if numerical_features_enabled else 0 world_categorical_sizes = feature_spec.get_categorical_sizes() local_categorical_sizes = [ world_categorical_sizes[i] for i in local_categorical_positions ] return SyntheticGpuDatasetFactory( flags, local_numerical_features_num=local_numerical_features, local_categorical_feature_sizes=local_categorical_sizes) raise NotImplementedError(f"unknown dataset type: {dataset_type}")
def create_dataset_factory(flags, device_mapping: Optional[dict] = None ) -> DatasetFactory: """ By default each dataset can be used in single GPU or distributed setting - please keep that in mind when adding new datasets. Distributed case requires selection of categorical features provided in `device_mapping` (see `DatasetFactory#create_collate_fn`). :param flags: :param device_mapping: dict, information about model bottom mlp and embeddings devices assignment :return: """ dataset_type = flags.dataset_type if dataset_type == "binary": return BinaryDatasetFactory(flags, device_mapping) if dataset_type == "split": if is_distributed(): assert device_mapping is not None, "Distributed dataset requires information about model device mapping." rank = get_rank() return SplitBinaryDatasetFactory( flags=flags, numerical_features=device_mapping["bottom_mlp"] == rank, categorical_features=device_mapping["embedding"][rank]) return SplitBinaryDatasetFactory( flags=flags, numerical_features=True, categorical_features=range( len(get_categorical_feature_sizes(flags)))) if dataset_type == "synthetic_gpu": return SyntheticGpuDatasetFactory(flags, device_mapping) if dataset_type == "synthetic_disk": return SyntheticDiskDatasetFactory(flags, device_mapping) raise NotImplementedError(f"unknown dataset type: {dataset_type}")
def create_sampler(self, dataset: Dataset) -> Optional[Sampler]: return RandomDistributedSampler( dataset) if is_distributed() else RandomSampler(dataset)
def main(argv): torch.manual_seed(FLAGS.seed) utils.init_logging(log_path=FLAGS.log_path) use_gpu = "cpu" not in FLAGS.base_device.lower() rank, world_size, gpu = dist.init_distributed_mode(backend=FLAGS.backend, use_gpu=use_gpu) device = FLAGS.base_device if not is_distributed(): raise NotImplementedError( "This file is only for distributed training.") if is_main_process(): dllogger.log(data=FLAGS.flag_values_dict(), step='PARAMETER') print("Command line flags:") pprint(FLAGS.flag_values_dict()) print("Creating data loaders") FLAGS.set_default("test_batch_size", FLAGS.test_batch_size // world_size * world_size) categorical_feature_sizes = get_categorical_feature_sizes(FLAGS) world_categorical_feature_sizes = np.asarray(categorical_feature_sizes) device_mapping = get_device_mapping(categorical_feature_sizes, num_gpus=world_size) batch_sizes_per_gpu = get_gpu_batch_sizes(FLAGS.batch_size, num_gpus=world_size) batch_indices = tuple(np.cumsum([0] + list(batch_sizes_per_gpu))) # sizes of embeddings for each GPU categorical_feature_sizes = world_categorical_feature_sizes[ device_mapping['embedding'][rank]].tolist() bottom_mlp_sizes = FLAGS.bottom_mlp_sizes if rank == device_mapping[ 'bottom_mlp'] else None data_loader_train, data_loader_test = get_data_loaders( FLAGS, device_mapping=device_mapping) model = DistributedDlrm( vectors_per_gpu=device_mapping['vectors_per_gpu'], embedding_device_mapping=device_mapping['embedding'], embedding_type=FLAGS.embedding_type, embedding_dim=FLAGS.embedding_dim, world_num_categorical_features=len(world_categorical_feature_sizes), categorical_feature_sizes=categorical_feature_sizes, num_numerical_features=FLAGS.num_numerical_features, hash_indices=FLAGS.hash_indices, bottom_mlp_sizes=bottom_mlp_sizes, top_mlp_sizes=FLAGS.top_mlp_sizes, interaction_op=FLAGS.interaction_op, fp16=FLAGS.amp, use_cpp_mlp=FLAGS.optimized_mlp, bottom_features_ordered=FLAGS.bottom_features_ordered, device=device) print(model) print(device_mapping) print(f"Batch sizes per gpu: {batch_sizes_per_gpu}") dist.setup_distributed_print(is_main_process()) # DDP introduces a gradient average through allreduce(mean), which doesn't apply to bottom model. # Compensate it with further scaling lr scaled_lr = FLAGS.lr / FLAGS.loss_scale if FLAGS.amp else FLAGS.lr scaled_lrs = [scaled_lr / world_size, scaled_lr] embedding_optimizer = torch.optim.SGD([ { 'params': model.bottom_model.embeddings.parameters(), 'lr': scaled_lrs[0] }, ]) mlp_optimizer = apex_optim.FusedSGD([{ 'params': model.bottom_model.mlp.parameters(), 'lr': scaled_lrs[0] }, { 'params': model.top_model.parameters(), 'lr': scaled_lrs[1] }]) checkpoint_writer = make_distributed_checkpoint_writer( device_mapping=device_mapping, rank=rank, is_main_process=is_main_process(), config=FLAGS.flag_values_dict()) checkpoint_loader = make_distributed_checkpoint_loader( device_mapping=device_mapping, rank=rank) if FLAGS.load_checkpoint_path: checkpoint_loader.load_checkpoint(model, FLAGS.load_checkpoint_path) model.to(device) if FLAGS.amp: (model.top_model, model.bottom_model.mlp), mlp_optimizer = amp.initialize( [model.top_model, model.bottom_model.mlp], mlp_optimizer, opt_level="O2", loss_scale=1) if use_gpu: model.top_model = parallel.DistributedDataParallel(model.top_model) else: # Use other backend for CPU model.top_model = torch.nn.parallel.DistributedDataParallel( model.top_model) if FLAGS.mode == 'test': auc = dist_evaluate(model, data_loader_test) results = {'auc': auc} dllogger.log(data=results, step=tuple()) if auc is not None: print(F"Finished testing. Test auc {auc:.4f}") return if FLAGS.save_checkpoint_path and not FLAGS.bottom_features_ordered and is_main_process( ): logging.warning( "Saving checkpoint without --bottom_features_ordered flag will result in " "a device-order dependent model. Consider using --bottom_features_ordered " "if you plan to load the checkpoint in different device configurations." ) loss_fn = torch.nn.BCEWithLogitsLoss(reduction="mean") # Print per 16384 * 2000 samples by default default_print_freq = 16384 * 2000 // FLAGS.batch_size print_freq = default_print_freq if FLAGS.print_freq is None else FLAGS.print_freq steps_per_epoch = len(data_loader_train) test_freq = FLAGS.test_freq if FLAGS.test_freq is not None else steps_per_epoch - 1 metric_logger = utils.MetricLogger(delimiter=" ") metric_logger.add_meter( 'loss', utils.SmoothedValue(window_size=1, fmt='{avg:.4f}')) metric_logger.add_meter( 'step_time', utils.SmoothedValue(window_size=1, fmt='{avg:.6f}')) metric_logger.add_meter( 'lr', utils.SmoothedValue(window_size=1, fmt='{value:.4f}')) # Accumulating loss on GPU to avoid memcpyD2H every step moving_loss = torch.zeros(1, device=device) moving_loss_stream = torch.cuda.Stream() lr_scheduler = utils.LearningRateScheduler( optimizers=[mlp_optimizer, embedding_optimizer], base_lrs=[scaled_lrs, [scaled_lrs[0]]], warmup_steps=FLAGS.warmup_steps, warmup_factor=FLAGS.warmup_factor, decay_start_step=FLAGS.decay_start_step, decay_steps=FLAGS.decay_steps, decay_power=FLAGS.decay_power, end_lr_factor=FLAGS.decay_end_lr / FLAGS.lr) data_stream = torch.cuda.Stream() timer = utils.StepTimer() best_auc = 0 best_epoch = 0 start_time = time() stop_time = time() for epoch in range(FLAGS.epochs): epoch_start_time = time() batch_iter = prefetcher(iter(data_loader_train), data_stream) for step in range(len(data_loader_train)): timer.click() numerical_features, categorical_features, click = next(batch_iter) torch.cuda.synchronize() global_step = steps_per_epoch * epoch + step if FLAGS.max_steps and global_step > FLAGS.max_steps: print( F"Reached max global steps of {FLAGS.max_steps}. Stopping." ) break lr_scheduler.step() if click.shape[0] != FLAGS.batch_size: # last batch logging.error("The last batch with size %s is not supported", click.shape[0]) else: output = model(numerical_features, categorical_features, batch_sizes_per_gpu).squeeze() loss = loss_fn( output, click[batch_indices[rank]:batch_indices[rank + 1]]) # We don't need to accumulate gradient. Set grad to None is faster than optimizer.zero_grad() for param_group in itertools.chain( embedding_optimizer.param_groups, mlp_optimizer.param_groups): for param in param_group['params']: param.grad = None if FLAGS.amp: loss *= FLAGS.loss_scale with amp.scale_loss(loss, mlp_optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() mlp_optimizer.step() embedding_optimizer.step() moving_loss_stream.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(moving_loss_stream): moving_loss += loss if timer.measured is None: # first iteration, no step time etc. to print continue if step == 0: print(F"Started epoch {epoch}...") elif step % print_freq == 0: torch.cuda.current_stream().wait_stream(moving_loss_stream) # Averaging cross a print_freq period to reduce the error. # An accurate timing needs synchronize which would slow things down. if global_step < FLAGS.benchmark_warmup_steps: metric_logger.update( loss=moving_loss.item() / print_freq / (FLAGS.loss_scale if FLAGS.amp else 1), lr=mlp_optimizer.param_groups[1]["lr"] * (FLAGS.loss_scale if FLAGS.amp else 1)) else: metric_logger.update( step_time=timer.measured, loss=moving_loss.item() / print_freq / (FLAGS.loss_scale if FLAGS.amp else 1), lr=mlp_optimizer.param_groups[1]["lr"] * (FLAGS.loss_scale if FLAGS.amp else 1)) stop_time = time() eta_str = datetime.timedelta( seconds=int(metric_logger.step_time.global_avg * (steps_per_epoch - step))) metric_logger.print( header= F"Epoch:[{epoch}/{FLAGS.epochs}] [{step}/{steps_per_epoch}] eta: {eta_str}" ) with torch.cuda.stream(moving_loss_stream): moving_loss = 0. if global_step % test_freq == 0 and global_step > 0 and global_step / steps_per_epoch >= FLAGS.test_after: auc = dist_evaluate(model, data_loader_test) if auc is None: continue print(F"Epoch {epoch} step {step}. auc {auc:.6f}") stop_time = time() if auc > best_auc: best_auc = auc best_epoch = epoch + ((step + 1) / steps_per_epoch) if FLAGS.auc_threshold and auc >= FLAGS.auc_threshold: run_time_s = int(stop_time - start_time) print( F"Hit target accuracy AUC {FLAGS.auc_threshold} at epoch " F"{global_step/steps_per_epoch:.2f} in {run_time_s}s. " F"Average speed {global_step * FLAGS.batch_size / run_time_s:.1f} records/s." ) sys.exit() epoch_stop_time = time() epoch_time_s = epoch_stop_time - epoch_start_time print( F"Finished epoch {epoch} in {datetime.timedelta(seconds=int(epoch_time_s))}. " F"Average speed {steps_per_epoch * FLAGS.batch_size / epoch_time_s:.1f} records/s." ) avg_throughput = FLAGS.batch_size / metric_logger.step_time.avg if FLAGS.save_checkpoint_path: checkpoint_writer.save_checkpoint(model, FLAGS.save_checkpoint_path, epoch, step) results = { 'best_auc': best_auc, 'best_epoch': best_epoch, 'average_train_throughput': avg_throughput } dllogger.log(data=results, step=tuple())