Ejemplo n.º 1
0
def create_model(name, model_spec, checkpoint_folder, write_spec=True):
    checkpoint_path = os.path.join(checkpoint_folder, "{}.state.pth.tar".format(name))
    print (checkpoint_path)
    if os.path.isfile(checkpoint_path):
        #my_p.load_state_dict(torch.load(checkpoint_path))
        a = artificery.Artificery(checkpoint_init=False)
    else:
        a = artificery.Artificery(checkpoint_init=True)

    spec_path = os.path.expanduser(model_spec["spec_path"])
    spec_dir = os.path.dirname(spec_path)
    my_p = a.parse(model_spec["spec_path"])
    my_p.name = name

    if write_spec:
        model_spec_dst = os.path.join(checkpoint_folder, "model_spec.json")
        subprocess.Popen("ls {}".format(spec_path), shell=True)
        subprocess.Popen("cp {} {}".format(spec_path, model_spec_dst), shell=True)
        for sf in a.used_specfiles:
            sf = os.path.expanduser(sf)
            rel_folder = os.path.dirname(os.path.relpath(sf, spec_dir))
            dst_folder = os.path.join(checkpoint_folder, rel_folder)
            if not os.path.exists(dst_folder):
                os.makedirs(dst_folder)
            subprocess.Popen("cp {} {}".format(sf, dst_folder), shell=True)

    if os.path.isfile(checkpoint_path):
        #my_p.load_state_dict(torch.load(checkpoint_path))
        state_dict = torch.load(checkpoint_path)

        own_state = my_p.state_dict()
        reinit_downmodules = [7]
        reinit_upmodules = []

        for layer_name, param in state_dict.items():
            load_weights = True
            if "level_downmodules" in layer_name:
                level = layer_name[len("level_downmodules.")]
                level = int(level[:1])
                if level in reinit_downmodules:
                    load_weights = False

            elif "level_upmodules" in layer_name:
                level = layer_name[len("level_upmodules.")]
                level = int(level[:1])
                if level in reinit_upmodules:
                    load_weights = False

            if layer_name not in own_state:
                load_weights = False

            if load_weights:
                if isinstance(param, Parameter):
                    # backwards compatibility for serialized parameters
                    param = param.data
                own_state[layer_name].copy_(param)
                #print ("Overwriting '{}'".format(layer_name))


    return my_p
Ejemplo n.º 2
0
def open_model(name, checkpoint_folder):
    a = artificery.Artificery()

    spec_path = os.path.join(checkpoint_folder, "model_spec.json")
    my_p = a.parse(spec_path)

    checkpoint_path = os.path.join(checkpoint_folder, "{}.state.pth.tar".format(name))
    if os.path.isfile(checkpoint_path):
        my_p.load_state_dict(torch.load(checkpoint_path))
    my_p.name = name
    return my_p.cuda()
Ejemplo n.º 3
0
def create_model(name, checkpoint_folder, checkpoint_init=False):
    a = artificery.Artificery(checkpoint_init=checkpoint_init)

    spec_path = os.path.join(checkpoint_folder, "model_spec.json")
    my_p = a.parse(spec_path)

    checkpoint_path = os.path.join(checkpoint_folder,
                                   "{}.state.pth.tar".format(name))
    if os.path.isfile(checkpoint_path):
        load_my_state_dict(my_p, torch.load(checkpoint_path))
        #my_p.load_state_dict(torch.load(checkpoint_path))
    my_p.name = name
    return my_p.cuda()
Ejemplo n.º 4
0
def create_model(checkpoint_folder, device):
    a = artificery.Artificery()

    spec_path = os.path.join(checkpoint_folder, "model_spec.json")
    my_p = a.parse(spec_path)

    checkpoint_path = os.path.join(checkpoint_folder,
                                   "checkpoint.state.pth.tar")
    if not os.path.isfile(checkpoint_path):
        raise Exception(f"Checkput path {checkpoint_path} no found!")

    load_my_state_dict(
        my_p, torch.load(checkpoint_path, map_location=torch.device(device)))

    return my_p
Ejemplo n.º 5
0
def create_model(checkpoint_folder,
                 device='cpu',
                 checkpoint_name="checkpoint"):
    a = artificery.Artificery()

    spec_path = os.path.join(checkpoint_folder, "model_spec.json")
    my_p = a.parse(spec_path)

    checkpoint_path = os.path.join(checkpoint_folder,
                                   f"{checkpoint_name}.state.pth.tar")

    if not os.path.isfile(checkpoint_path):
        print("creating new checkpiont...")
        return my_p

    load_my_state_dict(
        my_p, torch.load(checkpoint_path, map_location=torch.device(device)))
    my_p.name = checkpoint_name

    return my_p
Ejemplo n.º 6
0
import artificery

a = artificery.Artificery()

net = a.parse("../params/pyramid_m5m6m7_fm24.json")
import pdb
pdb.set_trace()