示例#1
0
    def __init__(self, model_name):
        super(Net, self).__init__()
        hasher = Hasher()

        model, hash_val, drive_target, env_var = {
            'u2netp': (u2net.U2NETP, 'e4f636406ca4e2af789941e7f139ee2e',
                       '1rbSTGKAE-MTxBYHd-51l2hMOQPT_7EPy', 'U2NET_PATH'),
            'u2net': (u2net.U2NET, '09fb4e49b7f785c9f855baf94916840a',
                      '1-Yg0cxgrNhHP-016FPdp902BR-kSsA4P', 'U2NET_PATH'),
            'u2net_human_seg':
            (u2net.U2NET, '347c3d51b01528e5c6c071e3cff1cb55',
             '1ao1ovG1Qtx4b7EoskHXmi2E9rp5CHLcZ', 'U2NET_PATH')
        }[model_name]
        path = os.environ.get(
            env_var,
            os.path.expanduser(os.path.join("~", ".u2net",
                                            model_name + ".pth")))
        net = model(3, 1)
        if not os.path.exists(path) or hasher.md5(path) != hash_val:
            head, tail = os.path.split(path)
            os.makedirs(head, exist_ok=True)

            URL = "https://docs.google.com/uc?export=download"

            session = requests.Session()
            response = session.get(URL,
                                   params={"id": drive_target},
                                   stream=True)

            token = None
            for key, value in response.cookies.items():
                if key.startswith("download_warning"):
                    token = value
                    break

            if token:
                params = {"id": drive_target, "confirm": token}
                response = session.get(URL, params=params, stream=True)

            total = int(response.headers.get("content-length", 0))

            with open(path, "wb") as file, tqdm(
                    desc=f"Downloading {tail} to {head}",
                    total=total,
                    unit="iB",
                    unit_scale=True,
                    unit_divisor=1024,
            ) as bar:
                for data in response.iter_content(chunk_size=1024):
                    size = file.write(data)
                    bar.update(size)
        net.load_state_dict(torch.load(path,
                                       map_location=torch.device(DEVICE)))
        net.to(device=DEVICE, dtype=torch.float32, non_blocking=True)
        net.eval()
        self.net = net
示例#2
0
def load_model(model_name: str = "u2net"):
    os.makedirs(os.path.expanduser(os.path.join("~", ".u2net")), exist_ok=True)

    hasher = Hasher()
    net = u2net.U2NETP(3, 1)

    if model_name == "u2netp":
        path = os.environ.get(
            "U2NETP_PATH",
            os.path.expanduser(os.path.join("~", ".u2net", model_name)),
        )
        print(path)
        if (not os.path.exists(path)
                or hasher.md5(path) != "e4f636406ca4e2af789941e7f139ee2e"):
            download_file_from_google_drive(
                "1rbSTGKAE-MTxBYHd-51l2hMOQPT_7EPy",
                "u2netp.pth",
                path,
            )

    elif model_name == "u2net":
        path = os.environ.get(
            "U2NET_PATH",
            os.path.expanduser(os.path.join("~", ".u2net", model_name)),
        )
        print(path)
        if (not os.path.exists(path)
                or hasher.md5(path) != "347c3d51b01528e5c6c071e3cff1cb55"):
            download_file_from_google_drive(
                "1ao1ovG1Qtx4b7EoskHXmi2E9rp5CHLcZ",
                "u2net.pth",
                path,
            )
    else:
        print("Choose between u2net or u2netp", file=sys.stderr)

    try:
        if torch.cuda.is_available():
            net.load_state_dict(torch.load(path))
            net.to(torch.device("cuda"))
        else:
            net.load_state_dict(torch.load(
                path,
                map_location="cpu",
            ))
    except FileNotFoundError:
        raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT),
                                model_name + ".pth")

    net.eval()

    return net
示例#3
0
 def test_sha512_vs_standard(self):
     h = Hasher()
     hd = h.sha512(self.testfile1_path)
     self.assertEqual(hd, self.hashtest_sha512_hash)
示例#4
0
 def test_md5_vs_standard(self):
     h = Hasher()
     hd = h.md5(self.testfile1_path)
     self.assertEqual(hd, self.hashtest_md5_hash)
示例#5
0
def load_model(model_name: str = "u2net"):
    hasher = Hasher()

    if model_name == "u2netp":
        net = u2net.U2NETP(3, 1)
        path = os.environ.get(
            "U2NETP_PATH",
            os.path.expanduser(os.path.join("~", ".u2net",
                                            model_name + ".pth")),
        )
        if (not os.path.exists(path)
                or hasher.md5(path) != "e4f636406ca4e2af789941e7f139ee2e"):
            download_file_from_google_drive(
                "1rbSTGKAE-MTxBYHd-51l2hMOQPT_7EPy",
                "u2netp.pth",
                path,
            )

    elif model_name == "u2net_human_seg":
        net = u2net.U2NET(3, 1)
        path = os.environ.get(
            "U2NET_PATH",
            os.path.expanduser(os.path.join("~", ".u2net",
                                            model_name + ".pth")),
        )
        if (not os.path.exists(path)
                or hasher.md5(path) != "09fb4e49b7f785c9f855baf94916840a"):
            download_file_from_google_drive(
                "1-Yg0cxgrNhHP-016FPdp902BR-kSsA4P",
                "u2net_human.pth",
                path,
            )

    elif model_name == "u2net":
        net = u2net.U2NET(3, 1)
        path = os.environ.get(
            "U2NET_PATH",
            os.path.expanduser(os.path.join("~", ".u2net",
                                            model_name + ".pth")),
        )
        if (not os.path.exists(path)
                or hasher.md5(path) != "347c3d51b01528e5c6c071e3cff1cb55"):
            download_file_from_google_drive(
                "1ao1ovG1Qtx4b7EoskHXmi2E9rp5CHLcZ",
                "u2net.pth",
                path,
            )
    else:
        print("Choose between u2net or u2netp", file=sys.stderr)

    try:
        if torch.cuda.is_available():
            net.load_state_dict(torch.load(path))
            net.to(torch.device("cuda"))
        else:
            raise Exception("GPU only")
    except FileNotFoundError:
        raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT),
                                model_name + ".pth")

    net.eval()

    return net
