def load_state(dirname, device, model, optim, use_amp=False): """ Load a model and optimizer state dict from disk """ model.to(device) weight_no = None weight_files = glob(os.path.join(dirname, "weights_*.tar")) if weight_files: weight_no = max( [int(re.sub(".*_([0-9]+).tar", "\\1", w)) for w in weight_files]) if weight_no: print("[picking up from epoch %s]" % weight_no) state_dict = torch.load(os.path.join(dirname, 'weights_%s.tar' % weight_no), map_location=device) state_dict = { k2: state_dict[k1] for k1, k2 in match_names(state_dict, model).items() } new_state_dict = OrderedDict() for k, v in state_dict.items(): name = k.replace('module.', '') new_state_dict[name] = v model.load_state_dict(new_state_dict) epoch = weight_no else: epoch = 0 return epoch
def load_state(dirname, device, model, optim=None): """ Load a model state dict from disk """ model.to(device) if hasattr(model, "module"): model = model.module weight_no = optim_no = None optim_files = glob(os.path.join(dirname, "optim_*.tar")) optim_nos = {int(re.sub(".*_([0-9]+).tar", "\\1", w)) for w in optim_files} weight_files = glob(os.path.join(dirname, "weights_*.tar")) weight_nos = { int(re.sub(".*_([0-9]+).tar", "\\1", w)) for w in weight_files } if optim is not None: weight_no = optim_no = max(optim_nos & weight_nos, default=None) else: weight_no = max(weight_nos, default=None) to_load = [] if weight_no: to_load.append(("weights", model)) if optim_no: to_load.append(("optim", optim)) if to_load: print("[picking up %s state from epoch %s]" % (', '.join([n for n, _ in to_load]), weight_no)) for name, obj in to_load: state_dict = torch.load(os.path.join( dirname, '%s_%s.tar' % (name, weight_no)), map_location=device) if name == "weights": state_dict = { k2: state_dict[k1] for k1, k2 in match_names(state_dict, obj).items() } new_state_dict = OrderedDict() for k, v in state_dict.items(): name = k.replace('module.', '') new_state_dict[name] = v state_dict = new_state_dict obj.load_state_dict(state_dict) epoch = weight_no else: epoch = 0 return epoch
def load_state(dirname, device, model, optim, use_amp=False): """ Load a model and optimizer state dict from disk """ model.to(device) if use_amp: try: model, optimizer = amp.initialize(model, optim, opt_level="O1", verbosity=0) except NameError: print( "[error]: Cannot use AMP: Apex package needs to be installed manually, See https://github.com/NVIDIA/apex" ) exit(1) weight_no = None weight_files = glob(os.path.join(dirname, "weights_*.tar")) if weight_files: weight_no = max( [int(re.sub(".*_([0-9]+).tar", "\\1", w)) for w in weight_files]) if weight_no: print("[picking up from epoch %s]" % weight_no) state_dict = torch.load(os.path.join(dirname, 'weights_%s.tar' % weight_no), map_location=device) state_dict = { k2: state_dict[k1] for k1, k2 in match_names(state_dict, model).items() } new_state_dict = OrderedDict() for k, v in state_dict.items(): name = k.replace('module.', '') new_state_dict[name] = v model.load_state_dict(new_state_dict) epoch = weight_no else: epoch = 0 return epoch