def create_model(name, model_spec, checkpoint_folder, write_spec=True): checkpoint_path = os.path.join(checkpoint_folder, "{}.state.pth.tar".format(name)) print (checkpoint_path) if os.path.isfile(checkpoint_path): #my_p.load_state_dict(torch.load(checkpoint_path)) a = artificery.Artificery(checkpoint_init=False) else: a = artificery.Artificery(checkpoint_init=True) spec_path = os.path.expanduser(model_spec["spec_path"]) spec_dir = os.path.dirname(spec_path) my_p = a.parse(model_spec["spec_path"]) my_p.name = name if write_spec: model_spec_dst = os.path.join(checkpoint_folder, "model_spec.json") subprocess.Popen("ls {}".format(spec_path), shell=True) subprocess.Popen("cp {} {}".format(spec_path, model_spec_dst), shell=True) for sf in a.used_specfiles: sf = os.path.expanduser(sf) rel_folder = os.path.dirname(os.path.relpath(sf, spec_dir)) dst_folder = os.path.join(checkpoint_folder, rel_folder) if not os.path.exists(dst_folder): os.makedirs(dst_folder) subprocess.Popen("cp {} {}".format(sf, dst_folder), shell=True) if os.path.isfile(checkpoint_path): #my_p.load_state_dict(torch.load(checkpoint_path)) state_dict = torch.load(checkpoint_path) own_state = my_p.state_dict() reinit_downmodules = [7] reinit_upmodules = [] for layer_name, param in state_dict.items(): load_weights = True if "level_downmodules" in layer_name: level = layer_name[len("level_downmodules.")] level = int(level[:1]) if level in reinit_downmodules: load_weights = False elif "level_upmodules" in layer_name: level = layer_name[len("level_upmodules.")] level = int(level[:1]) if level in reinit_upmodules: load_weights = False if layer_name not in own_state: load_weights = False if load_weights: if isinstance(param, Parameter): # backwards compatibility for serialized parameters param = param.data own_state[layer_name].copy_(param) #print ("Overwriting '{}'".format(layer_name)) return my_p
def open_model(name, checkpoint_folder): a = artificery.Artificery() spec_path = os.path.join(checkpoint_folder, "model_spec.json") my_p = a.parse(spec_path) checkpoint_path = os.path.join(checkpoint_folder, "{}.state.pth.tar".format(name)) if os.path.isfile(checkpoint_path): my_p.load_state_dict(torch.load(checkpoint_path)) my_p.name = name return my_p.cuda()
def create_model(name, checkpoint_folder, checkpoint_init=False): a = artificery.Artificery(checkpoint_init=checkpoint_init) spec_path = os.path.join(checkpoint_folder, "model_spec.json") my_p = a.parse(spec_path) checkpoint_path = os.path.join(checkpoint_folder, "{}.state.pth.tar".format(name)) if os.path.isfile(checkpoint_path): load_my_state_dict(my_p, torch.load(checkpoint_path)) #my_p.load_state_dict(torch.load(checkpoint_path)) my_p.name = name return my_p.cuda()
def create_model(checkpoint_folder, device): a = artificery.Artificery() spec_path = os.path.join(checkpoint_folder, "model_spec.json") my_p = a.parse(spec_path) checkpoint_path = os.path.join(checkpoint_folder, "checkpoint.state.pth.tar") if not os.path.isfile(checkpoint_path): raise Exception(f"Checkput path {checkpoint_path} no found!") load_my_state_dict( my_p, torch.load(checkpoint_path, map_location=torch.device(device))) return my_p
def create_model(checkpoint_folder, device='cpu', checkpoint_name="checkpoint"): a = artificery.Artificery() spec_path = os.path.join(checkpoint_folder, "model_spec.json") my_p = a.parse(spec_path) checkpoint_path = os.path.join(checkpoint_folder, f"{checkpoint_name}.state.pth.tar") if not os.path.isfile(checkpoint_path): print("creating new checkpiont...") return my_p load_my_state_dict( my_p, torch.load(checkpoint_path, map_location=torch.device(device))) my_p.name = checkpoint_name return my_p
import artificery a = artificery.Artificery() net = a.parse("../params/pyramid_m5m6m7_fm24.json") import pdb pdb.set_trace()