def main(): args = parser.parse_args() if os.path.exists(args.output): print("Error: Output filename ({}) already exists.".format(args.output)) exit(1) # Load an existing checkpoint to CPU, strip everything but the state_dict and re-save if args.checkpoint and os.path.isfile(args.checkpoint): print("=> Loading checkpoint '{}'".format(args.checkpoint)) state_dict = load_state_dict(args.checkpoint, use_ema=args.use_ema) new_state_dict = {} for k, v in state_dict.items(): if args.clean_aux_bn and 'aux_bn' in k: # If all aux_bn keys are removed, the SplitBN layers will end up as normal and # load with the unmodified model using BatchNorm2d. continue name = k[7:] if k.startswith('module') else k new_state_dict[name] = v print("=> Loaded state_dict from '{}'".format(args.checkpoint)) try: torch.save(new_state_dict, _TEMP_NAME, _use_new_zipfile_serialization=False) except: torch.save(new_state_dict, _TEMP_NAME) with open(_TEMP_NAME, 'rb') as f: sha_hash = hashlib.sha256(f.read()).hexdigest() if args.output: checkpoint_root, checkpoint_base = os.path.split(args.output) checkpoint_base = os.path.splitext(checkpoint_base)[0] else: checkpoint_root = '' checkpoint_base = os.path.splitext(args.checkpoint)[0] final_filename = '-'.join([checkpoint_base, sha_hash[:8]]) + '.pth' shutil.move(_TEMP_NAME, os.path.join(checkpoint_root, final_filename)) print("=> Saved state_dict to '{}, SHA256: {}'".format(final_filename, sha_hash)) else: print("Error: Checkpoint ({}) doesn't exist".format(args.checkpoint))
def main(): args = parser.parse_args() # by default use the EMA weights (if present) args.use_ema = not args.no_use_ema # by default sort by checkpoint metric (if present) and avg top n checkpoints args.sort = not args.no_sort if os.path.exists(args.output): print("Error: Output filename ({}) already exists.".format( args.output)) exit(1) pattern = args.input if not args.input.endswith(os.path.sep) and not args.filter.startswith( os.path.sep): pattern += os.path.sep pattern += args.filter checkpoints = glob.glob(pattern, recursive=True) if not checkpoints: print("Error: No checkpoints to average.") exit(1) if args.sort: checkpoint_metrics = [] for c in checkpoints: metric = checkpoint_metric(c) if metric is not None: checkpoint_metrics.append((metric, c)) checkpoint_metrics = list( sorted(checkpoint_metrics, reverse=not args.descending)) checkpoint_metrics = checkpoint_metrics[:args.n] print("Selected checkpoints:") [print(m, c) for m, c in checkpoint_metrics] avg_checkpoints = [c for m, c in checkpoint_metrics] else: avg_checkpoints = checkpoints print("Selected checkpoints:") [print(c) for c in checkpoints] avg_state_dict = {} avg_counts = {} for c in avg_checkpoints: new_state_dict = load_state_dict(c, args.use_ema) if not new_state_dict: print("Error: Checkpoint ({}) doesn't exist".format( args.checkpoint)) continue for k, v in new_state_dict.items(): if k not in avg_state_dict: avg_state_dict[k] = v.clone().to(dtype=torch.float64) avg_counts[k] = 1 else: avg_state_dict[k] += v.to(dtype=torch.float64) avg_counts[k] += 1 for k, v in avg_state_dict.items(): v.div_(avg_counts[k]) # float32 overflow seems unlikely based on weights seen to date, but who knows float32_info = torch.finfo(torch.float32) final_state_dict = {} for k, v in avg_state_dict.items(): v = v.clamp(float32_info.min, float32_info.max) final_state_dict[k] = v.to(dtype=torch.float32) torch.save(final_state_dict, args.output) with open(args.output, 'rb') as f: sha_hash = hashlib.sha256(f.read()).hexdigest() print("=> Saved state_dict to '{}, SHA256: {}'".format( args.output, sha_hash))