def main(argv): rank, world_size, gpu = dist.init_distributed_mode() top_mlp = create_top_mlp().to("cuda") print(top_mlp) optimizer = torch.optim.SGD(top_mlp.parameters(), lr=1.) if FLAGS.fp16: top_mlp, optimizer = amp.initialize(top_mlp, optimizer, opt_level="O1", loss_scale=1) if world_size > 1: top_mlp = parallel.DistributedDataParallel(top_mlp) model_without_ddp = top_mlp.module dummy_bottom_mlp_output = torch.rand(FLAGS.batch_size, EMBED_DIM, device="cuda") dummy_embedding_output = torch.rand(FLAGS.batch_size, 26 * EMBED_DIM, device="cuda") dummy_target = torch.ones(FLAGS.batch_size, device="cuda") if FLAGS.fp16: dummy_bottom_mlp_output = dummy_bottom_mlp_output.to(torch.half) dummy_embedding_output = dummy_embedding_output.to(torch.half) # warm up GPU for _ in range(100): interaction_out = dot_interaction(dummy_bottom_mlp_output, [dummy_embedding_output], FLAGS.batch_size) output = top_mlp(interaction_out) start_time = utils.timer_start() for _ in range(FLAGS.num_iters): interaction_out = dot_interaction(dummy_bottom_mlp_output, [dummy_embedding_output], FLAGS.batch_size) output = top_mlp(interaction_out).squeeze() dummy_loss = output.mean() optimizer.zero_grad() if FLAGS.fp16: with amp.scale_loss(dummy_loss, optimizer) as scaled_loss: scaled_loss.backward() else: dummy_loss.backward() optimizer.step() stop_time = utils.timer_stop() elapsed_time = (stop_time - start_time) / FLAGS.num_iters * 1e3 print(F"Average step time: {elapsed_time:.4f} ms.")
def main(argv): if FLAGS.seed is not None: torch.manual_seed(FLAGS.seed) np.random.seed(FLAGS.seed) # Initialize distributed mode use_gpu = "cpu" not in FLAGS.device.lower() rank, world_size, gpu = dist.init_distributed_mode(backend=FLAGS.backend, use_gpu=use_gpu) if world_size == 1: raise NotImplementedError( "This file is only for distributed training.") mlperf_logger.mlperf_submission_log('dlrm') mlperf_logger.log_event(key=mlperf_logger.constants.SEED, value=FLAGS.seed) mlperf_logger.log_event(key=mlperf_logger.constants.GLOBAL_BATCH_SIZE, value=FLAGS.batch_size) # Only print cmd args on rank 0 if rank == 0: print("Command line flags:") pprint(FLAGS.flag_values_dict()) # Check arguments sanity if FLAGS.batch_size % world_size != 0: raise ValueError( F"Batch size {FLAGS.batch_size} is not divisible by world_size {world_size}." ) if FLAGS.test_batch_size % world_size != 0: raise ValueError( F"Test batch size {FLAGS.test_batch_size} is not divisible by world_size {world_size}." ) # Load config file, create sub config for each rank with open(FLAGS.model_config, "r") as f: config = json.loads(f.read()) wolrd_categorical_feature_sizes = np.asarray( config.pop('categorical_feature_sizes')) device_mapping = dist_model.get_criteo_device_mapping(world_size) vectors_per_gpu = device_mapping['vectors_per_gpu'] # Get sizes of embeddings each GPU is gonna create categorical_feature_sizes = wolrd_categorical_feature_sizes[ device_mapping['embedding'][rank]].tolist() bottom_mlp_sizes = config.pop('bottom_mlp_sizes') if rank != device_mapping['bottom_mlp']: bottom_mlp_sizes = None model = dist_model.DistDlrm( categorical_feature_sizes=categorical_feature_sizes, bottom_mlp_sizes=bottom_mlp_sizes, world_num_categorical_features=len(wolrd_categorical_feature_sizes), **config, device=FLAGS.device, use_embedding_ext=FLAGS.use_embedding_ext) print(model) dist.setup_distributed_print(rank == 0) # DDP introduces a gradient average through allreduce(mean), which doesn't apply to bottom model. # Compensate it with further scaling lr scaled_lr = FLAGS.lr / FLAGS.loss_scale if FLAGS.fp16 else FLAGS.lr scaled_lrs = [scaled_lr / world_size, scaled_lr] embedding_optimizer = torch.optim.SGD([ { 'params': model.bottom_model.joint_embedding.parameters(), 'lr': scaled_lrs[0] }, ]) mlp_optimizer = apex_optim.FusedSGD([{ 'params': model.bottom_model.bottom_mlp.parameters(), 'lr': scaled_lrs[0] }, { 'params': model.top_model.parameters(), 'lr': scaled_lrs[1] }]) if FLAGS.fp16: (model.top_model, model.bottom_model.bottom_mlp), mlp_optimizer = amp.initialize( [model.top_model, model.bottom_model.bottom_mlp], mlp_optimizer, opt_level="O2", loss_scale=1, cast_model_outputs=torch.float16) if use_gpu: model.top_model = parallel.DistributedDataParallel(model.top_model) else: # Use other backend for CPU model.top_model = torch.nn.parallel.DistributedDataParallel( model.top_model) loss_fn = torch.nn.BCEWithLogitsLoss(reduction="mean") # Too many arguments to pass for distributed training. Use plain train code here instead of # defining a train function # Print per 16384 * 2000 samples by default default_print_freq = 16384 * 2000 // FLAGS.batch_size print_freq = default_print_freq if FLAGS.print_freq is None else FLAGS.print_freq metric_logger = utils.MetricLogger(delimiter=" ") metric_logger.add_meter( 'loss', utils.SmoothedValue(window_size=1, fmt='{avg:.4f}')) metric_logger.add_meter( 'step_time', utils.SmoothedValue(window_size=1, fmt='{avg:.4f} ms')) metric_logger.add_meter( 'lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) # Accumulating loss on GPU to avoid memcpyD2H every step moving_loss = torch.zeros(1, device=FLAGS.device) moving_loss_stream = torch.cuda.Stream() local_embedding_device_mapping = torch.tensor( device_mapping['embedding'][rank], device=FLAGS.device, dtype=torch.long) # LR is logged twice for now because of a compliance checker bug mlperf_logger.log_event(key=mlperf_logger.constants.OPT_BASE_LR, value=FLAGS.lr) mlperf_logger.log_event(key=mlperf_logger.constants.OPT_LR_WARMUP_STEPS, value=FLAGS.warmup_steps) # use logging keys from the official HP table and not from the logging library mlperf_logger.log_event(key='sgd_opt_base_learning_rate', value=FLAGS.lr) mlperf_logger.log_event(key='lr_decay_start_steps', value=FLAGS.decay_start_step) mlperf_logger.log_event(key='sgd_opt_learning_rate_decay_steps', value=FLAGS.decay_steps) mlperf_logger.log_event(key='sgd_opt_learning_rate_decay_poly_power', value=FLAGS.decay_power) lr_scheduler = utils.LearningRateScheduler( optimizers=[mlp_optimizer, embedding_optimizer], base_lrs=[scaled_lrs, [scaled_lrs[0]]], warmup_steps=FLAGS.warmup_steps, warmup_factor=FLAGS.warmup_factor, decay_start_step=FLAGS.decay_start_step, decay_steps=FLAGS.decay_steps, decay_power=FLAGS.decay_power, end_lr_factor=FLAGS.decay_end_lr / FLAGS.lr) data_stream = torch.cuda.Stream() eval_data_cache = [] if FLAGS.cache_eval_data else None start_time = time() stop_time = time() print("Creating data loaders") dist_dataset_args = { "numerical_features": rank == 0, "categorical_features": device_mapping['embedding'][rank] } mlperf_logger.barrier() mlperf_logger.log_end(key=mlperf_logger.constants.INIT_STOP) mlperf_logger.barrier() mlperf_logger.log_start(key=mlperf_logger.constants.RUN_START) mlperf_logger.barrier() data_loader_train, data_loader_test = dataset.get_data_loader( FLAGS.dataset, FLAGS.batch_size, FLAGS.test_batch_size, FLAGS.device, dataset_type=FLAGS.dataset_type, shuffle=FLAGS.shuffle, **dist_dataset_args) steps_per_epoch = len(data_loader_train) # Default 20 tests per epoch test_freq = FLAGS.test_freq if FLAGS.test_freq is not None else steps_per_epoch // 20 for epoch in range(FLAGS.epochs): epoch_start_time = time() mlperf_logger.barrier() mlperf_logger.log_start(key=mlperf_logger.constants.BLOCK_START, metadata={ mlperf_logger.constants.FIRST_EPOCH_NUM: epoch + 1, mlperf_logger.constants.EPOCH_COUNT: 1 }) mlperf_logger.barrier() mlperf_logger.log_start( key=mlperf_logger.constants.EPOCH_START, metadata={mlperf_logger.constants.EPOCH_NUM: epoch + 1}) if FLAGS.profile_steps is not None: torch.cuda.profiler.start() for step, (numerical_features, categorical_features, click) in enumerate( dataset.prefetcher(iter(data_loader_train), data_stream)): torch.cuda.current_stream().wait_stream(data_stream) global_step = steps_per_epoch * epoch + step lr_scheduler.step() # Slice out categorical features if not using the "dist" dataset if FLAGS.dataset_type != "dist": categorical_features = categorical_features[:, local_embedding_device_mapping] if FLAGS.fp16 and categorical_features is not None: numerical_features = numerical_features.to(torch.float16) last_batch_size = None if click.shape[0] != FLAGS.batch_size: # last batch last_batch_size = click.shape[0] logging.debug("Pad the last batch of size %d to %d", last_batch_size, FLAGS.batch_size) padding_size = FLAGS.batch_size - last_batch_size padding_numiercal = torch.empty( padding_size, numerical_features.shape[1], device=numerical_features.device, dtype=numerical_features.dtype) numerical_features = torch.cat( (numerical_features, padding_numiercal), dim=0) if categorical_features is not None: padding_categorical = torch.ones( padding_size, categorical_features.shape[1], device=categorical_features.device, dtype=categorical_features.dtype) categorical_features = torch.cat( (categorical_features, padding_categorical), dim=0) padding_click = torch.empty(padding_size, device=click.device, dtype=click.dtype) click = torch.cat((click, padding_click)) bottom_out = model.bottom_model(numerical_features, categorical_features) batch_size_per_gpu = FLAGS.batch_size // world_size from_bottom = dist_model.bottom_to_top(bottom_out, batch_size_per_gpu, config['embedding_dim'], vectors_per_gpu) if last_batch_size is not None: partial_rank = math.ceil(last_batch_size / batch_size_per_gpu) if rank == partial_rank: top_out = model.top_model( from_bottom[:last_batch_size % batch_size_per_gpu]).squeeze().float() loss = loss_fn( top_out, click[rank * batch_size_per_gpu:(rank + 1) * batch_size_per_gpu][:last_batch_size % batch_size_per_gpu]) elif rank < partial_rank: loss = loss_fn( model.top_model(from_bottom).squeeze().float(), click[rank * batch_size_per_gpu:(rank + 1) * batch_size_per_gpu]) else: # Back propgate nothing for padded samples loss = 0. * model.top_model( from_bottom).squeeze().float().mean() else: loss = loss_fn( model.top_model(from_bottom).squeeze().float(), click[rank * batch_size_per_gpu:(rank + 1) * batch_size_per_gpu]) # We don't need to accumulate gradient. Set grad to None is faster than optimizer.zero_grad() for param_group in itertools.chain( embedding_optimizer.param_groups, mlp_optimizer.param_groups): for param in param_group['params']: param.grad = None if FLAGS.fp16: loss *= FLAGS.loss_scale with amp.scale_loss(loss, mlp_optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() mlp_optimizer.step() embedding_optimizer.step() moving_loss_stream.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(moving_loss_stream): moving_loss += loss if step == 0: print(F"Started epoch {epoch}...") elif step % print_freq == 0: torch.cuda.synchronize() # Averaging cross a print_freq period to reduce the error. # An accurate timing needs synchronize which would slow things down. metric_logger.update(step_time=(time() - stop_time) * 1000 / print_freq, loss=moving_loss.item() / print_freq / (FLAGS.loss_scale if FLAGS.fp16 else 1), lr=mlp_optimizer.param_groups[1]["lr"] * (FLAGS.loss_scale if FLAGS.fp16 else 1)) stop_time = time() eta_str = datetime.timedelta( seconds=int(metric_logger.step_time.avg / 1000 * (steps_per_epoch - step))) metric_logger.print( header= F"Epoch:[{epoch}/{FLAGS.epochs}] [{step}/{steps_per_epoch}] eta: {eta_str}" ) moving_loss = 0. with torch.cuda.stream(moving_loss_stream): moving_loss = 0. if global_step % test_freq == 0 and global_step > 0 and global_step / steps_per_epoch >= FLAGS.test_after: mlperf_epoch_index = global_step / steps_per_epoch + 1 mlperf_logger.barrier() mlperf_logger.log_start(key=mlperf_logger.constants.EVAL_START, metadata={ mlperf_logger.constants.EPOCH_NUM: mlperf_epoch_index }) auc = dist_evaluate(model, data_loader_test, eval_data_cache) mlperf_logger.log_event( key=mlperf_logger.constants.EVAL_ACCURACY, value=float(auc), metadata={ mlperf_logger.constants.EPOCH_NUM: mlperf_epoch_index }) print(F"Epoch {epoch} step {step}. auc {auc:.6f}") stop_time = time() mlperf_logger.barrier() mlperf_logger.log_end(key=mlperf_logger.constants.EVAL_STOP, metadata={ mlperf_logger.constants.EPOCH_NUM: mlperf_epoch_index }) if auc > FLAGS.auc_threshold: mlperf_logger.barrier() mlperf_logger.log_end(key=mlperf_logger.constants.RUN_STOP, metadata={ mlperf_logger.constants.STATUS: mlperf_logger.constants.SUCCESS }) mlperf_logger.barrier() mlperf_logger.log_end( key=mlperf_logger.constants.EPOCH_STOP, metadata={ mlperf_logger.constants.EPOCH_NUM: epoch + 1 }) mlperf_logger.barrier() mlperf_logger.log_end( key=mlperf_logger.constants.BLOCK_STOP, metadata={ mlperf_logger.constants.FIRST_EPOCH_NUM: epoch + 1 }) run_time_s = int(stop_time - start_time) print( F"Hit target accuracy AUC {FLAGS.auc_threshold} at epoch " F"{global_step/steps_per_epoch:.2f} in {run_time_s}s. " F"Average speed {global_step * FLAGS.batch_size / run_time_s:.1f} records/s." ) return if FLAGS.profile_steps is not None and global_step == FLAGS.profile_steps: torch.cuda.profiler.stop() logging.warning("Profile run, stopped at step %d.", global_step) return mlperf_logger.barrier() mlperf_logger.log_end( key=mlperf_logger.constants.EPOCH_STOP, metadata={mlperf_logger.constants.EPOCH_NUM: epoch + 1}) mlperf_logger.barrier() mlperf_logger.log_end( key=mlperf_logger.constants.BLOCK_STOP, metadata={mlperf_logger.constants.FIRST_EPOCH_NUM: epoch + 1}) epoch_stop_time = time() epoch_time_s = epoch_stop_time - epoch_start_time print( F"Finished epoch {epoch} in {datetime.timedelta(seconds=int(epoch_time_s))}. " F"Average speed {steps_per_epoch * FLAGS.batch_size / epoch_time_s:.1f} records/s." ) mlperf_logger.barrier() mlperf_logger.log_end(key=mlperf_logger.constants.RUN_STOP, metadata={ mlperf_logger.constants.STATUS: mlperf_logger.constants.ABORTED })
def main(argv): torch.manual_seed(FLAGS.seed) utils.init_logging(log_path=FLAGS.log_path) use_gpu = "cpu" not in FLAGS.base_device.lower() rank, world_size, gpu = dist.init_distributed_mode(backend=FLAGS.backend, use_gpu=use_gpu) device = FLAGS.base_device if is_main_process(): dllogger.log(data=FLAGS.flag_values_dict(), step='PARAMETER') print("Command line flags:") pprint(FLAGS.flag_values_dict()) print("Creating data loaders") FLAGS.set_default("test_batch_size", FLAGS.test_batch_size // world_size * world_size) categorical_feature_sizes = get_categorical_feature_sizes(FLAGS) world_categorical_feature_sizes = np.asarray(categorical_feature_sizes) device_mapping = get_device_mapping(categorical_feature_sizes, num_gpus=world_size) batch_sizes_per_gpu = get_gpu_batch_sizes(FLAGS.batch_size, num_gpus=world_size) batch_indices = tuple(np.cumsum([0] + list(batch_sizes_per_gpu))) # sizes of embeddings for each GPU categorical_feature_sizes = world_categorical_feature_sizes[ device_mapping['embedding'][rank]].tolist() bottom_mlp_sizes = FLAGS.bottom_mlp_sizes if rank == device_mapping[ 'bottom_mlp'] else None data_loader_train, data_loader_test = get_data_loaders( FLAGS, device_mapping=device_mapping) model = DistributedDlrm( vectors_per_gpu=device_mapping['vectors_per_gpu'], embedding_device_mapping=device_mapping['embedding'], embedding_type=FLAGS.embedding_type, embedding_dim=FLAGS.embedding_dim, world_num_categorical_features=len(world_categorical_feature_sizes), categorical_feature_sizes=categorical_feature_sizes, num_numerical_features=FLAGS.num_numerical_features, hash_indices=FLAGS.hash_indices, bottom_mlp_sizes=bottom_mlp_sizes, top_mlp_sizes=FLAGS.top_mlp_sizes, interaction_op=FLAGS.interaction_op, fp16=FLAGS.amp, use_cpp_mlp=FLAGS.optimized_mlp, bottom_features_ordered=FLAGS.bottom_features_ordered, device=device) print(model) print(device_mapping) print(f"Batch sizes per gpu: {batch_sizes_per_gpu}") dist.setup_distributed_print(is_main_process()) # DDP introduces a gradient average through allreduce(mean), which doesn't apply to bottom model. # Compensate it with further scaling lr scaled_lr = FLAGS.lr / FLAGS.loss_scale if FLAGS.amp else FLAGS.lr if FLAGS.Adam_embedding_optimizer: embedding_model_parallel_lr = scaled_lr else: embedding_model_parallel_lr = scaled_lr / world_size if FLAGS.Adam_MLP_optimizer: MLP_model_parallel_lr = scaled_lr else: MLP_model_parallel_lr = scaled_lr / world_size data_parallel_lr = scaled_lr if is_main_process(): mlp_params = [{ 'params': list(model.top_model.parameters()), 'lr': data_parallel_lr }, { 'params': list(model.bottom_model.mlp.parameters()), 'lr': MLP_model_parallel_lr }] mlp_lrs = [data_parallel_lr, MLP_model_parallel_lr] else: mlp_params = [{ 'params': list(model.top_model.parameters()), 'lr': data_parallel_lr }] mlp_lrs = [data_parallel_lr] if FLAGS.Adam_MLP_optimizer: mlp_optimizer = apex_optim.FusedAdam(mlp_params) else: mlp_optimizer = apex_optim.FusedSGD(mlp_params) embedding_params = [{ 'params': list(model.bottom_model.embeddings.parameters()), 'lr': embedding_model_parallel_lr }] embedding_lrs = [embedding_model_parallel_lr] if FLAGS.Adam_embedding_optimizer: embedding_optimizer = torch.optim.SparseAdam(embedding_params) else: embedding_optimizer = torch.optim.SGD(embedding_params) checkpoint_writer = make_distributed_checkpoint_writer( device_mapping=device_mapping, rank=rank, is_main_process=is_main_process(), config=FLAGS.flag_values_dict()) checkpoint_loader = make_distributed_checkpoint_loader( device_mapping=device_mapping, rank=rank) if FLAGS.load_checkpoint_path: checkpoint_loader.load_checkpoint(model, FLAGS.load_checkpoint_path) model.to(device) if FLAGS.amp: (model.top_model, model.bottom_model.mlp), mlp_optimizer = amp.initialize( [model.top_model, model.bottom_model.mlp], mlp_optimizer, opt_level="O2", loss_scale=1) if use_gpu: model.top_model = parallel.DistributedDataParallel(model.top_model) else: # Use other backend for CPU model.top_model = torch.nn.parallel.DistributedDataParallel( model.top_model) if FLAGS.mode == 'test': auc = dist_evaluate(model, data_loader_test) results = {'auc': auc} dllogger.log(data=results, step=tuple()) if auc is not None: print(f"Finished testing. Test auc {auc:.4f}") return if FLAGS.save_checkpoint_path and not FLAGS.bottom_features_ordered and is_main_process( ): logging.warning( "Saving checkpoint without --bottom_features_ordered flag will result in " "a device-order dependent model. Consider using --bottom_features_ordered " "if you plan to load the checkpoint in different device configurations." ) loss_fn = torch.nn.BCEWithLogitsLoss(reduction="mean") # Print per 16384 * 2000 samples by default default_print_freq = 16384 * 2000 // FLAGS.batch_size print_freq = default_print_freq if FLAGS.print_freq is None else FLAGS.print_freq steps_per_epoch = len(data_loader_train) test_freq = FLAGS.test_freq if FLAGS.test_freq is not None else steps_per_epoch - 1 metric_logger = utils.MetricLogger(delimiter=" ") metric_logger.add_meter( 'loss', utils.SmoothedValue(window_size=1, fmt='{avg:.4f}')) metric_logger.add_meter( 'step_time', utils.SmoothedValue(window_size=1, fmt='{avg:.6f}')) metric_logger.add_meter( 'lr', utils.SmoothedValue(window_size=1, fmt='{value:.4f}')) # Accumulating loss on GPU to avoid memcpyD2H every step moving_loss = torch.zeros(1, device=device) moving_loss_stream = torch.cuda.Stream() lr_scheduler = utils.LearningRateScheduler( optimizers=[mlp_optimizer, embedding_optimizer], base_lrs=[mlp_lrs, embedding_lrs], warmup_steps=FLAGS.warmup_steps, warmup_factor=FLAGS.warmup_factor, decay_start_step=FLAGS.decay_start_step, decay_steps=FLAGS.decay_steps, decay_power=FLAGS.decay_power, end_lr_factor=FLAGS.decay_end_lr / FLAGS.lr) data_stream = torch.cuda.Stream() timer = utils.StepTimer() best_auc = 0 best_epoch = 0 start_time = time() stop_time = time() for epoch in range(FLAGS.epochs): epoch_start_time = time() batch_iter = prefetcher(iter(data_loader_train), data_stream) for step in range(len(data_loader_train)): timer.click() numerical_features, categorical_features, click = next(batch_iter) torch.cuda.synchronize() global_step = steps_per_epoch * epoch + step if FLAGS.max_steps and global_step > FLAGS.max_steps: print( f"Reached max global steps of {FLAGS.max_steps}. Stopping." ) break lr_scheduler.step() if click.shape[0] != FLAGS.batch_size: # last batch logging.error("The last batch with size %s is not supported", click.shape[0]) else: output = model(numerical_features, categorical_features, batch_sizes_per_gpu).squeeze() loss = loss_fn( output, click[batch_indices[rank]:batch_indices[rank + 1]]) if FLAGS.Adam_embedding_optimizer or FLAGS.Adam_MLP_optimizer: model.zero_grad() else: # We don't need to accumulate gradient. Set grad to None is faster than optimizer.zero_grad() for param_group in itertools.chain( embedding_optimizer.param_groups, mlp_optimizer.param_groups): for param in param_group['params']: param.grad = None if FLAGS.amp: loss *= FLAGS.loss_scale with amp.scale_loss(loss, mlp_optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() if FLAGS.Adam_MLP_optimizer: scale_MLP_gradients(mlp_optimizer, world_size) mlp_optimizer.step() if FLAGS.Adam_embedding_optimizer: scale_embeddings_gradients(embedding_optimizer, world_size) embedding_optimizer.step() moving_loss_stream.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(moving_loss_stream): moving_loss += loss if timer.measured is None: # first iteration, no step time etc. to print continue if step == 0: print(f"Started epoch {epoch}...") elif step % print_freq == 0: torch.cuda.current_stream().wait_stream(moving_loss_stream) # Averaging across a print_freq period to reduce the error. # An accurate timing needs synchronize which would slow things down. if global_step < FLAGS.benchmark_warmup_steps: metric_logger.update( loss=moving_loss.item() / print_freq / (FLAGS.loss_scale if FLAGS.amp else 1), lr=mlp_optimizer.param_groups[0]["lr"] * (FLAGS.loss_scale if FLAGS.amp else 1)) else: metric_logger.update( step_time=timer.measured, loss=moving_loss.item() / print_freq / (FLAGS.loss_scale if FLAGS.amp else 1), lr=mlp_optimizer.param_groups[0]["lr"] * (FLAGS.loss_scale if FLAGS.amp else 1)) stop_time = time() eta_str = datetime.timedelta( seconds=int(metric_logger.step_time.global_avg * (steps_per_epoch - step))) metric_logger.print( header= f"Epoch:[{epoch}/{FLAGS.epochs}] [{step}/{steps_per_epoch}] eta: {eta_str}" ) with torch.cuda.stream(moving_loss_stream): moving_loss = 0. if global_step % test_freq == 0 and global_step > 0 and global_step / steps_per_epoch >= FLAGS.test_after: auc = dist_evaluate(model, data_loader_test) if auc is None: continue print(f"Epoch {epoch} step {step}. auc {auc:.6f}") stop_time = time() if auc > best_auc: best_auc = auc best_epoch = epoch + ((step + 1) / steps_per_epoch) if FLAGS.auc_threshold and auc >= FLAGS.auc_threshold: run_time_s = int(stop_time - start_time) print( f"Hit target accuracy AUC {FLAGS.auc_threshold} at epoch " f"{global_step / steps_per_epoch:.2f} in {run_time_s}s. " f"Average speed {global_step * FLAGS.batch_size / run_time_s:.1f} records/s." ) sys.exit() epoch_stop_time = time() epoch_time_s = epoch_stop_time - epoch_start_time print( f"Finished epoch {epoch} in {datetime.timedelta(seconds=int(epoch_time_s))}. " f"Average speed {steps_per_epoch * FLAGS.batch_size / epoch_time_s:.1f} records/s." ) avg_throughput = FLAGS.batch_size / metric_logger.step_time.avg if FLAGS.save_checkpoint_path: checkpoint_writer.save_checkpoint(model, FLAGS.save_checkpoint_path, epoch, step) results = { 'best_auc': best_auc, 'best_epoch': best_epoch, 'average_train_throughput': avg_throughput } dllogger.log(data=results, step=tuple())
def main(argv): torch.manual_seed(FLAGS.seed) utils.init_logging(log_path=FLAGS.log_path) use_gpu = "cpu" not in FLAGS.base_device.lower() rank, world_size, gpu = dist.init_distributed_mode(backend=FLAGS.backend, use_gpu=use_gpu) device = FLAGS.base_device feature_spec = load_feature_spec(FLAGS) cat_feature_count = len(get_embedding_sizes(feature_spec, None)) validate_flags(cat_feature_count) if is_main_process(): dllogger.log(data=FLAGS.flag_values_dict(), step='PARAMETER') FLAGS.set_default("test_batch_size", FLAGS.test_batch_size // world_size * world_size) feature_spec = load_feature_spec(FLAGS) world_embedding_sizes = get_embedding_sizes(feature_spec, max_table_size=FLAGS.max_table_size) world_categorical_feature_sizes = np.asarray(world_embedding_sizes) device_mapping = get_device_mapping(world_embedding_sizes, num_gpus=world_size) batch_sizes_per_gpu = get_gpu_batch_sizes(FLAGS.batch_size, num_gpus=world_size) batch_indices = tuple(np.cumsum([0] + list(batch_sizes_per_gpu))) # todo what does this do # Embedding sizes for each GPU categorical_feature_sizes = world_categorical_feature_sizes[device_mapping['embedding'][rank]].tolist() num_numerical_features = feature_spec.get_number_of_numerical_features() bottom_mlp_sizes = FLAGS.bottom_mlp_sizes if rank == device_mapping['bottom_mlp'] else None data_loader_train, data_loader_test = get_data_loaders(FLAGS, device_mapping=device_mapping, feature_spec=feature_spec) model = DistributedDlrm( vectors_per_gpu=device_mapping['vectors_per_gpu'], embedding_device_mapping=device_mapping['embedding'], embedding_type=FLAGS.embedding_type, embedding_dim=FLAGS.embedding_dim, world_num_categorical_features=len(world_categorical_feature_sizes), categorical_feature_sizes=categorical_feature_sizes, num_numerical_features=num_numerical_features, hash_indices=FLAGS.hash_indices, bottom_mlp_sizes=bottom_mlp_sizes, top_mlp_sizes=FLAGS.top_mlp_sizes, interaction_op=FLAGS.interaction_op, fp16=FLAGS.amp, use_cpp_mlp=FLAGS.optimized_mlp, bottom_features_ordered=FLAGS.bottom_features_ordered, device=device ) dist.setup_distributed_print(is_main_process()) # DDP introduces a gradient average through allreduce(mean), which doesn't apply to bottom model. # Compensate it with further scaling lr if FLAGS.Adam_embedding_optimizer: embedding_model_parallel_lr = FLAGS.lr else: embedding_model_parallel_lr = FLAGS.lr / world_size if FLAGS.Adam_MLP_optimizer: MLP_model_parallel_lr = FLAGS.lr else: MLP_model_parallel_lr = FLAGS.lr / world_size data_parallel_lr = FLAGS.lr if is_main_process(): mlp_params = [ {'params': list(model.top_model.parameters()), 'lr': data_parallel_lr}, {'params': list(model.bottom_model.mlp.parameters()), 'lr': MLP_model_parallel_lr} ] mlp_lrs = [data_parallel_lr, MLP_model_parallel_lr] else: mlp_params = [ {'params': list(model.top_model.parameters()), 'lr': data_parallel_lr} ] mlp_lrs = [data_parallel_lr] if FLAGS.Adam_MLP_optimizer: mlp_optimizer = apex_optim.FusedAdam(mlp_params) else: mlp_optimizer = apex_optim.FusedSGD(mlp_params) embedding_params = [{ 'params': list(model.bottom_model.embeddings.parameters()), 'lr': embedding_model_parallel_lr }] embedding_lrs = [embedding_model_parallel_lr] if FLAGS.Adam_embedding_optimizer: embedding_optimizer = torch.optim.SparseAdam(embedding_params) else: embedding_optimizer = torch.optim.SGD(embedding_params) checkpoint_writer = make_distributed_checkpoint_writer( device_mapping=device_mapping, rank=rank, is_main_process=is_main_process(), config=FLAGS.flag_values_dict() ) checkpoint_loader = make_distributed_checkpoint_loader(device_mapping=device_mapping, rank=rank) if FLAGS.load_checkpoint_path: checkpoint_loader.load_checkpoint(model, FLAGS.load_checkpoint_path) model.to(device) scaler = torch.cuda.amp.GradScaler(enabled=FLAGS.amp, growth_interval=int(1e9)) def parallelize(model): if world_size <= 1: return model if use_gpu: model.top_model = parallel.DistributedDataParallel(model.top_model) else: # Use other backend for CPU model.top_model = torch.nn.parallel.DistributedDataParallel(model.top_model) return model if FLAGS.mode == 'test': model = parallelize(model) auc = dist_evaluate(model, data_loader_test) results = {'auc': auc} dllogger.log(data=results, step=tuple()) if auc is not None: print(f"Finished testing. Test auc {auc:.4f}") return elif FLAGS.mode == 'inference_benchmark': if world_size > 1: raise ValueError('Inference benchmark only supports singleGPU mode.') results = {} if FLAGS.amp: # can use pure FP16 for inference model = model.half() for batch_size in FLAGS.inference_benchmark_batch_sizes: FLAGS.test_batch_size = batch_size _, data_loader_test = get_data_loaders(FLAGS, device_mapping=device_mapping, feature_spec=feature_spec) latencies = inference_benchmark(model=model, data_loader=data_loader_test, num_batches=FLAGS.inference_benchmark_steps, cuda_graphs=FLAGS.cuda_graphs) # drop the first 10 as a warmup latencies = latencies[10:] mean_latency = np.mean(latencies) mean_inference_throughput = batch_size / mean_latency subresult = {f'mean_inference_latency_batch_{batch_size}': mean_latency, f'mean_inference_throughput_batch_{batch_size}': mean_inference_throughput} results.update(subresult) dllogger.log(data=results, step=tuple()) return if FLAGS.save_checkpoint_path and not FLAGS.bottom_features_ordered and is_main_process(): logging.warning("Saving checkpoint without --bottom_features_ordered flag will result in " "a device-order dependent model. Consider using --bottom_features_ordered " "if you plan to load the checkpoint in different device configurations.") loss_fn = torch.nn.BCEWithLogitsLoss(reduction="mean") # Print per 16384 * 2000 samples by default default_print_freq = 16384 * 2000 // FLAGS.batch_size print_freq = default_print_freq if FLAGS.print_freq is None else FLAGS.print_freq # last one will be dropped in the training loop steps_per_epoch = len(data_loader_train) - 1 test_freq = FLAGS.test_freq if FLAGS.test_freq is not None else steps_per_epoch - 2 metric_logger = utils.MetricLogger(delimiter=" ") metric_logger.add_meter('loss', utils.SmoothedValue(window_size=1, fmt='{avg:.8f}')) metric_logger.add_meter('step_time', utils.SmoothedValue(window_size=1, fmt='{avg:.6f}')) metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.4f}')) # Accumulating loss on GPU to avoid memcpyD2H every step moving_loss = torch.zeros(1, device=device) lr_scheduler = utils.LearningRateScheduler(optimizers=[mlp_optimizer, embedding_optimizer], base_lrs=[mlp_lrs, embedding_lrs], warmup_steps=FLAGS.warmup_steps, warmup_factor=FLAGS.warmup_factor, decay_start_step=FLAGS.decay_start_step, decay_steps=FLAGS.decay_steps, decay_power=FLAGS.decay_power, end_lr_factor=FLAGS.decay_end_lr / FLAGS.lr) def zero_grad(model): if FLAGS.Adam_embedding_optimizer or FLAGS.Adam_MLP_optimizer: model.zero_grad() else: # We don't need to accumulate gradient. Set grad to None is faster than optimizer.zero_grad() for param_group in itertools.chain(embedding_optimizer.param_groups, mlp_optimizer.param_groups): for param in param_group['params']: param.grad = None def forward_backward(model, *args): numerical_features, categorical_features, click = args with torch.cuda.amp.autocast(enabled=FLAGS.amp): output = model(numerical_features, categorical_features, batch_sizes_per_gpu).squeeze() loss = loss_fn(output, click[batch_indices[rank]: batch_indices[rank + 1]]) scaler.scale(loss).backward() return loss def weight_update(): if not FLAGS.freeze_mlps: if FLAGS.Adam_MLP_optimizer: scale_MLP_gradients(mlp_optimizer, world_size) scaler.step(mlp_optimizer) if not FLAGS.freeze_embeddings: if FLAGS.Adam_embedding_optimizer: scale_embeddings_gradients(embedding_optimizer, world_size) scaler.unscale_(embedding_optimizer) embedding_optimizer.step() scaler.update() trainer = CudaGraphWrapper(model, forward_backward, parallelize, zero_grad, cuda_graphs=FLAGS.cuda_graphs) data_stream = torch.cuda.Stream() timer = utils.StepTimer() best_auc = 0 best_epoch = 0 start_time = time() for epoch in range(FLAGS.epochs): epoch_start_time = time() batch_iter = prefetcher(iter(data_loader_train), data_stream) for step in range(len(data_loader_train)): timer.click() numerical_features, categorical_features, click = next(batch_iter) torch.cuda.synchronize() global_step = steps_per_epoch * epoch + step if FLAGS.max_steps and global_step > FLAGS.max_steps: print(f"Reached max global steps of {FLAGS.max_steps}. Stopping.") break # One of the batches will be smaller because the dataset size # isn't necessarily a multiple of the batch size. #TODO isn't dropping here a change of behavior if click.shape[0] != FLAGS.batch_size: continue lr_scheduler.step() loss = trainer.train_step(numerical_features, categorical_features, click) # need to wait for the gradients before the weight update torch.cuda.current_stream().wait_stream(trainer.stream) weight_update() moving_loss += loss if timer.measured is None: # first iteration, no step time etc. to print continue if step == 0: print(f"Started epoch {epoch}...") elif step % print_freq == 0: # Averaging across a print_freq period to reduce the error. # An accurate timing needs synchronize which would slow things down. if global_step < FLAGS.benchmark_warmup_steps: metric_logger.update( loss=moving_loss.item() / print_freq, lr=mlp_optimizer.param_groups[0]["lr"]) else: metric_logger.update( step_time=timer.measured, loss=moving_loss.item() / print_freq, lr=mlp_optimizer.param_groups[0]["lr"]) eta_str = datetime.timedelta(seconds=int(metric_logger.step_time.global_avg * (steps_per_epoch - step))) metric_logger.print(header=f"Epoch:[{epoch}/{FLAGS.epochs}] [{step}/{steps_per_epoch}] eta: {eta_str}") moving_loss = 0. if global_step % test_freq == 0 and global_step > 0 and global_step / steps_per_epoch >= FLAGS.test_after: auc = dist_evaluate(trainer.model, data_loader_test) if auc is None: continue print(f"Epoch {epoch} step {step}. auc {auc:.6f}") stop_time = time() if auc > best_auc: best_auc = auc best_epoch = epoch + ((step + 1) / steps_per_epoch) if FLAGS.auc_threshold and auc >= FLAGS.auc_threshold: run_time_s = int(stop_time - start_time) print(f"Hit target accuracy AUC {FLAGS.auc_threshold} at epoch " f"{global_step / steps_per_epoch:.2f} in {run_time_s}s. ") sys.exit() epoch_stop_time = time() epoch_time_s = epoch_stop_time - epoch_start_time print(f"Finished epoch {epoch} in {datetime.timedelta(seconds=int(epoch_time_s))}. ") avg_throughput = FLAGS.batch_size / metric_logger.step_time.avg if FLAGS.save_checkpoint_path: checkpoint_writer.save_checkpoint(model, FLAGS.save_checkpoint_path, epoch, step) results = {'best_auc': best_auc, 'best_epoch': best_epoch, 'average_train_throughput': avg_throughput} if is_main_process(): dllogger.log(data=results, step=tuple())