def validate(epoch, calib_loader, val_loader, criterion, val_meters, model_wrapper, ema, phase): """Calibrate and validate.""" assert phase in ['test', 'val'] model_eval_wrapper = get_ema_model(ema, model_wrapper) # bn_calibration if FLAGS.get('bn_calibration', False): if not FLAGS.use_distributed: logging.warning( 'Only GPU0 is used when calibration when use DataParallel') with torch.no_grad(): _ = run_one_epoch(epoch, calib_loader, model_eval_wrapper, criterion, None, None, None, None, val_meters, max_iter=FLAGS.bn_calibration_steps, phase='bn_calibration') if FLAGS.use_distributed: udist.allreduce_bn(model_eval_wrapper) # val with torch.no_grad(): results = run_one_epoch(epoch, val_loader, model_eval_wrapper, criterion, None, None, None, None, val_meters, phase=phase) summary_bn(model_eval_wrapper, phase) return results, model_eval_wrapper
def run_one_epoch(epoch, loader, model, criterion, optimizer, lr_scheduler, ema, meters, max_iter=None, phase='train'): """Run one epoch.""" assert phase in ['train', 'val', 'test', 'bn_calibration' ], "phase not be in train/val/test/bn_calibration." train = phase == 'train' if train: model.train() else: model.eval() if phase == 'bn_calibration': model.apply(bn_calibration) if FLAGS.use_distributed: loader.sampler.set_epoch(epoch) results = None data_iterator = iter(loader) if FLAGS.use_distributed: data_fetcher = dataflow.DataPrefetcher(data_iterator) else: # TODO(meijieru): prefetch for non distributed logging.warning('Not use prefetcher') data_fetcher = data_iterator for batch_idx, (input, target) in enumerate(data_fetcher): # used for bn calibration if max_iter is not None: assert phase == 'bn_calibration' if batch_idx >= max_iter: break target = target.cuda(non_blocking=True) if train: optimizer.zero_grad() loss = mc.forward_loss(model, criterion, input, target, meters) loss_l2 = optim.cal_l2_loss(model, FLAGS.weight_decay, FLAGS.weight_decay_method) loss = loss + loss_l2 loss.backward() if FLAGS.use_distributed: udist.allreduce_grads(model) if FLAGS._global_step % FLAGS.log_interval == 0: results = mc.reduce_and_flush_meters(meters) if udist.is_master(): logging.info('Epoch {}/{} Iter {}/{} {}: '.format( epoch, FLAGS.num_epochs, batch_idx, len(loader), phase) + ', '.join('{}: {:.4f}'.format(k, v) for k, v in results.items())) for k, v in results.items(): mc.summary_writer.add_scalar('{}/{}'.format(phase, k), v, FLAGS._global_step) if udist.is_master( ) and FLAGS._global_step % FLAGS.log_interval == 0: mc.summary_writer.add_scalar('train/learning_rate', optimizer.param_groups[0]['lr'], FLAGS._global_step) mc.summary_writer.add_scalar('train/l2_regularize_loss', extract_item(loss_l2), FLAGS._global_step) mc.summary_writer.add_scalar( 'train/current_epoch', FLAGS._global_step / FLAGS._steps_per_epoch, FLAGS._global_step) if FLAGS.data_loader_workers > 0: mc.summary_writer.add_scalar( 'data/train/prefetch_size', get_data_queue_size(data_iterator), FLAGS._global_step) optimizer.step() lr_scheduler.step() if FLAGS.use_distributed and FLAGS.allreduce_bn: udist.allreduce_bn(model) FLAGS._global_step += 1 # NOTE: after steps count upate if ema is not None: model_unwrap = mc.unwrap_model(model) ema_names = ema.average_names() params = get_params_by_name(model_unwrap, ema_names) for name, param in zip(ema_names, params): ema(name, param, FLAGS._global_step) else: mc.forward_loss(model, criterion, input, target, meters) if not train: results = mc.reduce_and_flush_meters(meters) if udist.is_master(): logging.info( 'Epoch {}/{} {}: '.format(epoch, FLAGS.num_epochs, phase) + ', '.join('{}: {:.4f}'.format(k, v) for k, v in results.items())) for k, v in results.items(): mc.summary_writer.add_scalar('{}/{}'.format(phase, k), v, FLAGS._global_step) return results
def validate(epoch, calib_loader, val_loader, criterion, val_meters, model_wrapper, ema, phase, segval=None, val_set=None): """Calibrate and validate.""" assert phase in ['test', 'val'] model_eval_wrapper = mc.get_ema_model(ema, model_wrapper) # bn_calibration if FLAGS.prune_params['method'] is not None: if FLAGS.get('bn_calibration', False): if not FLAGS.use_distributed: logging.warning( 'Only GPU0 is used when calibration when use DataParallel') with torch.no_grad(): _ = run_one_epoch(epoch, calib_loader, model_eval_wrapper, criterion, None, None, None, None, val_meters, max_iter=FLAGS.bn_calibration_steps, phase='bn_calibration') if FLAGS.use_distributed: udist.allreduce_bn(model_eval_wrapper) # val with torch.no_grad(): if FLAGS.model_kwparams.task == 'segmentation': if FLAGS.dataset == 'coco': results = 0 if udist.is_master(): results = keypoint_val(val_set, val_loader, model_eval_wrapper.module, criterion) else: assert segval is not None results = segval.run( epoch, val_loader, model_eval_wrapper.module if FLAGS.single_gpu_test else model_eval_wrapper, FLAGS) else: results = run_one_epoch(epoch, val_loader, model_eval_wrapper, criterion, None, None, None, None, val_meters, phase=phase) summary_bn(model_eval_wrapper, phase) return results, model_eval_wrapper
def run_one_epoch(epoch, loader, model, criterion, optimizer, lr_scheduler, ema, rho_scheduler, meters, max_iter=None, phase='train'): """Run one epoch.""" assert phase in [ 'train', 'val', 'test', 'bn_calibration' ] or phase.startswith( 'prune'), "phase not be in train/val/test/bn_calibration/prune." train = phase == 'train' if train: model.train() else: model.eval() if phase == 'bn_calibration': model.apply(bn_calibration) if not FLAGS.use_hdfs: if FLAGS.use_distributed: loader.sampler.set_epoch(epoch) results = None data_iterator = iter(loader) if not FLAGS.use_hdfs: if FLAGS.use_distributed: if FLAGS.dataset == 'coco': data_fetcher = dataflow.DataPrefetcherKeypoint(data_iterator) else: data_fetcher = dataflow.DataPrefetcher(data_iterator) else: logging.warning('Not use prefetcher') data_fetcher = data_iterator for batch_idx, data in enumerate(data_fetcher): if FLAGS.dataset == 'coco': input, target, target_weight, meta = data # print(input.shape, target.shape, target_weight.shape, meta) # (4, 3, 384, 288), (4, 17, 96, 72), (4, 17, 1), else: input, target = data # if batch_idx > 400: # break # used for bn calibration if max_iter is not None: assert phase == 'bn_calibration' if batch_idx >= max_iter: break target = target.cuda(non_blocking=True) if train: optimizer.zero_grad() rho = rho_scheduler(FLAGS._global_step) if FLAGS.dataset == 'coco': outputs = model(input) if isinstance(outputs, list): loss = criterion(outputs[0], target, target_weight) for output in outputs[1:]: loss += criterion(output, target, target_weight) else: output = outputs loss = criterion(output, target, target_weight) _, avg_acc, cnt, pred = accuracy_keypoint( output.detach().cpu().numpy(), target.detach().cpu().numpy()) # cnt=17 meters['acc'].cache(avg_acc) meters['loss'].cache(loss) else: loss = mc.forward_loss(model, criterion, input, target, meters, task=FLAGS.model_kwparams.task, distill=FLAGS.distill) if FLAGS.prune_params['method'] is not None: loss_l2 = optim.cal_l2_loss( model, FLAGS.weight_decay, FLAGS.weight_decay_method) # manual weight decay loss_bn_l1 = prune.cal_bn_l1_loss(get_prune_weights(model), FLAGS._bn_to_prune.penalty, rho) if FLAGS.prune_params.use_transformer: transformer_weights = get_prune_weights(model, True) loss_bn_l1 += prune.cal_bn_l1_loss( transformer_weights, FLAGS._bn_to_prune_transformer.penalty, rho) transformer_dict = [] for name, weight in zip( FLAGS._bn_to_prune_transformer.weight, transformer_weights): transformer_dict.append( sum(weight > FLAGS.model_shrink_threshold).item()) FLAGS._bn_to_prune_transformer.add_info_list( 'channels', transformer_dict) FLAGS._bn_to_prune_transformer.update_penalty() if udist.is_master( ) and FLAGS._global_step % FLAGS.log_interval == 0: logging.info(transformer_dict) # logging.info(FLAGS._bn_to_prune_transformer.penalty) meters['loss_l2'].cache(loss_l2) meters['loss_bn_l1'].cache(loss_bn_l1) loss = loss + loss_l2 + loss_bn_l1 loss.backward() if FLAGS.use_distributed: udist.allreduce_grads(model) if FLAGS._global_step % FLAGS.log_interval == 0: results = mc.reduce_and_flush_meters(meters) if udist.is_master(): logging.info('Epoch {}/{} Iter {}/{} Lr: {} {}: '.format( epoch, FLAGS.num_epochs, batch_idx, len(loader), optimizer.param_groups[0]["lr"], phase) + ', '.join('{}: {:.4f}'.format(k, v) for k, v in results.items())) for k, v in results.items(): mc.summary_writer.add_scalar('{}/{}'.format(phase, k), v, FLAGS._global_step) if udist.is_master( ) and FLAGS._global_step % FLAGS.log_interval == 0: mc.summary_writer.add_scalar('train/learning_rate', optimizer.param_groups[0]['lr'], FLAGS._global_step) if FLAGS.prune_params['method'] is not None: mc.summary_writer.add_scalar('train/l2_regularize_loss', extract_item(loss_l2), FLAGS._global_step) mc.summary_writer.add_scalar('train/bn_l1_loss', extract_item(loss_bn_l1), FLAGS._global_step) mc.summary_writer.add_scalar('prune/rho', rho, FLAGS._global_step) mc.summary_writer.add_scalar( 'train/current_epoch', FLAGS._global_step / FLAGS._steps_per_epoch, FLAGS._global_step) if FLAGS.data_loader_workers > 0: mc.summary_writer.add_scalar( 'data/train/prefetch_size', get_data_queue_size(data_iterator), FLAGS._global_step) if udist.is_master( ) and FLAGS._global_step % FLAGS.log_interval_detail == 0: summary_bn(model, 'train') optimizer.step() if FLAGS.lr_scheduler == 'poly': optim.poly_learning_rate( optimizer, FLAGS.lr, epoch * FLAGS._steps_per_epoch + batch_idx + 1, FLAGS.num_epochs * FLAGS._steps_per_epoch) else: lr_scheduler.step() if FLAGS.use_distributed and FLAGS.allreduce_bn: udist.allreduce_bn(model) FLAGS._global_step += 1 # NOTE: after steps count update if ema is not None: model_unwrap = mc.unwrap_model(model) ema_names = ema.average_names() params = get_params_by_name(model_unwrap, ema_names) for name, param in zip(ema_names, params): ema(name, param, FLAGS._global_step) else: if FLAGS.dataset == 'coco': outputs = model(input) if isinstance(outputs, list): loss = criterion(outputs[0], target, target_weight) for output in outputs[1:]: loss += criterion(output, target, target_weight) else: output = outputs loss = criterion(output, target, target_weight) _, avg_acc, cnt, pred = accuracy_keypoint( output.detach().cpu().numpy(), target.detach().cpu().numpy()) # cnt=17 meters['acc'].cache(avg_acc) meters['loss'].cache(loss) else: mc.forward_loss(model, criterion, input, target, meters, task=FLAGS.model_kwparams.task, distill=False) if not train: results = mc.reduce_and_flush_meters(meters) if udist.is_master(): logging.info( 'Epoch {}/{} {}: '.format(epoch, FLAGS.num_epochs, phase) + ', '.join('{}: {:.4f}'.format(k, v) for k, v in results.items())) for k, v in results.items(): mc.summary_writer.add_scalar('{}/{}'.format(phase, k), v, FLAGS._global_step) return results