def setup_distributed(num_images=None): """Setup distributed related parameters.""" # init distributed if FLAGS.use_distributed: udist.init_dist() FLAGS.batch_size = udist.get_world_size() * FLAGS.per_gpu_batch_size FLAGS._loader_batch_size = FLAGS.per_gpu_batch_size if FLAGS.bn_calibration: FLAGS._loader_batch_size_calib = \ FLAGS.bn_calibration_per_gpu_batch_size FLAGS.data_loader_workers = round( FLAGS.data_loader_workers / udist.get_local_size() ) # Per_gpu_workers(the function will return the nearest integer else: count = torch.cuda.device_count() FLAGS.batch_size = count * FLAGS.per_gpu_batch_size FLAGS._loader_batch_size = FLAGS.batch_size if FLAGS.bn_calibration: FLAGS._loader_batch_size_calib = \ FLAGS.bn_calibration_per_gpu_batch_size * count if hasattr(FLAGS, 'base_lr'): FLAGS.lr = FLAGS.base_lr * (FLAGS.batch_size / FLAGS.base_total_batch) if num_images: # NOTE: don't drop last batch, thus must use ceil, otherwise learning # rate will be negative # the smallest integer not less than x FLAGS._steps_per_epoch = math.ceil(num_images / FLAGS.batch_size)
def _train_one_batch(self, x, y, optimizer, lr_scheduler, meters, criterions, end): top1_meter, top5_meter, loss_meter, data_time = meters criterion = criterions[0] world_size = dist.get_world_size() lr_scheduler.step(self.cur_step) self.cur_step += 1 data_time.update(time.time() - end) self.model.zero_grad() out = self.model(x) loss = criterion(out, y) loss /= world_size top1, top5 = accuracy(out, y, top_k=(1, 5)) reduced_loss = dist.all_reduce(loss.clone()) reduced_top1 = dist.all_reduce(top1.clone(), div=True) reduced_top5 = dist.all_reduce(top5.clone(), div=True) loss_meter.update(reduced_loss.item()) top1_meter.update(reduced_top1.item()) top5_meter.update(reduced_top5.item()) loss.backward() dist.average_gradient(self.model.parameters()) optimizer.step()
def reduce_and_flush_meters(meters, method='avg'): """Sync and flush meters.""" if not FLAGS.use_distributed: results = flush_scalar_meters(meters) else: results = {} assert isinstance(meters, dict), "meters should be a dict." # NOTE: Ensure same order, otherwise may deadlock for name in sorted(meters.keys()): meter = meters[name] if not isinstance(meter, ScalarMeter): continue if method == 'avg': method_fun = torch.mean elif method == 'sum': method_fun = torch.sum elif method == 'max': method_fun = torch.max elif method == 'min': method_fun = torch.min else: raise NotImplementedError( 'flush method: {} is not yet implemented.'.format(method)) tensor = torch.tensor(meter.values).cuda() gather_tensors = [ torch.ones_like(tensor) for _ in range(udist.get_world_size()) ] dist.all_gather(gather_tensors, tensor) value = method_fun(torch.cat(gather_tensors)) meter.flush(value) results[name] = value return results
def reduce_tensor(inp): """ Reduce the loss from all processes so that process with rank 0 has the averaged results. """ world_size = dist.get_world_size() if world_size < 2: return inp with torch.no_grad(): reduced_inp = inp torch.distributed.reduce(reduced_inp, dst=0) return reduced_inp / world_size
def main(): """Entry.""" # init distributed global is_root_rank if FLAGS.use_distributed: udist.init_dist() FLAGS.batch_size = udist.get_world_size() * FLAGS.per_gpu_batch_size FLAGS._loader_batch_size = FLAGS.per_gpu_batch_size if FLAGS.bn_calibration: FLAGS._loader_batch_size_calib = FLAGS.bn_calibration_per_gpu_batch_size FLAGS.data_loader_workers = round(FLAGS.data_loader_workers / udist.get_local_size()) is_root_rank = udist.is_master() else: count = torch.cuda.device_count() FLAGS.batch_size = count * FLAGS.per_gpu_batch_size FLAGS._loader_batch_size = FLAGS.batch_size if FLAGS.bn_calibration: FLAGS._loader_batch_size_calib = FLAGS.bn_calibration_per_gpu_batch_size * count is_root_rank = True FLAGS.lr = FLAGS.base_lr * (FLAGS.batch_size / FLAGS.base_total_batch) # NOTE: don't drop last batch, thus must use ceil, otherwise learning rate # will be negative FLAGS._steps_per_epoch = int(np.ceil(NUM_IMAGENET_TRAIN / FLAGS.batch_size)) if is_root_rank: FLAGS.log_dir = '{}/{}'.format(FLAGS.log_dir, time.strftime("%Y%m%d-%H%M%S")) create_exp_dir( FLAGS.log_dir, FLAGS.config_path, blacklist_dirs=[ 'exp', '.git', 'pretrained', 'tmp', 'deprecated', 'bak', ], ) setup_logging(FLAGS.log_dir) for k, v in _ENV_EXPAND.items(): logging.info('Env var expand: {} to {}'.format(k, v)) logging.info(FLAGS) set_random_seed(FLAGS.get('random_seed', 0)) with SummaryWriterManager(): train_val_test()
def build_data_loader(): logger.info("build train dataset") # train_dataset train_dataset = TrainDataset() logger.info("build dataset done") train_sampler = None if get_world_size() > 1: train_sampler = DistributedSampler(train_dataset) train_dataloader = DataLoader(train_dataset, batch_size=cfg.TRAIN.BATCH_SIZE, num_workers=cfg.TRAIN.NUM_WORKERS, pin_memory=True, sampler=train_sampler) return train_dataloader
def KineticsSounds(cfg, split): if split == 'train': max_idx = 19 elif split == 'val': max_idx = 1 elif split == 'test': max_idx = 2 dataset_root = cfg.DATASET_ROOT if dataset_root.endswith('/'): dataset_root = dataset_root[:-1] url = f"{dataset_root}/KineticsSounds/shards-{split}/shard-{{000000..{max_idx:06d}}}.tar" if cfg.STORAGE_SAS_KEY: url += cfg.STORAGE_SAS_KEY _decoder = Decoder(cfg, "KineticsSounds", split) if split == 'train': batch_size = int(cfg.TRAIN.BATCH_SIZE / cfg.SOLVER.GRADIENT_ACCUMULATION_STEPS) batch_size = int(batch_size / du.get_world_size()) length = int(cfg.TRAIN.DATASET_SIZE / du.get_world_size()) nominal = int(length / batch_size) elif split == 'val': batch_size = int(cfg.TRAIN.BATCH_SIZE / du.get_world_size()) length = int(cfg.VAL.DATASET_SIZE / du.get_world_size()) nominal = int(length / batch_size) elif split == 'test': batch_size = int(cfg.TEST.BATCH_SIZE / du.get_world_size()) length = math.ceil(cfg.TEST.DATASET_SIZE / du.get_world_size()) nominal = math.ceil(length / batch_size) wds.filters.batched = wds.filters.Curried( partial(wds.filters.batched_, collation_fn=COLLATE_FN["kinetics"])) dataset = wds.Dataset( url, handler=wds.warn_and_continue, shard_selection=du.shard_selection, length=length, ) if split == 'train': dataset = dataset.shuffle(100) dataset = (dataset.map_dict( handler=wds.warn_and_continue, mp4=_decoder.mp4decode, json=_decoder.jsondecode, )) if cfg.DATA_LOADER.NUM_WORKERS > 0: length = nominal else: nominal = length dataset = wds.ResizedDataset( dataset, length=length, nominal=nominal, ) return dataset
def check_dist_init(config, logger): # check distributed initialization if config.distributed.enable: import os # for slurm try: node_id = int(os.environ['SLURM_NODEID']) except KeyError: return rank = dist.get_rank() world_size = dist.get_world_size() gpu_id = dist.gpu_id logger.info('World: {}/Node: {}/Rank: {}/GpuId: {} initialized.' .format(world_size, node_id, rank, gpu_id))
def _train_one_batch(self, x, y, optimizer, lr_scheduler, meters, criterions, end): top1_meter, top5_meter, loss_meter, data_time = meters criterion, distill_loss = criterions world_size = dist.get_world_size() max_width = self.config.training.sandwich.max_width lr_scheduler.step(self.cur_step) self.cur_step += 1 data_time.update(time.time() - end) self.model.zero_grad() max_pred = None for idx in range(self.config.training.sandwich.num_sample): # sandwich rule top1_m, top5_m, loss_m = self._set_width(idx, top1_meter, top5_meter, loss_meter) out = self.model(x) if self.config.training.distillation.enable: if idx == 0: max_pred = out.detach() loss = criterion(out, y) else: loss = self.config.training.distillation.loss_weight * \ distill_loss(out, max_pred) if self.config.training.distillation.hard_label: loss += criterion(out, y) else: loss = criterion(out, y) loss /= world_size top1, top5 = accuracy(out, y, top_k=(1, 5)) reduced_loss = dist.all_reduce(loss.clone()) reduced_top1 = dist.all_reduce(top1.clone(), div=True) reduced_top5 = dist.all_reduce(top5.clone(), div=True) loss_m.update(reduced_loss.item()) top1_m.update(reduced_top1.item()) top5_m.update(reduced_top5.item()) loss.backward() dist.average_gradient(self.model.parameters()) optimizer.step()
def _train_one_batch(self, x, y, optimizer, lr_scheduler, meters, criterions, end): lr_scheduler, arch_lr_scheduler = lr_scheduler optimizer, arch_optimizer = optimizer top1_meter, top5_meter, loss_meter, arch_loss_meter, \ floss_meter, eflops_meter, arch_top1_meter, data_time = meters criterion, _ = criterions self.model.module.set_alpha_training(False) super(DMCPRunner, self)._train_one_batch( x, y, optimizer, lr_scheduler, [top1_meter, top5_meter, loss_meter, data_time], criterions, end) arch_lr_scheduler.step(self.cur_step) world_size = dist.get_world_size() # train architecture params if self.cur_step >= self.config.arch.start_train \ and self.cur_step % self.config.arch.train_freq == 0: self._set_width(0, top1_meter, top5_meter, loss_meter) self.model.module.set_alpha_training(True) self.model.zero_grad() arch_out = self.model(x) arch_loss = criterion(arch_out, y) arch_loss /= world_size floss, eflops = flop_loss(self.config, self.model) floss /= world_size arch_top1 = accuracy(arch_out, y, top_k=(1, ))[0] reduced_arch_loss = dist.all_reduce(arch_loss.clone()) reduced_floss = dist.all_reduce(floss.clone()) reduced_eflops = dist.all_reduce(eflops.clone(), div=True) reduced_arch_top1 = dist.all_reduce(arch_top1.clone(), div=True) arch_loss_meter.update(reduced_arch_loss.item()) floss_meter.update(reduced_floss.item()) eflops_meter.update(reduced_eflops.item()) arch_top1_meter.update(reduced_arch_top1.item()) floss.backward() arch_loss.backward() dist.average_gradient(self.model.module.arch_parameters()) arch_optimizer.step()
def setup_logging(output_dir=None): """ Sets up the logging for multiple processes. Only enable the logging for the master process, and suppress logging for the non-master processes. """ # Set up logging format. _FORMAT = "[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s" if du.is_master_proc(): # Enable logging for the master process. logging.root.handlers = [] logging.basicConfig(level=logging.INFO, format=_FORMAT, stream=sys.stdout) else: # Suppress logging for non-master processes. _suppress_print() logger = logging.getLogger() logger.setLevel(logging.DEBUG) logger.propagate = False plain_formatter = logging.Formatter( "[%(asctime)s][%(levelname)s] %(name)s: %(lineno)4d: %(message)s", datefmt="%m/%d %H:%M:%S", ) if du.is_master_proc(): ch = logging.StreamHandler(stream=sys.stdout) ch.setLevel(logging.DEBUG) ch.setFormatter(plain_formatter) logger.addHandler(ch) if output_dir is not None and du.is_master_proc(du.get_world_size()): filename = os.path.join(output_dir, "stdout.log") fh = logging.StreamHandler(_cached_log_stream(filename)) fh.setLevel(logging.DEBUG) fh.setFormatter(plain_formatter) logger.addHandler(fh)
def train(train_dataloader, model, optimizer, lr_scheduler): def is_valid_number(x): return not (math.isnan(x) or math.isinf(x) or x > 1e4) logger.info("model\n{}".format(describe(model.module))) tb_writer = SummaryWriter(cfg.TRAIN.LOG_DIR) average_meter = AverageMeter() start_epoch = cfg.TRAIN.START_EPOCH world_size = get_world_size() num_per_epoch = len( train_dataloader.dataset) // (cfg.TRAIN.BATCH_SIZE * world_size) iter = 0 if not os.path.exists(cfg.TRAIN.SNAPSHOT_DIR) and get_rank() == 0: os.makedirs(cfg.TRAIN.SNAPSHOT_DIR) for epoch in range(cfg.TRAIN.START_EPOCH, cfg.TRAIN.EPOCHS): if cfg.BACKBONE.TRAIN_EPOCH == epoch: logger.info('begin to train backbone!') optimizer, lr_scheduler = build_optimizer_lr(model.module, epoch) logger.info("model\n{}".format(describe(model.module))) train_dataloader.dataset.shuffle() lr_scheduler.step(epoch) # log for lr if get_rank() == 0: for idx, pg in enumerate(optimizer.param_groups): tb_writer.add_scalar('lr/group{}'.format(idx + 1), pg['lr'], iter) cur_lr = lr_scheduler.get_cur_lr() for data in train_dataloader: begin = time.time() examplar_img = data['examplar_img'].cuda() search_img = data['search_img'].cuda() gt_cls = data['gt_cls'].cuda() gt_delta = data['gt_delta'].cuda() delta_weight = data['delta_weight'].cuda() data_time = time.time() - begin losses = model.forward(examplar_img, search_img, gt_cls, gt_delta, delta_weight) cls_loss = losses['cls_loss'] loc_loss = losses['loc_loss'] loss = losses['total_loss'] if is_valid_number(loss.item()): optimizer.zero_grad() loss.backward() reduce_gradients(model) if get_rank() == 0 and cfg.TRAIN.LOG_GRAD: log_grads(model.module, tb_writer, iter) clip_grad_norm_(model.parameters(), cfg.TRAIN.GRAD_CLIP) optimizer.step() batch_time = time.time() - begin batch_info = {} batch_info['data_time'] = average_reduce(data_time) batch_info['batch_time'] = average_reduce(batch_time) for k, v in losses.items(): batch_info[k] = average_reduce(v) average_meter.update(**batch_info) if get_rank() == 0: for k, v in batch_info.items(): tb_writer.add_scalar(k, v, iter) if iter % cfg.TRAIN.PRINT_EVERY == 0: logger.info( 'epoch: {}, iter: {}, cur_lr:{}, cls_loss: {}, loc_loss: {}, loss: {}' .format(epoch + 1, iter, cur_lr, cls_loss.item(), loc_loss.item(), loss.item())) print_speed(iter + 1 + start_epoch * num_per_epoch, average_meter.batch_time.avg, cfg.TRAIN.EPOCHS * num_per_epoch) iter += 1 # save model if get_rank() == 0: state = { 'model': model.module.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch + 1 } logger.info('save snapshot to {}/checkpoint_e{}.pth'.format( cfg.TRAIN.SNAPSHOT_DIR, epoch + 1)) torch.save( state, '{}/checkpoint_e{}.pth'.format(cfg.TRAIN.SNAPSHOT_DIR, epoch + 1))
def data_loader(train_set, val_set, test_set): """get data loader""" train_loader = None val_loader = None test_loader = None # infer batch size if getattr(FLAGS, 'batch_size', False): if getattr(FLAGS, 'batch_size_per_gpu', False): assert FLAGS.batch_size == ( FLAGS.batch_size_per_gpu * FLAGS.num_gpus_per_job) else: assert FLAGS.batch_size % FLAGS.num_gpus_per_job == 0 FLAGS.batch_size_per_gpu = ( FLAGS.batch_size // FLAGS.num_gpus_per_job) elif getattr(FLAGS, 'batch_size_per_gpu', False): FLAGS.batch_size = FLAGS.batch_size_per_gpu * FLAGS.num_gpus_per_job else: raise ValueError('batch size (per gpu) is not defined') batch_size = int(FLAGS.batch_size/get_world_size()) if FLAGS.data_loader == 'imagenet1k_basic': if getattr(FLAGS, 'distributed', False): if FLAGS.test_only: train_sampler = None else: train_sampler = DistributedSampler(train_set) val_sampler = DistributedSampler(val_set) else: train_sampler = None val_sampler = None if not FLAGS.test_only: train_loader = torch.utils.data.DataLoader( train_set, batch_size=batch_size, shuffle=(train_sampler is None), sampler=train_sampler, pin_memory=True, num_workers=FLAGS.data_loader_workers, drop_last=getattr(FLAGS, 'drop_last', False)) val_loader = torch.utils.data.DataLoader( val_set, batch_size=batch_size, shuffle=False, sampler=val_sampler, pin_memory=True, num_workers=FLAGS.data_loader_workers, drop_last=getattr(FLAGS, 'drop_last', False)) test_loader = val_loader else: try: data_loader_lib = importlib.import_module(FLAGS.data_loader) return data_loader_lib.data_loader(train_set, val_set, test_set) except ImportError: raise NotImplementedError( 'Data loader {} is not yet implemented.'.format( FLAGS.data_loader)) if train_loader is not None: FLAGS.data_size_train = len(train_loader.dataset) if val_loader is not None: FLAGS.data_size_val = len(val_loader.dataset) if test_loader is not None: FLAGS.data_size_test = len(test_loader.dataset) return train_loader, val_loader, test_loader