Exemple #1
0
def load_checkpoint(path):
    print("Loading checkpoint...")
    restore = path
    if restore[:5] == 'gs://':
        gs_path = restore
        local_path = os.path.join(os.path.expanduser("~/.cache"), gs_path[5:])
        gdrive_path = os.path.join("/content/gdrive/My Drive/samples/", gs_path[5:])
        print(f'local path: {local_path}')
        print(f'gdrive path: {gdrive_path}')
        if dist.get_rank() % 8 == 0:
            if os.path.exists(gdrive_path):
                print("Using priors on Google Drive")
                restore = gdrive_path
            elif os.path.exists( os.path.dirname(gdrive_path) ):
                print("Downloading priors to Google Drive")
                download(gs_path, gdrive_path)
                restore = gdrive_path
            else:
                print("Downloading from gce")
                if not os.path.exists(os.path.dirname(local_path)):
                    os.makedirs(os.path.dirname(local_path))
                if not os.path.exists(local_path):
                    download(gs_path, local_path)
                restore = local_path
    dist.barrier()
    checkpoint = t.load(restore, map_location=t.device('cpu'))
    print("RS // Restored from {}".format(restore))
    return checkpoint
Exemple #2
0
def load_checkpoint(path):
    restore = path
    if restore[:5] == 'gs://':
        gs_path = restore
        local_path = os.path.join(os.path.expanduser("~/.cache"), gs_path[5:])
        if dist.get_rank() % 8 == 0:
            print("Downloading from gce")
            if not os.path.exists(os.path.dirname(local_path)):
                os.makedirs(os.path.dirname(local_path))
            if not os.path.exists(local_path):
                download(gs_path, local_path)
        restore = local_path
    dist.barrier()
    checkpoint = t.load(restore, map_location=t.device('cpu'))
    print("Restored from {}".format(restore))
    return checkpoint
Exemple #3
0
def load_checkpoint(path):
    _restore = path
    if _restore[:5] == 'gs://':
        gs_path = _restore
        cache = "~/data/.cache" if os.path.exists(
            os.path.expanduser('/data')) else "~/.cache"
        local_path = os.path.join(os.path.expanduser(cache), gs_path[5:])
        if dist.get_rank() % 8 == 0:
            if not os.path.exists(os.path.dirname(local_path)):
                os.makedirs(os.path.dirname(local_path))
            if not os.path.exists(local_path):
                print("Downloading from gce")
                download(gs_path, local_path)
        _restore = local_path
    dist.barrier()
    checkpoint = t.load(_restore, map_location=t.device('cpu'))
    print("Restored from {}".format(_restore))
    return checkpoint