def computeAPr(self): logging_rank('Evaluating AP^r') instance_par_pred_dir = os.path.join(self.pred_dir, 'instance_parsing') instance_par_gt_dir = self.gt_dir.replace('Images', 'Instance_ids') assert os.path.exists(instance_par_pred_dir) assert os.path.exists(instance_par_gt_dir) # img_name_list = [x[:-4] for x in os.listdir(instance_par_pred_dir) if x[-3:] == 'txt'] tmp_instance_par_pred_dir = instance_par_pred_dir img_name_list = [] while len(img_name_list) == 0: img_name_list = [ x.replace(instance_par_pred_dir + '/', '')[:-4] for x in glob.glob(tmp_instance_par_pred_dir) if x[-3:] == 'txt' ] tmp_instance_par_pred_dir += '/*' APr = np.zeros((self.num_parsing - 1, len(self.par_thresholds))) with tqdm(total=self.num_parsing - 1) as pbar: pbar.set_description('Calculating AP^r ..') for class_id in range(1, self.num_parsing): APr[class_id - 1, :] = self._compute_class_apr( instance_par_gt_dir, instance_par_pred_dir, img_name_list, class_id) pbar.update(1) # AP under each threshold. mAPr = np.nanmean(APr, axis=0) all_APr = {} for i, thre in enumerate(self.par_thresholds): all_APr[thre] = mAPr[i] return all_APr
def convert_conv1_rgb2bgr(self, weights_dict): """Support caffe trained models: include resnet50/101/152 and vgg16""" conv1_name = 'features1.0.weight' if 'vgg16_reducedfc' in self.weights_path else 'conv1.weight' weights_dict[conv1_name] = weights_dict[conv1_name][:, [2, 1, 0], :, :] logging_rank('Convert {} from RGB to BGR of {}'.format( conv1_name, weights_dict[conv1_name].shape)) return weights_dict
def evaluate(self): logging_rank('Evaluating Semantic Segmentation predictions') hist = np.zeros((self.num_classes, self.num_classes)) for i in tqdm(self.ids, desc='Calculating IoU ..'): image_name = self.dataset.coco.imgs[i]['file_name'].replace( self.name_trans[0], self.name_trans[1]) if not (os.path.exists(os.path.join(self.gt_dir, image_name)) and os.path.exists(os.path.join(self.pre_dir, image_name))): continue pre_png = cv2.imread(os.path.join(self.pre_dir, image_name), 0) gt_png = self.generate_gt_png(i, image_name, pre_png.shape) assert gt_png.shape == pre_png.shape, '{} VS {}'.format( str(gt_png.shape), str(pre_png.shape)) gt = gt_png.flatten() pre = pre_png.flatten() hist += self.fast_hist(gt, pre) def mean_iou(overall_h): iu = np.diag(overall_h) / (overall_h.sum(1) + overall_h.sum(0) - np.diag(overall_h) + 1e-10) return iu, np.nanmean(iu) def per_class_acc(overall_h): acc = np.diag(overall_h) / (overall_h.sum(1) + 1e-10) return np.nanmean(acc) def pixel_wise_acc(overall_h): return np.diag(overall_h).sum() / overall_h.sum() iou, miou = mean_iou(hist) mean_acc = per_class_acc(hist) pixel_acc = pixel_wise_acc(hist) self.stats.update( dict(IoU=iou, mIoU=miou, MeanACC=mean_acc, PixelACC=pixel_acc))
def get_lr(self): """Update learning rate """ warmup_factor = self.get_warmup_factor(self.warmup_method, self.iteration, self.warmup_iters, self.warmup_factor) if self.solver.LR_POLICY == "STEP": lr_factor = self.get_step_factor(warmup_factor) elif self.solver.LR_POLICY == "COSINE": lr_factor = self.get_cosine_factor(warmup_factor) elif self.solver.LR_POLICY == 'STEP_COSINE': if self.iteration < self.milestones[-1]: lr_factor = self.get_step_factor(warmup_factor) else: lr_factor = self.get_cosine_lrs(warmup_factor) elif self.solver.LR_POLICY == 'POLY': lr_factor = self.get_poly_factor(warmup_factor) else: raise KeyError('Unknown SOLVER.LR_POLICY: {}'.format( self.solver.LR_POLICY)) ratio = _get_lr_change_ratio(lr_factor, self.lr_factor) if self.lr_factor != lr_factor and ratio > self.solver.LOG_LR_CHANGE_THRESHOLD: if lr_factor * self.solver.BASE_LR > 1e-7 and self.iteration > 1: logging_rank('Changing learning rate {:.6f} -> {:.6f}'.format( self.lr_factor * self.solver.BASE_LR, lr_factor * self.solver.BASE_LR)) self.lr_factor = lr_factor self.iteration += 1 return [lr_factor * base_lr for base_lr in self.base_lrs]
def align_and_update_state_dicts(model_state_dict, weights_dict, use_weights_once=False): """ This function is taken from the maskrcnn_benchmark repo. It can be seen here: https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/utils/model_serialization.py Strategy: suppose that the models that we will create will have prefixes appended to each of its keys, for example due to an extra level of nesting that the original pre-trained weights from ImageNet won't contain. For example, model.state_dict() might return backbone[0].body.res2.conv1.weight, while the pre-trained model contains res2.conv1.weight. We thus want to match both parameters together. For that, we look for each model weight, look among all loaded keys if there is one that is a suffix of the current weight name, and use it if that's the case. If multiple matches exist, take the one with longest size of the corresponding name. For example, for the same model as before, the pretrained weight file can contain both res2.conv1.weight, as well as conv1.weight. In this case, we want to match backbone[0].body.conv1.weight to conv1.weight, and backbone[0].body.res2.conv1.weight to res2.conv1.weight. """ model_keys = sorted(list(model_state_dict.keys())) weights_keys = sorted(list(weights_dict.keys())) # get a matrix of string matches, where each (i, j) entry correspond to the size of the # loaded_key string, if it matches match_matrix = [ len(j) if i.endswith(j) else 0 for i in model_keys for j in weights_keys ] match_matrix = torch.as_tensor(match_matrix).view(len(model_keys), len(weights_keys)) max_match_size, idxs = match_matrix.max(1) # remove indices that correspond to no-match idxs[max_match_size == 0] = -1 # used for logging max_size_model = max([len(key) for key in model_keys]) if model_keys else 1 max_size_weights = max([len(key) for key in weights_keys]) if weights_keys else 1 match_keys = set() if use_weights_once: idx_model_and_weights = zip( *np.unique(idxs.numpy(), return_index=True)[::-1]) else: idx_model_and_weights = enumerate(idxs.tolist()) for idx_model, idx_weights in idx_model_and_weights: if idx_weights == -1: continue key_model = model_keys[idx_model] key_weights = weights_keys[idx_weights] ori_value = model_state_dict[key_model] if ori_value.shape != weights_dict[key_weights].shape: continue model_state_dict[key_model] = weights_dict[key_weights] match_keys.add(key_model) logging_rank('{: <{}} loaded from {: <{}} of shape {}'.format( key_model, max_size_model, key_weights, max_size_weights, tuple(weights_dict[key_weights].shape))) mismatch_keys = set(model_keys) - match_keys return model_state_dict, mismatch_keys
def forward(self, pred, target): b, c, h, w = pred.size() target = target.view(-1) valid_mask = target.ne(self.ignore_label) target = target * valid_mask.long() num_valid = valid_mask.sum() prob = F.softmax(pred, dim=1) prob = (prob.transpose(0, 1)).reshape(c, -1) if self.min_kept > num_valid: logging_rank('Labels: {}'.format(num_valid)) elif num_valid > 0: prob = prob.masked_fill_((1 - valid_mask.long()).bool(), 1) mask_prob = prob[target, torch.arange(len(target), dtype=torch.long)] threshold = self.thresh if self.min_kept > 0: _, index = torch.sort(mask_prob) threshold_index = index[min(len(index), self.min_kept) - 1] if mask_prob[threshold_index] > self.thresh: threshold = mask_prob[threshold_index] kept_mask = mask_prob.le(threshold) target *= kept_mask.long() valid_mask *= kept_mask # logging_rank('Valid Mask: {}'.format(valid_mask.sum())) target = target.masked_fill_((1 - valid_mask.long()).bool(), self.ignore_label) target = target.view(b, h, w) return self.criterion(pred, target)
def save_best(self, model, optimizer=None, scheduler=None, remove_old=True, infix='epoch'): if scheduler.info['cur_acc'] < scheduler.info['best_acc']: return False old_name = 'model_{}{}-{:4.2f}.pth'.format( infix, scheduler.info['best_epoch'], scheduler.info['best_acc']) new_name = 'model_{}{}-{:4.2f}.pth'.format(infix, scheduler.info['cur_epoch'], scheduler.info['cur_acc']) if os.path.exists(os.path.join(self.ckpt, old_name)) and remove_old: os.remove(os.path.join(self.ckpt, old_name)) scheduler.info['best_acc'] = scheduler.info['cur_acc'] scheduler.info['best_epoch'] = scheduler.info['cur_epoch'] save_dict = {'model': model.state_dict()} if optimizer is not None: save_dict['optimizer'] = optimizer.state_dict() if scheduler is not None: save_dict['scheduler'] = scheduler.state_dict() torch.save(save_dict, os.path.join(self.ckpt, new_name)) shutil.copyfile(os.path.join(self.ckpt, new_name), os.path.join(self.ckpt, 'model_latest.pth')) logging_rank('Saving best checkpoint done: {}.'.format(new_name)) return True
def load_optimizer(self, optimizer): if self.resume: optimizer.load_state_dict(self.checkpoint.pop('optimizer')) logging_rank('Loading optimizer done.') else: logging_rank('Initializing optimizer done.') return optimizer
def main(): cfg = get_cfg() cfg.merge_from_file(args.cfg_file) cfg.merge_from_list(args.opts) cfg = infer_cfg(cfg) cfg.freeze() # Calculate Params & FLOPs & Activations if cfg.MODEL_ANALYSE: model = Generalized_CNN(cfg) model.eval() analyser = Analyser(cfg, model, param_details=False) n_params = analyser.get_params()[1] conv_flops, model_flops = analyser.get_flops_activs(args.size[0], args.size[1], mode='flops') conv_activs, model_activs = analyser.get_flops_activs( args.size[0], args.size[1], mode='activations') logging_rank('-----------------------------------') logging_rank('Params: {}'.format(n_params)) logging_rank('FLOPs: {:.4f} M / Conv_FLOPs: {:.4f} M'.format( model_flops, conv_flops)) logging_rank( 'ACTIVATIONs: {:.4f} M / Conv_ACTIVATIONs: {:.4f} M'.format( model_activs, conv_activs)) logging_rank('-----------------------------------') del model
def log_stats(self, cur_idx, start_ind, end_ind, total_num_images, ims_per_gpu=1, suffix=None, log_all=False): """Log the tracked statistics.""" if (cur_idx + 1) % self.logperiod == 0 or cur_idx == end_ind - 1: eta_seconds = self.timers['iter'].average_time / ims_per_gpu * ( end_ind - cur_idx - 1) eta = str(datetime.timedelta(seconds=int(eta_seconds))) lines = ['[Testing][range:{}-{} of {}][{}/{}]'. \ format(start_ind + 1, end_ind, total_num_images, cur_idx + 1, end_ind), '[{:.3f}s = {:.3f}s + {:.3f}s + {:.3f}s][eta: {}]'. \ format(*[self.timers[name].average_time / ims_per_gpu for name in self.default_timers], eta)] if log_all: lines.append('\n|') for name, timer in self.timers.items(): if name not in self.default_timers: lines.append('{}: {:.3f}s|'.format( name, timer.average_time / ims_per_gpu)) if suffix is not None: lines.append(suffix) logging_rank(''.join(lines))
def flatten_to_tuple(outputs): result = [] if isinstance(outputs, torch.Tensor): result.append(outputs) elif isinstance(outputs, (list, tuple)): for v in outputs: result.extend(flatten_to_tuple(v)) elif isinstance(outputs, dict): for _, v in outputs.items(): result.extend(flatten_to_tuple(v)) elif isinstance(outputs, BoxList): result.extend(flatten_to_tuple(outputs.bbox)) if outputs.has_field('grid'): result.extend(flatten_to_tuple((outputs.get_field('grid')))) if outputs.has_field('mask'): result.extend(flatten_to_tuple((outputs.get_field('mask')))) result.extend(flatten_to_tuple((outputs.get_field('mask_scores')))) if outputs.has_field('keypoints'): result.extend(flatten_to_tuple((outputs.get_field('keypoints')))) if outputs.has_field('parsing'): result.extend(flatten_to_tuple((outputs.get_field('parsing')))) result.extend(flatten_to_tuple((outputs.get_field('parsing_scores')))) if outputs.has_field('uv'): result.extend(flatten_to_tuple((outputs.get_field('uv')))) if outputs.has_field('hier'): result.extend(flatten_to_tuple((outputs.get_field('hier')))) else: logging_rank('Output of type {} not included in flops/activations count.'.format(type(outputs))) return tuple(result)
def update_stats(self): """ Update the model with precise statistics. Users can manually call this method. """ if self.disabled: return if self.data_iter is None: self.data_iter = iter(self.data_loader) def data_loader(): for num_iter in itertools.count(1): if num_iter % 100 == 0: logging_rank( "Running precise-BN ... {}/{} iterations.".format( num_iter, self.num_iter)) # This way we can reuse the same iterator yield next(self.data_iter) with EventStorage(): logging_rank( "Running precise-BN for {} iterations... ".format( self.num_iter) + "Note that this could produce different statistics every time." ) update_bn_stats(self.model, data_loader(), self.device, self.num_iter)
def load_scheduler(self, scheduler): if self.resume: scheduler.iteration = self.checkpoint['scheduler']['iteration'] scheduler.info = self.checkpoint['scheduler']['info'] logging_rank('Loading scheduler done.') else: logging_rank('Initializing scheduler done.') return scheduler
def data_loader(): for num_iter in itertools.count(1): if num_iter % 100 == 0: logging_rank( "Running precise-BN ... {}/{} iterations.".format( num_iter, self.num_iter)) # This way we can reuse the same iterator yield next(self.data_iter)
def main(args): cfg = get_cfg() cfg.merge_from_file(args.cfg_file) cfg.merge_from_list(args.opts) cfg = infer_cfg(cfg) cfg.freeze() # logging_rank(cfg) if not os.path.isdir(cfg.CKPT): mkdir_p(cfg.CKPT) setup_logging(cfg.CKPT) # Calculate Params & FLOPs & Activations n_params, conv_flops, model_flops, conv_activs, model_activs = 0, 0, 0, 0, 0 if is_main_process() and cfg.MODEL_ANALYSE: model = Generalized_CNN(cfg) model.eval() analyser = Analyser(cfg, model, param_details=False) n_params = analyser.get_params()[1] conv_flops, model_flops = analyser.get_flops_activs(cfg.TEST.SCALE[0], cfg.TEST.SCALE[1], mode='flops') conv_activs, model_activs = analyser.get_flops_activs( cfg.TEST.SCALE[0], cfg.TEST.SCALE[1], mode='activations') del model synchronize() # Create model model = Generalized_CNN(cfg) logging_rank(model) # Load model test_weights = get_weights(cfg.CKPT, cfg.TEST.WEIGHTS) load_weights(model, test_weights) logging_rank('Params: {} | FLOPs: {:.4f}M / Conv_FLOPs: {:.4f}M | ' 'ACTIVATIONs: {:.4f}M / Conv_ACTIVATIONs: {:.4f}M'.format( n_params, model_flops, conv_flops, model_activs, conv_activs)) model.eval() model.to(torch.device(cfg.DEVICE)) # Create testing dataset and loader datasets = build_dataset(cfg, is_train=False) test_loader = make_test_data_loader(cfg, datasets) synchronize() # Build hooks all_hooks = build_test_hooks(args.cfg_file.split('/')[-1], log_period=1, num_warmup=0) # Build test engine test_engine = TestEngine(cfg, model) # Test test(cfg, test_engine, test_loader, datasets, all_hooks)
def evaluate(self): logging_rank('Evaluating Parsing predictions') if 'APp' in self.metrics or 'ap^p' in self.metrics: APp, PCP = self.computeAPp() self.stats.update(dict(APp=APp, PCP=PCP)) if 'APr' in self.metrics or 'ap^r' in self.metrics: APr = self.computeAPr() self.stats.update(dict(APr=APr)) if 'APh' in self.metrics or 'ap^h' in self.metrics: APh = self.computeAPh() self.stats.update(dict(APh=APh))
def get_params(self, max_depth=6): """ Format the parameter count of the model (and its submodules or parameters) in a nice table. Args: max_depth (int): maximum depth to recursively print submodules or parameters Returns: str: the table to be printed """ count = self.compute_params() param_shape = { k: tuple(v.shape) for k, v in self.model.named_parameters() } table = [] def format_size(x): # pyre-fixme[6]: Expected `int` for 1st param but got `float`. # pyre-fixme[6]: Expected `int` for 1st param but got `float`. if x > 1e5: return '{:.2f}M'.format(x / 1e6) # pyre-fixme[6]: Expected `int` for 1st param but got `float`. # pyre-fixme[6]: Expected `int` for 1st param but got `float`. if x > 1e2: return '{:.2f}K'.format(x / 1e3) return str(x) def fill(lvl, prefix): if lvl >= max_depth: return for name, v in count.items(): if name.count('.') == lvl and name.startswith(prefix): indent = ' ' * (lvl + 1) if name in param_shape: table.append((indent + name, indent + str(param_shape[name]))) else: table.append((indent + name, indent + format_size(v))) fill(lvl + 1, name + '.') table.append(('model', format_size(count.pop('')))) fill(0, '') old_ws = tabulate.PRESERVE_WHITESPACE tabulate.PRESERVE_WHITESPACE = True tab = tabulate.tabulate( table, headers=['name', '#elements or shape'], tablefmt='pipe' ) tabulate.PRESERVE_WHITESPACE = old_ws if self.param_details: logging_rank(tab) return table[0]
def _get_repeat_factors(self, dataset_dicts): """ Compute (fractional) per-image repeat factors. Args: dataset_dicts (list) : per-image annotations Returns: torch.Tensor: the i-th element is the repeat factor for the dataset_dicts image at index i. """ # 1. For each category c, compute the fraction of images that contain it: f(c) category_freq = defaultdict(int) for dataset_dict in dataset_dicts: # For each image (without repeats) cat_ids = { ann["category_id"] for ann in dataset_dict["annotations"] } for cat_id in cat_ids: category_freq[cat_id] += 1 num_images = len(dataset_dicts) for k, v in category_freq.items(): category_freq[k] = v / num_images # 2. For each category c, compute the category-level repeat factor: # lvis paper: r(c) = max(1, sqrt(t / f(c))) # common: r(c) = max(i, min(a,pow(t / f(c),alpha))) # category_rep = { # cat_id: max(self.config.MIN_REPEAT_TIMES, min(self.config.MAX_REPEAT_TIMES, math.pow( # (self.config.REPEAT_THRESHOLD / cat_freq), self.config.POW))) # for cat_id, cat_freq in category_freq.items() # } category_rep = { cat_id: max( self.config.MIN_REPEAT_TIMES, math.pow((self.config.REPEAT_THRESHOLD / cat_freq), self.config.POW)) for cat_id, cat_freq in category_freq.items() } # 3. For each image I, compute the image-level repeat factor: # r(I) = max_{c in I} r(c) rep_factors = [] for dataset_dict in dataset_dicts: cat_ids = { ann["category_id"] for ann in dataset_dict["annotations"] } rep_factor = max({category_rep[cat_id] for cat_id in cat_ids}) rep_factors.append(rep_factor) logging_rank( 'max(rep_factors): {} , min(rep_factors): {} , len(rep_factors): {}' .format(max(rep_factors), min(rep_factors), len(rep_factors)), ) return torch.tensor(rep_factors, dtype=torch.float32)
def load_weights(model, weights_path, use_weights_once=False): try: weights_dict = torch.load(weights_path, map_location=torch.device("cpu"))['model'] except: weights_dict = torch.load(weights_path, map_location=torch.device("cpu")) weights_dict = strip_prefix_if_present(weights_dict, prefix='module.') model_state_dict = model.state_dict() model_state_dict, mismatch_keys = align_and_update_state_dicts( model_state_dict, weights_dict, use_weights_once) model.load_state_dict(model_state_dict) logging_rank('The mismatch keys: {}.'.format( list(mismatch_params_filter(sorted(mismatch_keys))))) logging_rank('Loading from weights: {}.'.format(weights_path))
def __init__(self, precise_bn_args, period, num_iter, max_iter): if len(get_bn_modules(precise_bn_args[1])) == 0: logging_rank( "PreciseBN is disabled because model does not contain BN layers in training mode." ) self.disabled = True return self.data_loader = precise_bn_args[0] self.model = precise_bn_args[1] self.device = precise_bn_args[2] self.num_iter = num_iter self.period = period self.max_iter = max_iter self.disabled = False self.data_iter = None
def log(self, logperiod=10): """ Log the tracked statistics. Eg.: | timer1: xxxs | timer2: xxxms | timer3: xxxms | """ self.calls += 1 if self.calls % logperiod == 0 and self.timers: lines = [''] for name, timer in self.timers.items(): avg_time = timer.average_time suffix = 's' if avg_time < 0.01: avg_time *= 1000 suffix = 'ms' lines.append(' {}: {:.3f}{} '.format(name, avg_time, suffix)) lines.append('') logging_rank('|'.join(lines))
def __init__(self, parsingGt=None, parsingPred=None, gt_dir=None, pred_dir=None, score_thresh=0.001, num_parsing=20, metrics=['mIoU', 'APp']): """ Initialize ParsingEvaluator :param parsingGt: :param parsingPred: :return: None """ self.parsingGt = parsingGt self.parsingPred = parsingPred self.params = {} # evaluation parameters self.params = Params(iouType='iou') # parameters self.par_thresholds = self.params.pariouThrs self.mask_thresholds = self.params.maskiouThrs self.gt_dir = gt_dir self.pred_dir = pred_dir self.score_thresh = score_thresh self.num_parsing = num_parsing self.metrics = metrics self.stats = dict() # result summarization if 'mIoU' in self.metrics or 'miou' in self.metrics: self.global_parsing_dir = os.path.join(self.pred_dir, 'global_parsing') assert os.path.exists(self.global_parsing_dir) logging_rank('The Global Parsing Images: {}'.format( len(parsingGt))) self.semseg_eval = SemSegEvaluator(parsingGt, self.gt_dir, self.global_parsing_dir, self.num_parsing, gt_dir=self.gt_dir.replace( 'Images', 'Category_ids')) self.semseg_eval.evaluate() self.semseg_eval.accumulate() self.semseg_eval.summarize() self.stats.update(self.semseg_eval.stats) print('=' * 80)
def flop_count(self, model, inputs, supported_ops=None, ): """ Given a model and an input to the model, compute the Gflops of the given model. Args: model (nn.Module): The model to compute flop counts. inputs (tuple): Inputs that are passed to `model` to count flops. Inputs need to be in a tuple. supported_ops (dict(str,Callable) or None) : provide additional handlers for extra ops, or overwrite the existing handlers for convolution and matmul and einsum. The key is operator name and the value is a function that takes (inputs, outputs) of the op. We count one Multiply-Add as one FLOP. Returns: tuple[defaultdict, Counter]: A dictionary that records the number of gflops for each operation and a Counter that records the number of skipped operations. """ assert isinstance(inputs, tuple), 'Inputs need to be in a tuple.' supported_ops = {**_FLOPS_DEFAULT_SUPPORTED_OPS, **(supported_ops or {})} # Run flop count. total_flop_counter, skipped_ops = get_jit_model_analysis( model, inputs, supported_ops ) # Log for skipped operations. if len(skipped_ops) > 0: for op, freq in skipped_ops.items(): logging_rank('Skipped operation {} {} time(s)'.format(op, freq)) # Convert flop count to gigaflops. final_count = defaultdict(float) for op in total_flop_counter: final_count[op] = total_flop_counter[op] / 1e6 return final_count, skipped_ops
def load_model(self, model, convert_conv1=False, use_weights_once=False): if self.resume: weights_dict = self.checkpoint.pop('model') weights_dict = strip_prefix_if_present(weights_dict, prefix='module.') model_state_dict = model.state_dict() model_state_dict, self.mismatch_keys = align_and_update_state_dicts( model_state_dict, weights_dict, use_weights_once) model.load_state_dict(model_state_dict) logging_rank('Resuming from weights: {}.'.format( self.weights_path)) else: if self.weights_path: weights_dict = self.checkpoint if not self.retrain else self.checkpoint.pop( 'model') weights_dict = strip_prefix_if_present(weights_dict, prefix='module.') weights_dict = self.weight_mapping( weights_dict) # only for pre-training if convert_conv1: # only for pre-training weights_dict = self.convert_conv1_rgb2bgr(weights_dict) model_state_dict = model.state_dict() model_state_dict, self.mismatch_keys = align_and_update_state_dicts( model_state_dict, weights_dict, use_weights_once) model.load_state_dict(model_state_dict) logging_rank('Pre-training on weights: {}.'.format( self.weights_path)) else: logging_rank('Training from scratch.') return model
def activation_count(self, model, inputs, supported_ops=None, ): """ Given a model and an input to the model, compute the total number of activations of the model. Args: model (nn.Module): The model to compute activation counts. inputs (tuple): Inputs that are passed to `model` to count activations. Inputs need to be in a tuple. supported_ops (dict(str,Callable) or None) : provide additional handlers for extra ops, or overwrite the existing handlers for convolution and matmul. The key is operator name and the value is a function that takes (inputs, outputs) of the op. Returns: tuple[defaultdict, Counter]: A dictionary that records the number of activation (mega) for each operation and a Counter that records the number of skipped operations. """ assert isinstance(inputs, tuple), "Inputs need to be in a tuple." supported_ops = {**_ACTIVS_DEFAULT_SUPPORTED_OPS, **(supported_ops or {})} # Run activation count. total_activation_count, skipped_ops = get_jit_model_analysis( model, inputs, supported_ops ) # Log for skipped operations. if len(skipped_ops) > 0: for op, freq in skipped_ops.items(): logging_rank("Skipped operation {} {} time(s)".format(op, freq)) # Convert activation count to mega count. final_count = defaultdict(float) for op in total_activation_count: final_count[op] = total_activation_count[op] / 1e6 return final_count, skipped_ops
def test(cfg, test_engine, loader, datasets, all_hooks): total_timer = Timer() total_timer.tic() all_results = [[] for _ in range(4)] eval = Evaluation(cfg) with torch.no_grad(): loader = iter(loader) for i in range(len(loader)): all_hooks.iter_tic() all_hooks.data_tic() inputs, targets, idx = next(loader) all_hooks.data_toc() all_hooks.infer_tic() result = test_engine(inputs, targets) all_hooks.infer_toc() all_hooks.post_tic() eval_results = eval.post_processing(result, targets, idx, datasets) all_results = [ results + eva for results, eva in zip(all_results, eval_results) ] all_hooks.post_toc() all_hooks.iter_toc() if is_main_process(): all_hooks.log_stats(i, 0, len(loader), len(datasets)) all_results = list(zip(*all_gather(all_results))) all_results = [[item for sublist in results for item in sublist] for results in all_results] if is_main_process(): total_timer.toc(average=False) logging_rank('Total inference time: {:.3f}s'.format( total_timer.average_time)) eval.evaluation(datasets, all_results)
def save(self, model, optimizer=None, scheduler=None, copy_latest=True, infix='epoch'): save_dict = {'model': model.state_dict()} if optimizer is not None: save_dict['optimizer'] = optimizer.state_dict() if scheduler is not None: save_dict['scheduler'] = scheduler.state_dict() torch.save(save_dict, os.path.join(self.ckpt, 'model_latest.pth')) logg_sstr = 'Saving checkpoint done.' if copy_latest and scheduler: shutil.copyfile( os.path.join(self.ckpt, 'model_latest.pth'), os.path.join( self.ckpt, 'model_{}{}.pth'.format(infix, str(scheduler.iteration)))) logg_sstr += ' And copy "model_latest.pth" to "model_{}{}.pth".'.format( infix, str(scheduler.iteration)) logging_rank(logg_sstr)
def __init__(self, ann_file, root, bbox_file, image_thresh, ann_types, transforms=None, extra_fields={}): self.root = root self.coco = COCO(ann_file) self.bbox_file = bbox_file self.image_thresh = image_thresh self.ann_types = ann_types self.transforms = transforms self.extra_fields = extra_fields ids = sorted(self.coco.imgs.keys()) self.ids = [] self.ann_ids = [] for img_id in ids: ann_ids_per_image = self.coco.getAnnIds(imgIds=img_id, iscrowd=None) ann = self.coco.loadAnns(ann_ids_per_image) entry = self.coco.loadImgs(img_id)[0] for obj in ann: if has_valid_annotation(obj, entry): if 'keypoints' in self.ann_types or 'parsing' in self.ann_types or 'uv' in self.ann_types: if not has_valid_person(obj): continue if 'keypoints' in self.ann_types and not has_valid_keypoint(obj): continue if 'uv' in self.ann_types and not has_valid_densepose(obj): continue self.ids.append(img_id) self.ann_ids.append(obj['id']) logging_rank('Load {} samples'.format(len(ids))) self.id_to_img_map = {k: v for k, v in enumerate(self.ids)} category_ids = sorted(self.coco.getCatIds()) self.json_category_id_to_contiguous_id = {v: i for i, v in enumerate(category_ids)} self.contiguous_category_id_to_json_id = {v: k for k, v in self.json_category_id_to_contiguous_id.items()} category_ids = [c['name'] for c in self.coco.loadCats(category_ids)] self.classes = category_ids if 'parsing' in self.ann_types: Parsing.FLIP_MAP = self.extra_fields['flip_map'] if 'flip_map' in self.extra_fields else ()
def after_train(self, storage, **kwargs): iter = storage.iter total_time = time.perf_counter() - self.start_time total_time_minus_hooks = self.total_timer.seconds() hook_time = total_time - total_time_minus_hooks num_iter = iter + 1 - self.start_iter - self.warmup_iter if num_iter > 0 and total_time_minus_hooks > 0: # Speed is meaningful only after warmup # NOTE this format is parsed by grep in some scripts logging_rank( "Overall training speed: {} iterations in {} ({:.4f} s / it)". format( num_iter, str(datetime.timedelta( seconds=int(total_time_minus_hooks))), total_time_minus_hooks / num_iter, )) logging_rank("Total training time: {} ({} on hooks)".format( str(datetime.timedelta(seconds=int(total_time))), str(datetime.timedelta(seconds=int(hook_time))), ))
def __iter__(self): if self.shuffle: # deterministically shuffle based on epoch g = torch.Generator() g.manual_seed(self.epoch) indices = self._get_epoch_indices(g) randperm = torch.randperm(len(indices), generator=g).tolist() indices = indices[randperm] else: g = torch.Generator() g.manual_seed(self.epoch) indices = self._get_epoch_indices(g) # indices = torch.arange(len(self.dataset)).tolist() # when balance len(indices) diff from dataset image_num self.total_size = len(indices) logging_rank('balance sample total_size: {}'.format(self.total_size)) # subsample self.num_samples = int(len(indices) / self.num_replicas) offset = self.num_samples * self.rank indices = indices[offset:offset + self.num_samples] assert len(indices) == self.num_samples return iter(indices)