示例#6
0
def main():

    c = Command()

    if c.does_not_validate_missing_args():
        print(hsh_usage)
        sys.exit(1)

    if c.is_help_request():  # User requested hsh help information
        print(hsh_help)
        sys.exit(0)
    elif c.is_usage_request():  # User requested hsh usage information
        print(hsh_usage)
        sys.exit(0)
    elif c.is_version_request():  # User requested hsh version information
        version_display_string = app_name + ' ' + major_version + '.' + minor_version + '.' + patch_version
        print(version_display_string)
        sys.exit(0)

    primary_command = c.subcmd.lower()  # make the subcommand case-independent

    if primary_command == "sha1":
        if c.argc > 1:
            file_list = c.argv[1:]
            for file in file_list:
                if file_exists(file):
                    hasher = Hasher()
                    sha_hash = hasher.sha1(file)
                    print("SHA1 (" + file + ") :")
                    print(sha_hash)
                else:
                    sys.stderr.write(
                        file +
                        " does not appear to be an existing file path.\n")
        else:
            sys.stderr.write(
                "You did not include a file in your command.  Please try again.\n"
            )
            sys.exit(1)
    elif primary_command == "sha224":
        if c.argc > 1:
            file_list = c.argv[1:]
            for file in file_list:
                if file_exists(file):
                    hasher = Hasher()
                    sha_hash = hasher.sha224(file)
                    print("SHA224 (" + file + ") :")
                    print(sha_hash)
                else:
                    sys.stderr.write(
                        file +
                        " does not appear to be an existing file path.\n")
        else:
            sys.stderr.write(
                "You did not include a file in your command.  Please try again.\n"
            )
            sys.exit(1)
    elif primary_command == "sha256":
        if c.argc > 1:
            file_list = c.argv[1:]
            for file in file_list:
                if file_exists(file):
                    hasher = Hasher()
                    sha_hash = hasher.sha256(file)
                    print("SHA256 (" + file + ") :")
                    print(sha_hash)
                else:
                    sys.stderr.write(
                        file +
                        " does not appear to be an existing file path.\n")
        else:
            sys.stderr.write(
                "You did not include a file in your command.  Please try again.\n"
            )
            sys.exit(1)
    elif primary_command == "sha384":
        if c.argc > 1:
            file_list = c.argv[1:]
            for file in file_list:
                if file_exists(file):
                    hasher = Hasher()
                    sha_hash = hasher.sha384(file)
                    print("SHA384 (" + file + ") :")
                    print(sha_hash)
                else:
                    sys.stderr.write(
                        file +
                        " does not appear to be an existing file path.\n")
        else:
            sys.stderr.write(
                "You did not include a file in your command.  Please try again.\n"
            )
            sys.exit(1)
    elif primary_command == "sha512":
        if c.argc > 1:
            file_list = c.argv[1:]
            for file in file_list:
                if file_exists(file):
                    hasher = Hasher()
                    sha_hash = hasher.sha512(file)
                    print("SHA512 (" + file + ") :")
                    print(sha_hash)
                else:
                    sys.stderr.write(
                        file +
                        " does not appear to be an existing file path.\n")
        else:
            sys.stderr.write(
                "You did not include a file in your command.  Please try again.\n"
            )
            sys.exit(1)
    elif primary_command == "md5":
        if c.argc > 1:
            file_list = c.argv[1:]
            for file in file_list:
                if file_exists(file):
                    hasher = Hasher()
                    sha_hash = hasher.md5(file)
                    print("MD5 (" + file + ") :")
                    print(sha_hash)
                else:
                    sys.stderr.write(
                        file +
                        " does not appear to be an existing file path.\n")
        else:
            sys.stderr.write(
                "You did not include a file in your command.  Please try again.\n"
            )
            sys.exit(1)
    elif primary_command == "check":
        if c.argc == 3:  # primary command + 2 arguments
            hc = HashChecker()
            hc.compare(
                c.argv[1:]
            )  # pass the argument list excluding the primary command
        elif c.argc < 3:
            sys.stderr.write(
                "You did not include a file or hash digest for comparison.  Please try again.\n"
            )
            sys.exit(1)
        elif c.argc > 3:
            sys.stderr.write(
                "Too many arguments.  Please include two arguments for comparison.\n"
            )
            sys.exit(1)
    elif c.argc == 1:  # single file hash digest request with default SHA256 settings
        file = c.arg0
        if file_exists(file):
            hasher = Hasher()
            sha_hash = hasher.sha256(file)
            print("SHA256 (" + file + ") :")
            print(sha_hash)
        else:
            sys.stderr.write(
                c.arg0 +
                " does not appear to be an existing file path. Please try again.\n"
            )
            sys.exit(1)
    elif c.argc == 2:  # exactly two arguments, perform comparison between them by default
        hc = HashChecker()
        hc.compare(
            c.argv
        )  # pass the entire argument list because there is no primary command

    else:
        print(
            "Could not complete the command that you entered.  Please try again."
        )
        sys.exit(1)