def parallelize(model): if world_size <= 1: return model if use_gpu: model.top_model = parallel.DistributedDataParallel(model.top_model) else: # Use other backend for CPU model.top_model = torch.nn.parallel.DistributedDataParallel(model.top_model) return model
def set_model_dist(net): if has_apex: net = parallel.DistributedDataParallel(net, delay_allreduce=True) else: local_rank = dist.get_rank() net = nn.parallel.DistributedDataParallel( net, device_ids=[local_rank, ], output_device=local_rank) return net
def main(argv): rank, world_size, gpu = dist.init_distributed_mode() top_mlp = create_top_mlp().to("cuda") print(top_mlp) optimizer = torch.optim.SGD(top_mlp.parameters(), lr=1.) if FLAGS.fp16: top_mlp, optimizer = amp.initialize(top_mlp, optimizer, opt_level="O1", loss_scale=1) if world_size > 1: top_mlp = parallel.DistributedDataParallel(top_mlp) model_without_ddp = top_mlp.module dummy_bottom_mlp_output = torch.rand(FLAGS.batch_size, EMBED_DIM, device="cuda") dummy_embedding_output = torch.rand(FLAGS.batch_size, 26 * EMBED_DIM, device="cuda") dummy_target = torch.ones(FLAGS.batch_size, device="cuda") if FLAGS.fp16: dummy_bottom_mlp_output = dummy_bottom_mlp_output.to(torch.half) dummy_embedding_output = dummy_embedding_output.to(torch.half) # warm up GPU for _ in range(100): interaction_out = dot_interaction(dummy_bottom_mlp_output, [dummy_embedding_output], FLAGS.batch_size) output = top_mlp(interaction_out) start_time = utils.timer_start() for _ in range(FLAGS.num_iters): interaction_out = dot_interaction(dummy_bottom_mlp_output, [dummy_embedding_output], FLAGS.batch_size) output = top_mlp(interaction_out).squeeze() dummy_loss = output.mean() optimizer.zero_grad() if FLAGS.fp16: with amp.scale_loss(dummy_loss, optimizer) as scaled_loss: scaled_loss.backward() else: dummy_loss.backward() optimizer.step() stop_time = utils.timer_stop() elapsed_time = (stop_time - start_time) / FLAGS.num_iters * 1e3 print(F"Average step time: {elapsed_time:.4f} ms.")
def train(rank): torch.backends.cudnn.benchmark = True torch.backends.cudnn.enabled = True dist.init_process_group(backend="nccl", init_method="tcp://localhost:34567", world_size=8, rank=rank) dist.barrier() model = deeplab.resnext101_aspp_kp(19) torch.cuda.set_device(rank) if rank == 0: writer = SummaryWriter(log_dir=args.checkpoint_dir, flush_secs=20) model = parallel.convert_syncbn_model(model) model.cuda(rank) model.load_state_dict( torch.load(args.model_dir / "resnext_cityscapes_2p.pth", map_location=f"cuda:{rank}"), strict=False, ) dist.barrier() if rank == 0: print(model.parameters) model = parallel.DistributedDataParallel(model) train_dataset = semantic_kitti.SemanticKitti( args.semantic_kitti_dir / "dataset/sequences", "train", ) train_sampler = utils.dist_utils.TrainingSampler(train_dataset) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=3, num_workers=8, drop_last=True, shuffle=False, pin_memory=True, sampler=train_sampler, ) val_loader = torch.utils.data.DataLoader( dataset=semantic_kitti.SemanticKitti( args.semantic_kitti_dir / "dataset/sequences", "val", ), batch_size=1, shuffle=False, num_workers=4, drop_last=False, ) loss_fn = utils.ohem.OhemCrossEntropy(ignore_index=255, thresh=0.9, min_kept=10000) optimizer = torch.optim.SGD(model.parameters(), lr=0.00001, momentum=0.9, weight_decay=1e-4) scheduler = utils.cosine_schedule.CosineAnnealingWarmUpRestarts( optimizer, T_0=96000, T_mult=10, eta_max=0.01875, T_up=1000, gamma=0.5) n_iter = 0 for epoch in range(120): model.train() for step, items in enumerate(train_loader): images = items["image"].cuda(rank, non_blocking=True) labels = items["labels"].long().cuda(rank, non_blocking=True) py = items["py"].float().cuda(rank, non_blocking=True) px = items["px"].float().cuda(rank, non_blocking=True) pxyz = items["points_xyz"].float().cuda(rank, non_blocking=True) knns = items["knns"].long().cuda(rank, non_blocking=True) predictions = model(images, px, py, pxyz, knns) loss = loss_fn(predictions, labels) optimizer.zero_grad() loss.backward() nn.utils.clip_grad_norm_(model.parameters(), 3.0) optimizer.step() if rank == 0: print( f"Epoch: {epoch} Iteration: {step} / {len(train_loader)} Loss: {loss.item()}" ) writer.add_scalar("loss/train", loss.item(), n_iter) writer.add_scalar("lr", optimizer.param_groups[0]["lr"], n_iter) n_iter += 1 scheduler.step() if rank == 0: if (epoch + 1) % 5 == 0: run_val(model, val_loader, n_iter, writer) torch.save(model.module.state_dict(), args.checkpoint_dir / f"epoch{epoch}.pth")
def main(argv): torch.manual_seed(FLAGS.seed) utils.init_logging(log_path=FLAGS.log_path) use_gpu = "cpu" not in FLAGS.base_device.lower() rank, world_size, gpu = dist.init_distributed_mode(backend=FLAGS.backend, use_gpu=use_gpu) device = FLAGS.base_device if is_main_process(): dllogger.log(data=FLAGS.flag_values_dict(), step='PARAMETER') print("Command line flags:") pprint(FLAGS.flag_values_dict()) print("Creating data loaders") FLAGS.set_default("test_batch_size", FLAGS.test_batch_size // world_size * world_size) categorical_feature_sizes = get_categorical_feature_sizes(FLAGS) world_categorical_feature_sizes = np.asarray(categorical_feature_sizes) device_mapping = get_device_mapping(categorical_feature_sizes, num_gpus=world_size) batch_sizes_per_gpu = get_gpu_batch_sizes(FLAGS.batch_size, num_gpus=world_size) batch_indices = tuple(np.cumsum([0] + list(batch_sizes_per_gpu))) # sizes of embeddings for each GPU categorical_feature_sizes = world_categorical_feature_sizes[ device_mapping['embedding'][rank]].tolist() bottom_mlp_sizes = FLAGS.bottom_mlp_sizes if rank == device_mapping[ 'bottom_mlp'] else None data_loader_train, data_loader_test = get_data_loaders( FLAGS, device_mapping=device_mapping) model = DistributedDlrm( vectors_per_gpu=device_mapping['vectors_per_gpu'], embedding_device_mapping=device_mapping['embedding'], embedding_type=FLAGS.embedding_type, embedding_dim=FLAGS.embedding_dim, world_num_categorical_features=len(world_categorical_feature_sizes), categorical_feature_sizes=categorical_feature_sizes, num_numerical_features=FLAGS.num_numerical_features, hash_indices=FLAGS.hash_indices, bottom_mlp_sizes=bottom_mlp_sizes, top_mlp_sizes=FLAGS.top_mlp_sizes, interaction_op=FLAGS.interaction_op, fp16=FLAGS.amp, use_cpp_mlp=FLAGS.optimized_mlp, bottom_features_ordered=FLAGS.bottom_features_ordered, device=device) print(model) print(device_mapping) print(f"Batch sizes per gpu: {batch_sizes_per_gpu}") dist.setup_distributed_print(is_main_process()) # DDP introduces a gradient average through allreduce(mean), which doesn't apply to bottom model. # Compensate it with further scaling lr scaled_lr = FLAGS.lr / FLAGS.loss_scale if FLAGS.amp else FLAGS.lr if FLAGS.Adam_embedding_optimizer: embedding_model_parallel_lr = scaled_lr else: embedding_model_parallel_lr = scaled_lr / world_size if FLAGS.Adam_MLP_optimizer: MLP_model_parallel_lr = scaled_lr else: MLP_model_parallel_lr = scaled_lr / world_size data_parallel_lr = scaled_lr if is_main_process(): mlp_params = [{ 'params': list(model.top_model.parameters()), 'lr': data_parallel_lr }, { 'params': list(model.bottom_model.mlp.parameters()), 'lr': MLP_model_parallel_lr }] mlp_lrs = [data_parallel_lr, MLP_model_parallel_lr] else: mlp_params = [{ 'params': list(model.top_model.parameters()), 'lr': data_parallel_lr }] mlp_lrs = [data_parallel_lr] if FLAGS.Adam_MLP_optimizer: mlp_optimizer = apex_optim.FusedAdam(mlp_params) else: mlp_optimizer = apex_optim.FusedSGD(mlp_params) embedding_params = [{ 'params': list(model.bottom_model.embeddings.parameters()), 'lr': embedding_model_parallel_lr }] embedding_lrs = [embedding_model_parallel_lr] if FLAGS.Adam_embedding_optimizer: embedding_optimizer = torch.optim.SparseAdam(embedding_params) else: embedding_optimizer = torch.optim.SGD(embedding_params) checkpoint_writer = make_distributed_checkpoint_writer( device_mapping=device_mapping, rank=rank, is_main_process=is_main_process(), config=FLAGS.flag_values_dict()) checkpoint_loader = make_distributed_checkpoint_loader( device_mapping=device_mapping, rank=rank) if FLAGS.load_checkpoint_path: checkpoint_loader.load_checkpoint(model, FLAGS.load_checkpoint_path) model.to(device) if FLAGS.amp: (model.top_model, model.bottom_model.mlp), mlp_optimizer = amp.initialize( [model.top_model, model.bottom_model.mlp], mlp_optimizer, opt_level="O2", loss_scale=1) if use_gpu: model.top_model = parallel.DistributedDataParallel(model.top_model) else: # Use other backend for CPU model.top_model = torch.nn.parallel.DistributedDataParallel( model.top_model) if FLAGS.mode == 'test': auc = dist_evaluate(model, data_loader_test) results = {'auc': auc} dllogger.log(data=results, step=tuple()) if auc is not None: print(f"Finished testing. Test auc {auc:.4f}") return if FLAGS.save_checkpoint_path and not FLAGS.bottom_features_ordered and is_main_process( ): logging.warning( "Saving checkpoint without --bottom_features_ordered flag will result in " "a device-order dependent model. Consider using --bottom_features_ordered " "if you plan to load the checkpoint in different device configurations." ) loss_fn = torch.nn.BCEWithLogitsLoss(reduction="mean") # Print per 16384 * 2000 samples by default default_print_freq = 16384 * 2000 // FLAGS.batch_size print_freq = default_print_freq if FLAGS.print_freq is None else FLAGS.print_freq steps_per_epoch = len(data_loader_train) test_freq = FLAGS.test_freq if FLAGS.test_freq is not None else steps_per_epoch - 1 metric_logger = utils.MetricLogger(delimiter=" ") metric_logger.add_meter( 'loss', utils.SmoothedValue(window_size=1, fmt='{avg:.4f}')) metric_logger.add_meter( 'step_time', utils.SmoothedValue(window_size=1, fmt='{avg:.6f}')) metric_logger.add_meter( 'lr', utils.SmoothedValue(window_size=1, fmt='{value:.4f}')) # Accumulating loss on GPU to avoid memcpyD2H every step moving_loss = torch.zeros(1, device=device) moving_loss_stream = torch.cuda.Stream() lr_scheduler = utils.LearningRateScheduler( optimizers=[mlp_optimizer, embedding_optimizer], base_lrs=[mlp_lrs, embedding_lrs], warmup_steps=FLAGS.warmup_steps, warmup_factor=FLAGS.warmup_factor, decay_start_step=FLAGS.decay_start_step, decay_steps=FLAGS.decay_steps, decay_power=FLAGS.decay_power, end_lr_factor=FLAGS.decay_end_lr / FLAGS.lr) data_stream = torch.cuda.Stream() timer = utils.StepTimer() best_auc = 0 best_epoch = 0 start_time = time() stop_time = time() for epoch in range(FLAGS.epochs): epoch_start_time = time() batch_iter = prefetcher(iter(data_loader_train), data_stream) for step in range(len(data_loader_train)): timer.click() numerical_features, categorical_features, click = next(batch_iter) torch.cuda.synchronize() global_step = steps_per_epoch * epoch + step if FLAGS.max_steps and global_step > FLAGS.max_steps: print( f"Reached max global steps of {FLAGS.max_steps}. Stopping." ) break lr_scheduler.step() if click.shape[0] != FLAGS.batch_size: # last batch logging.error("The last batch with size %s is not supported", click.shape[0]) else: output = model(numerical_features, categorical_features, batch_sizes_per_gpu).squeeze() loss = loss_fn( output, click[batch_indices[rank]:batch_indices[rank + 1]]) if FLAGS.Adam_embedding_optimizer or FLAGS.Adam_MLP_optimizer: model.zero_grad() else: # We don't need to accumulate gradient. Set grad to None is faster than optimizer.zero_grad() for param_group in itertools.chain( embedding_optimizer.param_groups, mlp_optimizer.param_groups): for param in param_group['params']: param.grad = None if FLAGS.amp: loss *= FLAGS.loss_scale with amp.scale_loss(loss, mlp_optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() if FLAGS.Adam_MLP_optimizer: scale_MLP_gradients(mlp_optimizer, world_size) mlp_optimizer.step() if FLAGS.Adam_embedding_optimizer: scale_embeddings_gradients(embedding_optimizer, world_size) embedding_optimizer.step() moving_loss_stream.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(moving_loss_stream): moving_loss += loss if timer.measured is None: # first iteration, no step time etc. to print continue if step == 0: print(f"Started epoch {epoch}...") elif step % print_freq == 0: torch.cuda.current_stream().wait_stream(moving_loss_stream) # Averaging across a print_freq period to reduce the error. # An accurate timing needs synchronize which would slow things down. if global_step < FLAGS.benchmark_warmup_steps: metric_logger.update( loss=moving_loss.item() / print_freq / (FLAGS.loss_scale if FLAGS.amp else 1), lr=mlp_optimizer.param_groups[0]["lr"] * (FLAGS.loss_scale if FLAGS.amp else 1)) else: metric_logger.update( step_time=timer.measured, loss=moving_loss.item() / print_freq / (FLAGS.loss_scale if FLAGS.amp else 1), lr=mlp_optimizer.param_groups[0]["lr"] * (FLAGS.loss_scale if FLAGS.amp else 1)) stop_time = time() eta_str = datetime.timedelta( seconds=int(metric_logger.step_time.global_avg * (steps_per_epoch - step))) metric_logger.print( header= f"Epoch:[{epoch}/{FLAGS.epochs}] [{step}/{steps_per_epoch}] eta: {eta_str}" ) with torch.cuda.stream(moving_loss_stream): moving_loss = 0. if global_step % test_freq == 0 and global_step > 0 and global_step / steps_per_epoch >= FLAGS.test_after: auc = dist_evaluate(model, data_loader_test) if auc is None: continue print(f"Epoch {epoch} step {step}. auc {auc:.6f}") stop_time = time() if auc > best_auc: best_auc = auc best_epoch = epoch + ((step + 1) / steps_per_epoch) if FLAGS.auc_threshold and auc >= FLAGS.auc_threshold: run_time_s = int(stop_time - start_time) print( f"Hit target accuracy AUC {FLAGS.auc_threshold} at epoch " f"{global_step / steps_per_epoch:.2f} in {run_time_s}s. " f"Average speed {global_step * FLAGS.batch_size / run_time_s:.1f} records/s." ) sys.exit() epoch_stop_time = time() epoch_time_s = epoch_stop_time - epoch_start_time print( f"Finished epoch {epoch} in {datetime.timedelta(seconds=int(epoch_time_s))}. " f"Average speed {steps_per_epoch * FLAGS.batch_size / epoch_time_s:.1f} records/s." ) avg_throughput = FLAGS.batch_size / metric_logger.step_time.avg if FLAGS.save_checkpoint_path: checkpoint_writer.save_checkpoint(model, FLAGS.save_checkpoint_path, epoch, step) results = { 'best_auc': best_auc, 'best_epoch': best_epoch, 'average_train_throughput': avg_throughput } dllogger.log(data=results, step=tuple())
random.seed(hparams.seed) reader = Reader(hparams) start = time.time() logger.info("Loading data...") reader.load_data("train") end = time.time() logger.info("Loaded. {} secs".format(end - start)) model = DST(hparams).cuda() optimizer = Adam(model.parameters(), hparams.lr) model, optimizer = amp.initialize(model, optimizer, opt_level="O1", verbosity=0) model = parallel.DistributedDataParallel(model) # load saved model, optimizer if hparams.save_path is not None: load(model, optimizer, hparams.save_path) torch.distributed.barrier() train.max_iter = len(list(reader.make_batch(reader.train))) validate.max_iter = len(list(reader.make_batch(reader.dev))) train.warmup_steps = train.max_iter * hparams.max_epochs * hparams.warmup_steps train.global_step = 0 max_joint_acc = 0 early_stop_count = hparams.early_stop_count for epoch in range(hparams.max_epochs):
def train(self, args): # Reset amp if args.use_apex: from apex import amp amp.init(False) # Get dataloaders train_dataloader = ds_utils.get_dataloader(args, 'train') if not args.skip_test: test_dataloader = ds_utils.get_dataloader(args, 'test') model = runner = self.runner if args.use_half: runner.half() # Initialize optimizers, schedulers and apex opts = runner.get_optimizers(args) # Load pre-trained params for optimizers and schedulers (if needed) if args.which_epoch != 'none' and not args.init_experiment_dir: for net_name, opt in opts.items(): opt.load_state_dict(torch.load(self.checkpoints_dir / f'{args.which_epoch}_opt_{net_name}.pth', map_location='cpu')) if args.use_apex and args.num_gpus > 0 and args.num_gpus <= 8: # Enfornce apex mixed precision settings nets_list, opts_list = [], [] for net_name in sorted(opts.keys()): nets_list.append(runner.nets[net_name]) opts_list.append(opts[net_name]) loss_scale = float(args.amp_loss_scale) if args.amp_loss_scale != 'dynamic' else args.amp_loss_scale nets_list, opts_list = amp.initialize(nets_list, opts_list, opt_level=args.amp_opt_level, num_losses=1, loss_scale=loss_scale) # Unpack opts_list into optimizers for net_name, net, opt in zip(sorted(opts.keys()), nets_list, opts_list): runner.nets[net_name] = net opts[net_name] = opt if args.which_epoch != 'none' and not args.init_experiment_dir and os.path.exists(self.checkpoints_dir / f'{args.which_epoch}_amp.pth'): amp.load_state_dict(torch.load(self.checkpoints_dir / f'{args.which_epoch}_amp.pth', map_location='cpu')) # Initialize apex distributed data parallel wrapper if args.num_gpus > 1 and args.num_gpus <= 8: from apex import parallel model = parallel.DistributedDataParallel(runner, delay_allreduce=True) epoch_start = 1 if args.which_epoch == 'none' else int(args.which_epoch) + 1 # Initialize logging train_iter = epoch_start - 1 if args.visual_freq != -1: train_iter /= args.visual_freq logger = Logger(args, self.experiment_dir) logger.set_num_iter( train_iter=train_iter, test_iter=(epoch_start - 1) // args.test_freq) if args.debug and not args.use_apex: torch.autograd.set_detect_anomaly(True) total_iters = 1 for epoch in range(epoch_start, args.num_epochs + 1): if args.rank == 0: print('epoch %d' % epoch) # Train for one epoch model.train() time_start = time.time() # Shuffle the dataset before the epoch train_dataloader.dataset.shuffle() for i, data_dict in enumerate(train_dataloader, 1): # Prepare input data if args.num_gpus > 0 and args.num_gpus > 0: for key, value in data_dict.items(): data_dict[key] = value.cuda() # Convert inputs to FP16 if args.use_half: for key, value in data_dict.items(): data_dict[key] = value.half() output_logs = i == len(train_dataloader) if args.visual_freq != -1: output_logs = not (total_iters % args.visual_freq) output_visuals = output_logs and not args.no_disk_write_ops # Accumulate list of optimizers that will perform opt step for opt in opts.values(): opt.zero_grad() # Perform a forward pass if not args.use_closure: loss = model(data_dict) closure = None if args.use_apex and args.num_gpus > 0 and args.num_gpus <= 8: # Mixed precision requires a special wrapper for the loss with amp.scale_loss(loss, opts.values()) as scaled_loss: scaled_loss.backward() elif not args.use_closure: loss.backward() else: def closure(): loss = model(data_dict) loss.backward() return loss # Perform steps for all optimizers for opt in opts.values(): opt.step(closure) if output_logs: logger.output_logs('train', runner.output_visuals(), runner.output_losses(), time.time() - time_start) if args.debug: break if args.visual_freq != -1: total_iters += 1 total_iters %= args.visual_freq # Increment the epoch counter in the training dataset train_dataloader.dataset.epoch += 1 # If testing is not required -- continue if epoch % args.test_freq: continue # If skip test flag is set -- only check if a checkpoint if required if not args.skip_test: # Calculate "standing" stats for the batch normalization if args.calc_stats: runner.calculate_batchnorm_stats(train_dataloader, args.debug) # Test time_start = time.time() model.eval() for data_dict in test_dataloader: # Prepare input data if args.num_gpus > 0: for key, value in data_dict.items(): data_dict[key] = value.cuda() # Forward pass with torch.no_grad(): model(data_dict) if args.debug: break # Output logs logger.output_logs('test', runner.output_visuals(), runner.output_losses(), time.time() - time_start) # If creation of checkpoint is not required -- continue if epoch % args.checkpoint_freq and not args.debug: continue # Create or load a checkpoint if args.rank == 0 and not args.no_disk_write_ops: with torch.no_grad(): for net_name in runner.nets_names_to_train: # Save a network torch.save(runner.nets[net_name].state_dict(), self.checkpoints_dir / f'{epoch}_{net_name}.pth') # Save an optimizer torch.save(opts[net_name].state_dict(), self.checkpoints_dir / f'{epoch}_opt_{net_name}.pth') # Save amp if args.use_apex: torch.save(amp.state_dict(), self.checkpoints_dir / f'{epoch}_amp.pth') return runner
def main(argv): if FLAGS.seed is not None: torch.manual_seed(FLAGS.seed) np.random.seed(FLAGS.seed) # Initialize distributed mode use_gpu = "cpu" not in FLAGS.device.lower() rank, world_size, gpu = dist.init_distributed_mode(backend=FLAGS.backend, use_gpu=use_gpu) if world_size == 1: raise NotImplementedError( "This file is only for distributed training.") mlperf_logger.mlperf_submission_log('dlrm') mlperf_logger.log_event(key=mlperf_logger.constants.SEED, value=FLAGS.seed) mlperf_logger.log_event(key=mlperf_logger.constants.GLOBAL_BATCH_SIZE, value=FLAGS.batch_size) # Only print cmd args on rank 0 if rank == 0: print("Command line flags:") pprint(FLAGS.flag_values_dict()) # Check arguments sanity if FLAGS.batch_size % world_size != 0: raise ValueError( F"Batch size {FLAGS.batch_size} is not divisible by world_size {world_size}." ) if FLAGS.test_batch_size % world_size != 0: raise ValueError( F"Test batch size {FLAGS.test_batch_size} is not divisible by world_size {world_size}." ) # Load config file, create sub config for each rank with open(FLAGS.model_config, "r") as f: config = json.loads(f.read()) wolrd_categorical_feature_sizes = np.asarray( config.pop('categorical_feature_sizes')) device_mapping = dist_model.get_criteo_device_mapping(world_size) vectors_per_gpu = device_mapping['vectors_per_gpu'] # Get sizes of embeddings each GPU is gonna create categorical_feature_sizes = wolrd_categorical_feature_sizes[ device_mapping['embedding'][rank]].tolist() bottom_mlp_sizes = config.pop('bottom_mlp_sizes') if rank != device_mapping['bottom_mlp']: bottom_mlp_sizes = None model = dist_model.DistDlrm( categorical_feature_sizes=categorical_feature_sizes, bottom_mlp_sizes=bottom_mlp_sizes, world_num_categorical_features=len(wolrd_categorical_feature_sizes), **config, device=FLAGS.device, use_embedding_ext=FLAGS.use_embedding_ext) print(model) dist.setup_distributed_print(rank == 0) # DDP introduces a gradient average through allreduce(mean), which doesn't apply to bottom model. # Compensate it with further scaling lr scaled_lr = FLAGS.lr / FLAGS.loss_scale if FLAGS.fp16 else FLAGS.lr scaled_lrs = [scaled_lr / world_size, scaled_lr] embedding_optimizer = torch.optim.SGD([ { 'params': model.bottom_model.joint_embedding.parameters(), 'lr': scaled_lrs[0] }, ]) mlp_optimizer = apex_optim.FusedSGD([{ 'params': model.bottom_model.bottom_mlp.parameters(), 'lr': scaled_lrs[0] }, { 'params': model.top_model.parameters(), 'lr': scaled_lrs[1] }]) if FLAGS.fp16: (model.top_model, model.bottom_model.bottom_mlp), mlp_optimizer = amp.initialize( [model.top_model, model.bottom_model.bottom_mlp], mlp_optimizer, opt_level="O2", loss_scale=1, cast_model_outputs=torch.float16) if use_gpu: model.top_model = parallel.DistributedDataParallel(model.top_model) else: # Use other backend for CPU model.top_model = torch.nn.parallel.DistributedDataParallel( model.top_model) loss_fn = torch.nn.BCEWithLogitsLoss(reduction="mean") # Too many arguments to pass for distributed training. Use plain train code here instead of # defining a train function # Print per 16384 * 2000 samples by default default_print_freq = 16384 * 2000 // FLAGS.batch_size print_freq = default_print_freq if FLAGS.print_freq is None else FLAGS.print_freq metric_logger = utils.MetricLogger(delimiter=" ") metric_logger.add_meter( 'loss', utils.SmoothedValue(window_size=1, fmt='{avg:.4f}')) metric_logger.add_meter( 'step_time', utils.SmoothedValue(window_size=1, fmt='{avg:.4f} ms')) metric_logger.add_meter( 'lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) # Accumulating loss on GPU to avoid memcpyD2H every step moving_loss = torch.zeros(1, device=FLAGS.device) moving_loss_stream = torch.cuda.Stream() local_embedding_device_mapping = torch.tensor( device_mapping['embedding'][rank], device=FLAGS.device, dtype=torch.long) # LR is logged twice for now because of a compliance checker bug mlperf_logger.log_event(key=mlperf_logger.constants.OPT_BASE_LR, value=FLAGS.lr) mlperf_logger.log_event(key=mlperf_logger.constants.OPT_LR_WARMUP_STEPS, value=FLAGS.warmup_steps) # use logging keys from the official HP table and not from the logging library mlperf_logger.log_event(key='sgd_opt_base_learning_rate', value=FLAGS.lr) mlperf_logger.log_event(key='lr_decay_start_steps', value=FLAGS.decay_start_step) mlperf_logger.log_event(key='sgd_opt_learning_rate_decay_steps', value=FLAGS.decay_steps) mlperf_logger.log_event(key='sgd_opt_learning_rate_decay_poly_power', value=FLAGS.decay_power) lr_scheduler = utils.LearningRateScheduler( optimizers=[mlp_optimizer, embedding_optimizer], base_lrs=[scaled_lrs, [scaled_lrs[0]]], warmup_steps=FLAGS.warmup_steps, warmup_factor=FLAGS.warmup_factor, decay_start_step=FLAGS.decay_start_step, decay_steps=FLAGS.decay_steps, decay_power=FLAGS.decay_power, end_lr_factor=FLAGS.decay_end_lr / FLAGS.lr) data_stream = torch.cuda.Stream() eval_data_cache = [] if FLAGS.cache_eval_data else None start_time = time() stop_time = time() print("Creating data loaders") dist_dataset_args = { "numerical_features": rank == 0, "categorical_features": device_mapping['embedding'][rank] } mlperf_logger.barrier() mlperf_logger.log_end(key=mlperf_logger.constants.INIT_STOP) mlperf_logger.barrier() mlperf_logger.log_start(key=mlperf_logger.constants.RUN_START) mlperf_logger.barrier() data_loader_train, data_loader_test = dataset.get_data_loader( FLAGS.dataset, FLAGS.batch_size, FLAGS.test_batch_size, FLAGS.device, dataset_type=FLAGS.dataset_type, shuffle=FLAGS.shuffle, **dist_dataset_args) steps_per_epoch = len(data_loader_train) # Default 20 tests per epoch test_freq = FLAGS.test_freq if FLAGS.test_freq is not None else steps_per_epoch // 20 for epoch in range(FLAGS.epochs): epoch_start_time = time() mlperf_logger.barrier() mlperf_logger.log_start(key=mlperf_logger.constants.BLOCK_START, metadata={ mlperf_logger.constants.FIRST_EPOCH_NUM: epoch + 1, mlperf_logger.constants.EPOCH_COUNT: 1 }) mlperf_logger.barrier() mlperf_logger.log_start( key=mlperf_logger.constants.EPOCH_START, metadata={mlperf_logger.constants.EPOCH_NUM: epoch + 1}) if FLAGS.profile_steps is not None: torch.cuda.profiler.start() for step, (numerical_features, categorical_features, click) in enumerate( dataset.prefetcher(iter(data_loader_train), data_stream)): torch.cuda.current_stream().wait_stream(data_stream) global_step = steps_per_epoch * epoch + step lr_scheduler.step() # Slice out categorical features if not using the "dist" dataset if FLAGS.dataset_type != "dist": categorical_features = categorical_features[:, local_embedding_device_mapping] if FLAGS.fp16 and categorical_features is not None: numerical_features = numerical_features.to(torch.float16) last_batch_size = None if click.shape[0] != FLAGS.batch_size: # last batch last_batch_size = click.shape[0] logging.debug("Pad the last batch of size %d to %d", last_batch_size, FLAGS.batch_size) padding_size = FLAGS.batch_size - last_batch_size padding_numiercal = torch.empty( padding_size, numerical_features.shape[1], device=numerical_features.device, dtype=numerical_features.dtype) numerical_features = torch.cat( (numerical_features, padding_numiercal), dim=0) if categorical_features is not None: padding_categorical = torch.ones( padding_size, categorical_features.shape[1], device=categorical_features.device, dtype=categorical_features.dtype) categorical_features = torch.cat( (categorical_features, padding_categorical), dim=0) padding_click = torch.empty(padding_size, device=click.device, dtype=click.dtype) click = torch.cat((click, padding_click)) bottom_out = model.bottom_model(numerical_features, categorical_features) batch_size_per_gpu = FLAGS.batch_size // world_size from_bottom = dist_model.bottom_to_top(bottom_out, batch_size_per_gpu, config['embedding_dim'], vectors_per_gpu) if last_batch_size is not None: partial_rank = math.ceil(last_batch_size / batch_size_per_gpu) if rank == partial_rank: top_out = model.top_model( from_bottom[:last_batch_size % batch_size_per_gpu]).squeeze().float() loss = loss_fn( top_out, click[rank * batch_size_per_gpu:(rank + 1) * batch_size_per_gpu][:last_batch_size % batch_size_per_gpu]) elif rank < partial_rank: loss = loss_fn( model.top_model(from_bottom).squeeze().float(), click[rank * batch_size_per_gpu:(rank + 1) * batch_size_per_gpu]) else: # Back propgate nothing for padded samples loss = 0. * model.top_model( from_bottom).squeeze().float().mean() else: loss = loss_fn( model.top_model(from_bottom).squeeze().float(), click[rank * batch_size_per_gpu:(rank + 1) * batch_size_per_gpu]) # We don't need to accumulate gradient. Set grad to None is faster than optimizer.zero_grad() for param_group in itertools.chain( embedding_optimizer.param_groups, mlp_optimizer.param_groups): for param in param_group['params']: param.grad = None if FLAGS.fp16: loss *= FLAGS.loss_scale with amp.scale_loss(loss, mlp_optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() mlp_optimizer.step() embedding_optimizer.step() moving_loss_stream.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(moving_loss_stream): moving_loss += loss if step == 0: print(F"Started epoch {epoch}...") elif step % print_freq == 0: torch.cuda.synchronize() # Averaging cross a print_freq period to reduce the error. # An accurate timing needs synchronize which would slow things down. metric_logger.update(step_time=(time() - stop_time) * 1000 / print_freq, loss=moving_loss.item() / print_freq / (FLAGS.loss_scale if FLAGS.fp16 else 1), lr=mlp_optimizer.param_groups[1]["lr"] * (FLAGS.loss_scale if FLAGS.fp16 else 1)) stop_time = time() eta_str = datetime.timedelta( seconds=int(metric_logger.step_time.avg / 1000 * (steps_per_epoch - step))) metric_logger.print( header= F"Epoch:[{epoch}/{FLAGS.epochs}] [{step}/{steps_per_epoch}] eta: {eta_str}" ) moving_loss = 0. with torch.cuda.stream(moving_loss_stream): moving_loss = 0. if global_step % test_freq == 0 and global_step > 0 and global_step / steps_per_epoch >= FLAGS.test_after: mlperf_epoch_index = global_step / steps_per_epoch + 1 mlperf_logger.barrier() mlperf_logger.log_start(key=mlperf_logger.constants.EVAL_START, metadata={ mlperf_logger.constants.EPOCH_NUM: mlperf_epoch_index }) auc = dist_evaluate(model, data_loader_test, eval_data_cache) mlperf_logger.log_event( key=mlperf_logger.constants.EVAL_ACCURACY, value=float(auc), metadata={ mlperf_logger.constants.EPOCH_NUM: mlperf_epoch_index }) print(F"Epoch {epoch} step {step}. auc {auc:.6f}") stop_time = time() mlperf_logger.barrier() mlperf_logger.log_end(key=mlperf_logger.constants.EVAL_STOP, metadata={ mlperf_logger.constants.EPOCH_NUM: mlperf_epoch_index }) if auc > FLAGS.auc_threshold: mlperf_logger.barrier() mlperf_logger.log_end(key=mlperf_logger.constants.RUN_STOP, metadata={ mlperf_logger.constants.STATUS: mlperf_logger.constants.SUCCESS }) mlperf_logger.barrier() mlperf_logger.log_end( key=mlperf_logger.constants.EPOCH_STOP, metadata={ mlperf_logger.constants.EPOCH_NUM: epoch + 1 }) mlperf_logger.barrier() mlperf_logger.log_end( key=mlperf_logger.constants.BLOCK_STOP, metadata={ mlperf_logger.constants.FIRST_EPOCH_NUM: epoch + 1 }) run_time_s = int(stop_time - start_time) print( F"Hit target accuracy AUC {FLAGS.auc_threshold} at epoch " F"{global_step/steps_per_epoch:.2f} in {run_time_s}s. " F"Average speed {global_step * FLAGS.batch_size / run_time_s:.1f} records/s." ) return if FLAGS.profile_steps is not None and global_step == FLAGS.profile_steps: torch.cuda.profiler.stop() logging.warning("Profile run, stopped at step %d.", global_step) return mlperf_logger.barrier() mlperf_logger.log_end( key=mlperf_logger.constants.EPOCH_STOP, metadata={mlperf_logger.constants.EPOCH_NUM: epoch + 1}) mlperf_logger.barrier() mlperf_logger.log_end( key=mlperf_logger.constants.BLOCK_STOP, metadata={mlperf_logger.constants.FIRST_EPOCH_NUM: epoch + 1}) epoch_stop_time = time() epoch_time_s = epoch_stop_time - epoch_start_time print( F"Finished epoch {epoch} in {datetime.timedelta(seconds=int(epoch_time_s))}. " F"Average speed {steps_per_epoch * FLAGS.batch_size / epoch_time_s:.1f} records/s." ) mlperf_logger.barrier() mlperf_logger.log_end(key=mlperf_logger.constants.RUN_STOP, metadata={ mlperf_logger.constants.STATUS: mlperf_logger.constants.ABORTED })
def train(cfg): num_gpus = torch.cuda.device_count() if num_gpus > 1: torch.distributed.init_process_group(backend="nccl", world_size=num_gpus) # set logger log_dir = os.path.join("logs/", cfg.source_dataset, cfg.prefix) if not os.path.isdir(log_dir): os.makedirs(log_dir, exist_ok=True) logging.basicConfig(format="%(asctime)s %(message)s", filename=log_dir + "/" + "log.txt", filemode="a") logger = logging.getLogger() logger.setLevel(logging.INFO) stream_handler = logging.StreamHandler() stream_handler.setLevel(logging.INFO) logger.addHandler(stream_handler) # writer = SummaryWriter(log_dir, purge_step=0) if dist.is_initialized() and dist.get_rank() != 0: logger = writer = None else: logger.info(pprint.pformat(cfg)) # training data loader if not cfg.joint_training: # single domain train_loader = get_train_loader(root=os.path.join( cfg.source.root, cfg.source.train), batch_size=cfg.batch_size, image_size=cfg.image_size, random_flip=cfg.random_flip, random_crop=cfg.random_crop, random_erase=cfg.random_erase, color_jitter=cfg.color_jitter, padding=cfg.padding, num_workers=4) else: # cross domain source_root = os.path.join(cfg.source.root, cfg.source.train) target_root = os.path.join(cfg.target.root, cfg.target.train) train_loader = get_cross_domain_train_loader( source_root=source_root, target_root=target_root, batch_size=cfg.batch_size, random_flip=cfg.random_flip, random_crop=cfg.random_crop, random_erase=cfg.random_erase, color_jitter=cfg.color_jitter, padding=cfg.padding, image_size=cfg.image_size, num_workers=8) # evaluation data loader query_loader = None gallery_loader = None if cfg.eval_interval > 0: query_loader = get_test_loader(root=os.path.join( cfg.target.root, cfg.target.query), batch_size=512, image_size=cfg.image_size, num_workers=4) gallery_loader = get_test_loader(root=os.path.join( cfg.target.root, cfg.target.gallery), batch_size=512, image_size=cfg.image_size, num_workers=4) # model num_classes = cfg.source.num_id num_cam = cfg.source.num_cam + cfg.target.num_cam cam_ids = train_loader.dataset.target_dataset.cam_ids if cfg.joint_training else train_loader.dataset.cam_ids num_instances = len( train_loader.dataset.target_dataset) if cfg.joint_training else None model = Model(num_classes=num_classes, drop_last_stride=cfg.drop_last_stride, joint_training=cfg.joint_training, num_instances=num_instances, cam_ids=cam_ids, num_cam=num_cam, neighbor_mode=cfg.neighbor_mode, neighbor_eps=cfg.neighbor_eps, scale=cfg.scale, mix=cfg.mix, alpha=cfg.alpha) model.cuda() # optimizer ft_params = model.backbone.parameters() new_params = [ param for name, param in model.named_parameters() if not name.startswith("backbone.") ] param_groups = [{ 'params': ft_params, 'lr': cfg.ft_lr }, { 'params': new_params, 'lr': cfg.new_params_lr }] optimizer = optim.SGD(param_groups, momentum=0.9, weight_decay=cfg.wd) # convert model for mixed precision distributed training model, optimizer = amp.initialize(model, optimizer, enabled=cfg.fp16, opt_level="O2") lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer=optimizer, milestones=cfg.lr_step, gamma=0.1) if dist.is_initialized(): model = parallel.DistributedDataParallel(model, delay_allreduce=True) # engine checkpoint_dir = os.path.join("checkpoints", cfg.source_dataset, cfg.prefix) engine = get_trainer( model=model, optimizer=optimizer, lr_scheduler=lr_scheduler, logger=logger, # writer=writer, non_blocking=True, log_period=cfg.log_period, save_interval=10, save_dir=checkpoint_dir, prefix=cfg.prefix, eval_interval=cfg.eval_interval, query_loader=query_loader, gallery_loader=gallery_loader) # training engine.run(train_loader, max_epochs=cfg.num_epoch) if dist.is_initialized(): dist.destroy_process_group()
def main(): distributed.init_process_group(backend='nccl') with open(args.config) as file: config = Dict(json.load(file)) config.update(vars(args)) config.update( dict(world_size=distributed.get_world_size(), global_rank=distributed.get_rank(), device_count=torch.cuda.device_count())) config = apply(Dict, config) print(f'config: {config}') torch.manual_seed(0) torch.cuda.set_device(config.local_rank) generator = models.Generator(linear_params=[ Dict(in_features=44, out_features=128), *[Dict(in_features=128, out_features=128)] * 8, Dict(in_features=128, out_features=1) ]).cuda() discriminator = models.Discriminator( conv_params=[ Dict(in_channels=1, out_channels=32, kernel_size=3, stride=2, bias=False), Dict(in_channels=32, out_channels=64, kernel_size=3, stride=2, bias=False) ], linear_param=Dict(in_features=64, out_features=11)).cuda() generator_optimizer = torch.optim.Adam(params=generator.parameters(), lr=config.generator_lr, betas=(config.generator_beta1, config.generator_beta2)) discriminator_optimizer = torch.optim.Adam( params=discriminator.parameters(), lr=config.discriminator_lr, betas=(config.discriminator_beta1, config.discriminator_beta2)) [generator, discriminator ], [generator_optimizer, discriminator_optimizer] = amp.initialize( models=[generator, discriminator], optimizers=[generator_optimizer, discriminator_optimizer], opt_level=config.opt_level) generator = parallel.DistributedDataParallel(generator, delay_allreduce=True) discriminator = parallel.DistributedDataParallel(discriminator, delay_allreduce=True) epoch = 0 global_step = 0 if config.checkpoint: checkpoint = Dict(torch.load(config.checkpoint), map_location=lambda storage, location: storage.cuda( config.local_rank)) generator.load_state_dict(checkpoint.generator_state_dict) generator_optimizer.load_state_dict( checkpoint.generator_optimizer_state_dict) discriminator.load_state_dict(checkpoint.discriminator_state_dict) discriminator_optimizer.load_state_dict( checkpoint.discriminator_optimizer_state_dict) epoch = checkpoint.last_epoch + 1 global_step = checkpoint.last_global_step + 1 if config.global_rank == 0: os.makedirs(config.checkpoint_directory, exist_ok=True) os.makedirs(config.event_directory, exist_ok=True) summary_writer = SummaryWriter(config.event_directory) if config.train: dataset = datasets.MNIST(root='mnist', train=True, download=True, transform=transforms.Compose( [transforms.ToTensor()])) distributed_sampler = utils.data.distributed.DistributedSampler( dataset) data_loader = utils.data.DataLoader(dataset=dataset, batch_size=config.local_batch_size, num_workers=config.num_workers, sampler=distributed_sampler, pin_memory=True, drop_last=True) for epoch in range(epoch, config.num_epochs): discriminator.train() distributed_sampler.set_epoch(epoch) for step, (real_images, real_labels) in enumerate(data_loader): real_images = real_images.cuda() real_labels = real_labels.cuda() labels = nn.functional.embedding(real_labels, torch.eye(10, device='cuda')) labels = labels.repeat(1, config.image_size**2).reshape(-1, 10) latents = torch.randn(config.local_batch_size, 32, device='cuda') latents = latents.repeat(1, config.image_size**2).reshape(-1, 32) y = torch.linspace(-1, 1, config.image_size, device='cuda') x = torch.linspace(-1, 1, config.image_size, device='cuda') y, x = torch.meshgrid(y, x) positions = torch.stack((y.reshape(-1), x.reshape(-1)), dim=1) positions = positions.repeat(config.local_batch_size, 1) fake_images = generator( torch.cat((labels, latents, positions), dim=1)) fake_images = fake_images.reshape(-1, 1, config.image_size, config.image_size) real_logits = discriminator(real_images, real_labels) real_adversarial_logits, real_classification_logits = torch.split( real_logits, [1, 10], dim=1) fake_logits = discriminator(fake_images.detach(), real_labels) fake_adversarial_logits, fake_classification_logits = torch.split( fake_logits, [1, 10], dim=1) discriminator_loss = torch.mean( nn.functional.softplus(-real_adversarial_logits)) discriminator_loss += torch.mean( nn.functional.softplus(fake_adversarial_logits)) discriminator_loss += nn.functional.cross_entropy( real_classification_logits, real_labels) discriminator_loss += nn.functional.cross_entropy( fake_classification_logits, real_labels) discriminator_optimizer.zero_grad() with amp.scale_loss( discriminator_loss, discriminator_optimizer) as scaled_discriminator_loss: scaled_discriminator_loss.backward() discriminator_optimizer.step() fake_logits = discriminator(fake_images, real_labels) fake_adversarial_logits, fake_classification_logits = torch.split( fake_logits, [1, 10], dim=1) generator_loss = torch.mean( nn.functional.softplus(-fake_adversarial_logits)) generator_loss += nn.functional.cross_entropy( fake_classification_logits, real_labels) generator_optimizer.zero_grad() with amp.scale_loss( generator_loss, generator_optimizer) as scaled_generator_loss: scaled_generator_loss.backward() generator_optimizer.step() global_step += 1 if step % 100 == 0 and config.global_rank == 0: summary_writer.add_images(tag='real_images', img_tensor=real_images.repeat( 1, 3, 1, 1), global_step=global_step) summary_writer.add_images(tag='fake_images', img_tensor=fake_images.repeat( 1, 3, 1, 1), global_step=global_step) summary_writer.add_scalars( main_tag='training', tag_scalar_dict=dict( generator_loss=generator_loss, discriminator_loss=discriminator_loss, global_step=global_step)) print( f'[training] epoch: {epoch} step: {step} generator_loss: {generator_loss:.4f} discriminator_loss: {discriminator_loss:.4f}' ) torch.save( dict( generator_state_dict=generator.state_dict(), generator_optimizer_state_dict=generator_optimizer. state_dict(), discriminator_state_dict=discriminator.state_dict(), discriminator_optimizer_state_dict=discriminator_optimizer. state_dict(), last_epoch=epoch, last_global_step=global_step), f'{config.checkpoint_directory}/epoch_{epoch}') if config.generate: with torch.no_grad(): labels = torch.multinomial(torch.ones(10, device='cuda'), num_samples=1) labels = nn.functional.embedding(labels, torch.eye(10, device='cuda')) labels = labels.repeat(1, config.image_size**2).reshape(-1, 10) latents = torch.randn(1, 32, device='cuda') latents = latents.repeat(1, config.image_size**2).reshape(-1, 32) y = torch.linspace(-1, 1, config.image_size, device='cuda') x = torch.linspace(-1, 1, config.image_size, device='cuda') y, x = torch.meshgrid(y, x) positions = torch.stack((y.reshape(-1), x.reshape(-1)), dim=1) positions = positions.repeat(1, 1) images = generator(torch.cat((labels, latents, positions), dim=1)) images = images.reshape(-1, config.image_size, config.image_size) for i, image in enumerate(images.cpu().numpy()): io.imsave(f"{i}.jpg", image) summary_writer.close()
generator.load_state_dict(g_checkpoint['model_state_dict'], strict=False) discriminator.load_state_dict(d_checkpoint['model_state_dict'], strict=False) step = g_checkpoint['step'] alpha = g_checkpoint['alpha'] iteration = g_checkpoint['iteration'] print('pre-trained model is loaded step:%d, iteration:%d' % (step, iteration)) else: iteration = 0 step = 1 if args.distributed: generator = parallel.DistributedDataParallel(generator) discriminator = parallel.DistributedDataParallel(discriminator) vgg = parallel.DistributedDataParallel(vgg) face_align_net = parallel.DistributedDataParallel( torch.load('./checkpoints/compressed_model_011000.pth', map_location=lambda storage, loc: storage.cuda( args.local_rank)).to(device)) else: if len(args.gpu_ids) > 1: generator = nn.DataParallel(generator, args.gpu_ids) discriminator = nn.DataParallel(discriminator, args.gpu_ids) vgg = nn.DataParallel(vgg, args.gpu_ids) face_align_net = nn.DataParallel( torch.load('./checkpoints/compressed_model_011000.pth').to( device), args.gpu_ids) else:
args.world_size = torch.distributed.get_world_size() train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True, sampler=train_sampler) else: # load the dataset using structure DataLoader (part of torch.utils.data) dataloader = DataLoader(dataset=dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) # instantiate Generator(nn.Module) and load in cpu/gpu generator = Generator().to(device) ## DEFINE CHECKPOINT # checkpoints are used during training to save a model (model parameters I suppose) # here we are only testing the pre-trained model, thus we load (torch.load) the model if args.distributed: g_checkpoint = torch.load(args.checkpoint_path, map_location = lambda storage, loc: storage.cuda(args.local_rank)) generator = parallel.DistributedDataParallel(generator) generator = parallel.convert_syncbn_model(generator) else: g_checkpoint = torch.load(args.checkpoint_path) generator.load_state_dict(g_checkpoint['model_state_dict'], strict=False) step = g_checkpoint['step'] alpha = g_checkpoint['alpha'] iteration = g_checkpoint['iteration'] print('pre-trained model is loaded step:%d, alpha:%d iteration:%d'%(step, alpha, iteration)) MSE_Loss = nn.MSELoss() # notify all layers that you are in eval mode instead of training mode generator.eval() test(dataloader, generator, MSE_Loss, step, alpha)
def main(): global n_eval_epoch ## dataloader dataset_train = ImageNet(datapth, mode='train', cropsize=cropsize) sampler_train = torch.utils.data.distributed.DistributedSampler( dataset_train, shuffle=True) batch_sampler_train = torch.utils.data.sampler.BatchSampler( sampler_train, batchsize, drop_last=True ) dl_train = DataLoader( dataset_train, batch_sampler=batch_sampler_train, num_workers=num_workers, pin_memory=True ) dataset_eval = ImageNet(datapth, mode='val', cropsize=cropsize) sampler_val = torch.utils.data.distributed.DistributedSampler( dataset_eval, shuffle=False) batch_sampler_val = torch.utils.data.sampler.BatchSampler( sampler_val, batchsize * 2, drop_last=False ) dl_eval = DataLoader( dataset_eval, batch_sampler=batch_sampler_val, num_workers=4, pin_memory=True ) n_iters_per_epoch = len(dataset_train) // n_gpus // batchsize n_iters = n_epoches * n_iters_per_epoch ## model # model = EfficientNet(model_type, n_classes) model = build_model(**model_args) ## sync bn # if use_sync_bn: model = nn.SyncBatchNorm.convert_sync_batchnorm(model) init_model_weights(model) model.cuda() if use_sync_bn: model = parallel.convert_syncbn_model(model) crit = nn.CrossEntropyLoss() # crit = LabelSmoothSoftmaxCEV3(lb_smooth) # crit = SoftmaxCrossEntropyV2() ## optimizer optim = set_optimizer(model, lr, opt_wd, momentum, nesterov=nesterov) ## apex model, optim = amp.initialize(model, optim, opt_level=fp16_level) ## ema ema = EMA(model, ema_alpha) ## ddp training model = parallel.DistributedDataParallel(model, delay_allreduce=True) # local_rank = dist.get_rank() # model = nn.parallel.DistributedDataParallel( # model, device_ids=[local_rank, ], output_device=local_rank # ) ## log meters time_meter = TimeMeter(n_iters) loss_meter = AvgMeter() logger = logging.getLogger() # for mixup label_encoder = OnehotEncoder(n_classes=model_args['n_classes'], lb_smooth=lb_smooth) mixuper = MixUper(mixup_alpha, mixup=mixup) ## train loop for e in range(n_epoches): sampler_train.set_epoch(e) model.train() for idx, (im, lb) in enumerate(dl_train): im, lb= im.cuda(), lb.cuda() # lb = label_encoder(lb) # im, lb = mixuper(im, lb) optim.zero_grad() logits = model(im) loss = crit(logits, lb) #+ cal_l2_loss(model, weight_decay) # loss.backward() with amp.scale_loss(loss, optim) as scaled_loss: scaled_loss.backward() optim.step() torch.cuda.synchronize() ema.update_params() time_meter.update() loss_meter.update(loss.item()) if (idx + 1) % 200 == 0: t_intv, eta = time_meter.get() lr_log = scheduler.get_lr_ratio() * lr msg = 'epoch: {}, iter: {}, lr: {:.4f}, loss: {:.4f}, time: {:.2f}, eta: {}'.format( e + 1, idx + 1, lr_log, loss_meter.get()[0], t_intv, eta) logger.info(msg) scheduler.step() torch.cuda.empty_cache() if (e + 1) % n_eval_epoch == 0: if e > 50: n_eval_epoch = 5 logger.info('evaluating...') acc_1, acc_5, acc_1_ema, acc_5_ema = evaluate(ema, dl_eval) msg = 'epoch: {}, naive_acc1: {:.4}, naive_acc5: {:.4}, ema_acc1: {:.4}, ema_acc5: {:.4}'.format(e + 1, acc_1, acc_5, acc_1_ema, acc_5_ema) logger.info(msg) if dist.is_initialized() and dist.get_rank() == 0: torch.save(model.module.state_dict(), './res/model_final.pth') torch.save(ema.ema_model.state_dict(), './res/model_final_ema.pth')
def main(): with open(args.config) as file: config = Dict(json.load(file)) distributed.init_process_group(backend='nccl') world_size = distributed.get_world_size() global_rank = distributed.get_rank() device_count = torch.cuda.device_count() local_rank = args.local_rank torch.cuda.set_device(local_rank) print( f'Enabled distributed training. (global_rank: {global_rank}/{world_size}, local_rank: {local_rank}/{device_count})' ) torch.manual_seed(0) model = models.resnet50() model.fc = nn.Linear(in_features=2048, out_features=10, bias=True) model = model.cuda() config.global_batch_size = config.local_batch_size * world_size config.lr = config.base_lr * config.global_batch_size / 256 optimizer = torch.optim.SGD(params=model.parameters(), lr=config.lr, momentum=config.momentum, weight_decay=config.weight_decay) model, optimizer = amp.initialize(model, optimizer, opt_level=config.opt_level) model = parallel.DistributedDataParallel(model, delay_allreduce=True) last_epoch = -1 if args.checkpoint: checkpoint = Dict( torch.load(args.checkpoint), map_location=lambda storage, location: storage.cuda(local_rank)) model.load_state_dict(checkpoint.model_state_dict) optimizer.load_state_dict(checkpoint.optimizer_state_dict) last_epoch = checkpoint.last_epoch criterion = nn.CrossEntropyLoss(reduction='mean').cuda() scheduler = optim.lr_scheduler.MultiStepLR(optimizer=optimizer, milestones=config.lr_milestones, gamma=config.lr_gamma, last_epoch=last_epoch) summary_writer = SummaryWriter(config.event_directory) if args.training: os.makedirs(config.checkpoint_directory, exist_ok=True) os.makedirs(config.event_directory, exist_ok=True) # NOTE: When partition for distributed training executed? # NOTE: Should random seed be the same in the same node? train_pipeline = TrainPipeline(root=config.train_root, batch_size=config.local_batch_size, num_threads=config.num_workers, device_id=local_rank, num_shards=world_size, shard_id=global_rank, image_size=224) train_pipeline.build() # NOTE: What's `epoch_size`? # NOTE: Is that len(dataset) ? train_data_loader = pytorch.DALIClassificationIterator( pipelines=train_pipeline, size=list(train_pipeline.epoch_size().values())[0] // world_size, auto_reset=True, stop_at_epoch=True) val_pipeline = ValPipeline(root=config.val_root, batch_size=config.local_batch_size, num_threads=config.num_workers, device_id=local_rank, num_shards=world_size, shard_id=global_rank, image_size=224) val_pipeline.build() val_data_loader = pytorch.DALIClassificationIterator( pipelines=val_pipeline, size=list(val_pipeline.epoch_size().values())[0] // world_size, auto_reset=True, stop_at_epoch=True) for epoch in range(last_epoch + 1, config.num_epochs): model.train() scheduler.step() for step, data in enumerate(train_data_loader): images = data[0]["data"] labels = data[0]["label"] images = images.cuda() labels = labels.cuda() labels = labels.squeeze().long() logits = model(images) loss = criterion(logits, labels) optimizer.zero_grad() with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() optimizer.step() if global_rank == 0: summary_writer.add_scalars(main_tag='training', tag_scalar_dict=dict(loss=loss)) print( f'[training] epoch: {epoch} step: {step} loss: {loss}') torch.save( dict(model_state_dict=model.state_dict(), optimizer_state_dict=optimizer.state_dict(), last_epoch=epoch), f'{config.checkpoint_directory}/epoch_{epoch}') model.eval() total_steps = 0 total_loss = 0 total_accurtacy = 0 with torch.no_grad(): for step, data in enumerate(val_data_loader): images = data[0]["data"] labels = data[0]["label"] images = images.cuda() labels = labels.cuda() labels = labels.squeeze().long() logits = model(images) loss = criterion(logits, labels) / world_size distributed.all_reduce(loss) predictions = logits.topk(1)[1].squeeze() accuracy = torch.mean( (predictions == labels).float()) / world_size distributed.all_reduce(accuracy) total_steps += 1 total_loss += loss total_accurtacy += accuracy loss = total_loss / total_steps accuracy = total_accurtacy / total_steps if global_rank == 0: summary_writer.add_scalars(main_tag='validation', tag_scalar_dict=dict( loss=loss, accuracy=accuracy)) print( f'[validation] epoch: {epoch} loss: {loss} accuracy: {accuracy}' ) if args.evaluation: test_pipeline = ValPipeline(root=config.val_root, batch_size=config.local_batch_size, num_threads=config.num_workers, device_id=local_rank, num_shards=world_size, shard_id=global_rank, image_size=224) test_pipeline.build() test_data_loader = pytorch.DALIClassificationIterator( pipelines=test_pipeline, size=list(test_pipeline.epoch_size().values())[0] // world_size, auto_reset=True, stop_at_epoch=True) model.eval() total_steps = 0 total_loss = 0 total_accurtacy = 0 with torch.no_grad(): for step, data in enumerate(val_data_loader): images = data[0]["data"] labels = data[0]["label"] images = images.cuda() labels = labels.cuda() labels = labels.squeeze().long() logits = model(images) loss = criterion(logits, labels) / world_size distributed.all_reduce(loss) predictions = logits.topk(1)[1].squeeze() accuracy = torch.mean( (predictions == labels).float()) / world_size distributed.all_reduce(accuracy) total_steps += 1 total_loss += loss total_accurtacy += accuracy loss = total_loss / total_steps accuracy = total_accurtacy / total_steps if global_rank == 0: summary_writer.add_scalars(main_tag='validation', tag_scalar_dict=dict(loss=loss, accuracy=accuracy)) print( f'[evaluation] epoch: {last_epoch} loss: {loss} accuracy: {accuracy}' ) summary_writer.close()