Ejemplo n.º 1
0
 def runjobs(self, verbose = None):
     """ Run the jobs """
     # The verbose flag
     if verbose is not None:
         self.verbose = verbose
     # Find all 'job_' methods
     for (exc_name, exc_value) in inspect.getmembers(self, lambda x: inspect.ismethod(x)):
         # The method's name must start with uppercase
         if exc_name[0].isupper():
             lock = FileLock(os.path.join('/', 'tmp', 'TranPy-%s-%s.lock' % (self.job_class, exc_name)))
             # Try to get the lock
             if lock.acquire():
                 if self.verbose:
                     print >>sys.stderr, 'Running %s %s' % (self.job_class, exc_name)
                 # Run the job
                 try:
                     exc_value()
                 except:
                     if self.verbose:
                         traceback.print_exc(file = sys.stderr)
                 finally:
                     # Release the lock
                     lock.release()
             else:
                 if self.verbose:
                     print >>sys.stderr, 'Locked %s %s' % (self.job_class, exc_name)
Ejemplo n.º 2
0
Archivo: db.py Proyecto: fparrel/regepe
def DbBuildInvert(dbtype,ele,invfunc):
    if dbtype not in DBTYPES:
        raise Exception('Invalid database type')
    # Check ele
    if ele not in ELELIST[dbtype]:
        raise Exception('Invalid element')
    #print '<!-- DbBuildInvert -->\n'
    # Target inv db
    dbfileinv = ele.upper()+'_INV.db'
    # Lock and open inv db
    lock = FileLock(dbfileinv,5)
    lock.acquire()
    #Log('DbBuildInvert open db c %s\n'%dbfileinv)
    dbinv = anydbm.open(dbfileinv,'c')
    # Clear inv db
    dbinv.clear()
    # List dir
    for dbfile in os.listdir(dbtype):
        id = dbfile[:-3]
        #Log('DbBuildInvert open db r %s/%s\n'%(dbtype,dbfile))
        db = anydbm.open('%s/%s' % (dbtype,dbfile),'r')
        if db.has_key(ele):
            value = db[ele]
            for word in invfunc(value):
                if dbinv.has_key(word):
                    dbinv[word] = dbinv[word] + (',%s' % id)
                else:
                    dbinv[word] = '%s' % id
        db.close()
        #Log('DbBuildInvert close db r %s/%s\n'%(dbtype,dbfile))
    dbinv.close()
    #Log('DbBuildInvert close db c %s\n'%dbfileinv)
    lock.release()
    # Rebuild is no more needed
    RearmRebuild(ele)
Ejemplo n.º 3
0
Archivo: db.py Proyecto: fparrel/regepe
def DbBuildInvertOld(ele,invfunc):
    raise Exception('Deprecated')
    # Check ele
    if ele not in ELELIST['maps']:
        raise Exception('Invalid element')
    # Target inv db
    dbfileinv = ele.upper()+'_INV.db'
    # Lock and open inv db
    lock = FileLock(dbfileinv,5)
    lock.acquire()
    dbinv = anydbm.open(dbfileinv,'c')
    # Clear inv db
    dbinv.clear()
    # List dir
    for mapdbfile in os.listdir('maps'):
        mapid = mapdbfile[:-3]
        dbmap = anydbm.open('maps/%s' % mapdbfile,'r')
        if dbmap.has_key(ele):
            value = dbmap[ele]
            for word in invfunc(value):
                if dbinv.has_key(word):
                    dbinv[word] = dbinv[word] + (',%s' % mapid)
                else:
                    dbinv[word] = '%s' % mapid
        dbmap.close()
    dbinv.close()
    lock.release()
    # Rebuild is no more needed
    RearmRebuild(ele)
Ejemplo n.º 4
0
        def run(self):
                while True:
                        lock = FileLock("/var/lock/baseDaemon.lock")

                        #Just to be safe; pulling data shouldn't take more than 2h
                        lock.acquire()

                        wikiDatesToMongo(False)
                        time.sleep(2*60*60)

                        lock.release()
Ejemplo n.º 5
0
Archivo: db.py Proyecto: fparrel/regepe
def DbAddComment(mapid,user,comment):
    mapfile = 'maps/%s.db' % mapid
    if not os.access(mapfile,os.F_OK):
        raise Exception('Invalid map id %s' % mapid)
    d = getCurrentDate()
    lock = FileLock(mapfile,5)
    lock.acquire()
    #Log('DbAddComment open db r %s\n' % mapfile)
    db = anydbm.open(mapfile,'r')
    if db.has_key('last_comment_id'):
        last_comment_id = int(db['last_comment_id'])
    else:
        last_comment_id = 0
    db.close()
    #Log('DbAddComment close db r %s\n' % mapfile)
    last_comment_id += 1
    if last_comment_id>99999:
        lock.release()
        raise Exception('Max comments reached')
    #Log('DbAddComment open db c %s\n' % mapfile)
    db = anydbm.open(mapfile,'c')
    db['last_comment_id'] = str(last_comment_id)
    db['comment%.5d'%last_comment_id] = '%s,%s,%s' % (d,user,comment)
    db.close()
    #Log('DbAddComment close db c %s\n' % mapfile)
    lock.release()
Ejemplo n.º 6
0
 def update(self, lock=True):
     if lock:
         flock = FileLock(self.inifile)
         flock.acquire()
     
     try:
         inifp = open(self.inifile, 'w')
         self.cfg.write(inifp)
         inifp.close()
         if lock: flock.release()
         return True
     except:
         if lock: flock.release()
         return False
Ejemplo n.º 7
0
    def __init__(self, log_file, debug=False):
        """Initialize."""

        self.format = "%(message)s"
        self.file_name = log_file
        self.file_lock = FileLock(self.file_name + '.lock')
        self.debug = debug

        try:
            with self.file_lock.acquire(1):
                with codecs.open(self.file_name, "w", "utf-8") as f:
                    f.write("")
        except Exception:
            self.file_name = None

        wx.LogGui.__init__(self)
Ejemplo n.º 8
0
    def __init__(self, log_file, no_redirect):
        """Initialize."""

        self.format = "%(message)s"
        self.file_name = log_file
        self.file_lock = FileLock(self.file_name + '.lock')

        try:
            with self.file_lock.acquire(1):
                with codecs.open(self.file_name, "w", "utf-8") as f:
                    f.write("")
        except Exception:
            self.file_name = None

        self.no_redirect = no_redirect

        wx.Log.__init__(self)
Ejemplo n.º 9
0
    def __init__(self, db_name, schema, log_level=1):
        """
        schema example = (
            (name, type, len),
            ...
        )

        int
        str
        """
        self._log_lvl = log_level
        db_name = os.path.join('db', db_name)

        self.lock = FileLock(db_name)

        if not self.is_valid_schema(schema):
            self.error("Invalid DB schema")
            raise Exception('Schema is bead!')

        all = self.init_mem(db_name, schema)

        self._memory = all
        self._db_name = db_name
        self._schema = schema
Ejemplo n.º 10
0
PORT = 8485

s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
print('Socket created')

s.bind((HOST, PORT))
print('Socket bind complete')
s.listen(10)
print('Socket now listening')

conn, addr = s.accept()

data = b""
payload_size = struct.calcsize(">L")
print("payload_size: {}".format(payload_size))
lock = FileLock(
    '../deep_lab_v3_material_detection/bridge_images/test.jpeg.lock')
while True:
    while len(data) < payload_size:
        print("Recv: {}".format(len(data)))
        data += conn.recv(4096)

    print("Done Recv: {}".format(len(data)))
    packed_msg_size = data[:payload_size]
    data = data[payload_size:]
    msg_size = struct.unpack(">L", packed_msg_size)[0]
    print("msg_size: {}".format(msg_size))
    while len(data) < msg_size:
        data += conn.recv(4096)
    frame_data = data[:msg_size]
    data = data[msg_size:]
Ejemplo n.º 11
0
import sys, urllib, json, sqlite3
from filelock import FileLock

url = 'https://coinut.com/api/tick/BTCUSD'

conn = sqlite3.connect('coinut.db')
c = conn.cursor()

c.execute('''CREATE TABLE IF NOT EXISTS ticker (timestamp integer, price real)''') 

conn.close()

prev_price = -1
while(True):
	try:
		lock = FileLock('coinut.lock')
		
		response = urllib.urlopen(url);
		data = json.loads(response.read())
	
		timestamp = long(data['timestamp'])
		price = float(data['tick'])
	
		if price != prev_price:
			if lock.acquire():
				conn = sqlite3.connect('coinut.db')
				c = conn.cursor()
			
				with conn:
					values = (timestamp, price,)
					c.execute("INSERT INTO ticker VALUES (?,?)", values)
Ejemplo n.º 12
0
def main():
    ret = 0

    parser = argparse.ArgumentParser(description='project mirroring',
                                     parents=[get_baseparser(
                                         tool_version=__version__)
                                     ])

    parser.add_argument('project')
    parser.add_argument('-c', '--config',
                        help='config file in JSON/YAML format')
    parser.add_argument('-U', '--uri', default='http://localhost:8080/source',
                        help='uri of the webapp with context path')
    parser.add_argument('-b', '--batch', action='store_true',
                        help='batch mode - will log into a file')
    parser.add_argument('-B', '--backupcount', default=8,
                        help='how many log files to keep around in batch mode')
    parser.add_argument('-I', '--incoming', action='store_true',
                        help='Check for incoming changes, terminate the '
                             'processing if not found.')
    try:
        args = parser.parse_args()
    except ValueError as e:
        print_exc_exit(e)

    logger = get_console_logger(get_class_basename(), args.loglevel)

    if args.config:
        config = read_config(logger, args.config)
        if config is None:
            logger.error("Cannot read config file from {}".format(args.config))
            sys.exit(1)
    else:
        config = {}

    GLOBAL_TUNABLES = [HOOKDIR_PROPERTY, PROXY_PROPERTY, LOGDIR_PROPERTY,
                       COMMANDS_PROPERTY, PROJECTS_PROPERTY,
                       HOOK_TIMEOUT_PROPERTY, CMD_TIMEOUT_PROPERTY]
    diff = diff_list(config.keys(), GLOBAL_TUNABLES)
    if diff:
        logger.error("unknown global configuration option(s): '{}'"
                     .format(diff))
        sys.exit(1)

    # Make sure the log directory exists.
    logdir = config.get(LOGDIR_PROPERTY)
    if logdir:
        check_create_dir(logger, logdir)

    uri = args.uri
    if not is_web_uri(uri):
        logger.error("Not a URI: {}".format(uri))
        sys.exit(1)
    logger.debug("web application URI = {}".format(uri))

    source_root = get_config_value(logger, 'sourceRoot', uri)
    if not source_root:
        sys.exit(1)

    logger.debug("Source root = {}".format(source_root))

    project_config = None
    projects = config.get(PROJECTS_PROPERTY)
    if projects:
        if projects.get(args.project):
            project_config = projects.get(args.project)
        else:
            for proj in projects.keys():
                try:
                    pattern = re.compile(proj)
                except re.error:
                    logger.error("Not a valid regular expression: {}".
                                 format(proj))
                    continue

                if pattern.match(args.project):
                    logger.debug("Project '{}' matched pattern '{}'".
                                 format(args.project, proj))
                    project_config = projects.get(proj)
                    break

    hookdir = config.get(HOOKDIR_PROPERTY)
    if hookdir:
        logger.debug("Hook directory = {}".format(hookdir))

    command_timeout = get_int(logger, "command timeout",
                              config.get(CMD_TIMEOUT_PROPERTY))
    if command_timeout:
        logger.debug("Global command timeout = {}".format(command_timeout))

    hook_timeout = get_int(logger, "hook timeout",
                           config.get(HOOK_TIMEOUT_PROPERTY))
    if hook_timeout:
        logger.debug("Global hook timeout = {}".format(hook_timeout))

    prehook = None
    posthook = None
    use_proxy = False
    ignored_repos = None
    if project_config:
        logger.debug("Project '{}' has specific (non-default) config".
                     format(args.project))

        # Quick sanity check.
        KNOWN_PROJECT_TUNABLES = [DISABLED_PROPERTY, CMD_TIMEOUT_PROPERTY,
                                  HOOK_TIMEOUT_PROPERTY, PROXY_PROPERTY,
                                  IGNORED_REPOS_PROPERTY, HOOKS_PROPERTY]
        diff = diff_list(project_config.keys(), KNOWN_PROJECT_TUNABLES)
        if diff:
            logger.error("unknown project configuration option(s) '{}' "
                         "for project {}".format(diff, args.project))
            sys.exit(1)

        project_command_timeout = get_int(logger, "command timeout for "
                                                  "project {}".
                                          format(args.project),
                                          project_config.
                                          get(CMD_TIMEOUT_PROPERTY))
        if project_command_timeout:
            command_timeout = project_command_timeout
            logger.debug("Project command timeout = {}".
                         format(command_timeout))

        project_hook_timeout = get_int(logger, "hook timeout for "
                                               "project {}".
                                       format(args.project),
                                       project_config.
                                       get(HOOK_TIMEOUT_PROPERTY))
        if project_hook_timeout:
            hook_timeout = project_hook_timeout
            logger.debug("Project hook timeout = {}".
                         format(hook_timeout))

        ignored_repos = project_config.get(IGNORED_REPOS_PROPERTY)
        if ignored_repos:
            if not isinstance(ignored_repos, list):
                logger.error("{} for project {} is not a list".
                             format(IGNORED_REPOS_PROPERTY, args.project))
                sys.exit(1)
            logger.debug("has ignored repositories: {}".
                         format(ignored_repos))

        hooks = project_config.get(HOOKS_PROPERTY)
        if hooks:
            if not hookdir:
                logger.error("Need to have '{}' in the configuration "
                             "to run hooks".format(HOOKDIR_PROPERTY))
                sys.exit(1)

            if not os.path.isdir(hookdir):
                logger.error("Not a directory: {}".format(hookdir))
                sys.exit(1)

            for hookname in hooks:
                if hookname == "pre":
                    prehook = hookpath = os.path.join(hookdir, hooks['pre'])
                    logger.debug("pre-hook = {}".format(prehook))
                elif hookname == "post":
                    posthook = hookpath = os.path.join(hookdir, hooks['post'])
                    logger.debug("post-hook = {}".format(posthook))
                else:
                    logger.error("Unknown hook name {} for project {}".
                                 format(hookname, args.project))
                    sys.exit(1)

                if not is_exe(hookpath):
                    logger.error("hook file {} does not exist or not "
                                 "executable".format(hookpath))
                    sys.exit(1)

        if project_config.get(PROXY_PROPERTY):
            if not config.get(PROXY_PROPERTY):
                logger.error("global proxy setting is needed in order to"
                             "have per-project proxy")
                sys.exit(1)

            logger.debug("will use proxy")
            use_proxy = True

    if not ignored_repos:
        ignored_repos = []

    # Log messages to dedicated log file if running in batch mode.
    if args.batch:
        if not logdir:
            logger.error("The logdir property is required in batch mode")
            sys.exit(1)

        logfile = os.path.join(logdir, args.project + ".log")
        logger.debug("Switching logging to the {} file".
                     format(logfile))

        logger = logger.getChild("rotating")
        logger.setLevel(args.loglevel)
        logger.propagate = False
        handler = RotatingFileHandler(logfile, maxBytes=0, mode='a',
                                      backupCount=args.backupcount)
        formatter = logging.Formatter("%(asctime)s - %(levelname)s: "
                                      "%(message)s", '%m/%d/%Y %I:%M:%S %p')
        handler.setFormatter(formatter)
        handler.doRollover()
        logger.addHandler(handler)

    # We want this to be logged to the log file (if any).
    if project_config:
        if project_config.get(DISABLED_PROPERTY):
            logger.info("Project {} disabled, exiting".
                        format(args.project))
            sys.exit(CONTINUE_EXITVAL)

    lock = FileLock(os.path.join(tempfile.gettempdir(),
                                 args.project + "-mirror.lock"))
    try:
        with lock.acquire(timeout=0):
            proxy = config.get(PROXY_PROPERTY) if use_proxy else None

            #
            # Cache the repositories first. This way it will be known that
            # something is not right, avoiding any needless pre-hook run.
            #
            repos = []
            try:
                repos = get_repos_for_project(logger, args.project,
                                              ignored_repos,
                                              commands=config.
                                              get(COMMANDS_PROPERTY),
                                              proxy=proxy,
                                              command_timeout=command_timeout,
                                              source_root=source_root,
                                              uri=uri)
            except RepositoryException as ex:
                logger.error('failed to get repositories for project {}: {}'.
                             format(args.project, ex))
                sys.exit(1)

            if not repos:
                logger.info("No repositories for project {}".
                            format(args.project))
                sys.exit(CONTINUE_EXITVAL)

            # Check if any of the repositories contains incoming changes.
            if args.incoming:
                got_incoming = False
                for repo in repos:
                    try:
                        if repo.incoming():
                            logger.debug('Repository {} has incoming changes'.
                                         format(repo))
                            got_incoming = True
                            break
                    except RepositoryException:
                        logger.error('Cannot determine incoming changes for '
                                     'repository {}'.format(repo))
                        sys.exit(1)

                if not got_incoming:
                    logger.info('No incoming changes for repositories in '
                                'project {}'.
                                format(args.project))
                    sys.exit(CONTINUE_EXITVAL)

            if prehook:
                logger.info("Running pre hook")
                if run_hook(logger, prehook,
                            os.path.join(source_root, args.project), proxy,
                            hook_timeout) != 0:
                    logger.error("pre hook failed for project {}".
                                 format(args.project))
                    logging.shutdown()
                    sys.exit(1)

            #
            # If one of the repositories fails to sync, the whole project sync
            # is treated as failed, i.e. the program will return 1.
            #
            for repo in repos:
                logger.info("Synchronizing repository {}".
                            format(repo.path))
                if repo.sync() != 0:
                    logger.error("failed to sync repository {}".
                                 format(repo.path))
                    ret = 1

            if posthook:
                logger.info("Running post hook")
                if run_hook(logger, posthook,
                            os.path.join(source_root, args.project), proxy,
                            hook_timeout) != 0:
                    logger.error("post hook failed for project {}".
                                 format(args.project))
                    logging.shutdown()
                    sys.exit(1)
    except Timeout:
        logger.warning("Already running, exiting.")
        sys.exit(1)

    logging.shutdown()
    sys.exit(ret)
Ejemplo n.º 13
0
class Metric(MetricInfoMixin):
    """A Metrics is the base class and common API for all metrics.

    Args:
        config_name (``str``): This is used to define a hash specific to a metrics computation script and prevents the metric's data
            to be overridden when the metric loading script is modified.
        keep_in_memory (``bool``): keep all predictions and references in memory. Not possible in distributed settings.
        cache_dir (``str``): Path to a directory in which temporary prediction/references data will be stored.
            The data directory should be located on a shared file-system in distributed setups.
        num_process (``int``): specify the total number of nodes in a distributed settings.
            This is useful to compute metrics in distributed setups (in particular non-additive metrics like F1).
        process_id (``int``): specify the id of the current process in a distributed setup (between 0 and num_process-1)
            This is useful to compute metrics in distributed setups (in particular non-additive metrics like F1).
        seed (Optional ``int``): If specified, this will temporarily set numpy's random seed when :func:`datasets.Metric.compute` is run.
        experiment_id (``str``): A specific experiment id. This is used if several distributed evaluations share the same file system.
            This is useful to compute metrics in distributed setups (in particular non-additive metrics like F1).
        max_concurrent_cache_files (``int``): Max number of concurrent metrics cache files (default 10000).
        timeout (``Union[int, float]``): Timeout in second for distributed setting synchronization.
    """

    def __init__(
        self,
        config_name: Optional[str] = None,
        keep_in_memory: bool = False,
        cache_dir: Optional[str] = None,
        num_process: int = 1,
        process_id: int = 0,
        seed: Optional[int] = None,
        experiment_id: Optional[str] = None,
        max_concurrent_cache_files: int = 10000,
        timeout: Union[int, float] = 100,
        **kwargs,
    ):
        # prepare info
        self.config_name = config_name or "default"
        info = self._info()
        info.metric_name = camelcase_to_snakecase(self.__class__.__name__)
        info.config_name = self.config_name
        info.experiment_id = experiment_id or "default_experiment"
        MetricInfoMixin.__init__(self, info)  # For easy access on low level

        # Safety checks on num_process and process_id
        assert isinstance(process_id, int) and process_id >= 0, "'process_id' should be a number greater than 0"
        assert (
            isinstance(num_process, int) and num_process > process_id
        ), "'num_process' should be a number greater than process_id"
        assert (
            num_process == 1 or not keep_in_memory
        ), "Using 'keep_in_memory' is not possible in distributed setting (num_process > 1)."
        self.num_process = num_process
        self.process_id = process_id
        self.max_concurrent_cache_files = max_concurrent_cache_files

        self.keep_in_memory = keep_in_memory
        self._data_dir_root = os.path.expanduser(cache_dir or HF_METRICS_CACHE)
        self.data_dir = self._build_data_dir()
        self.seed: int = seed or np.random.get_state()[1][0]
        self.timeout: Union[int, float] = timeout

        # Update 'compute' and 'add' docstring
        # methods need to be copied otherwise it changes the docstrings of every instance
        self.compute = types.MethodType(copyfunc(self.compute), self)
        self.add_batch = types.MethodType(copyfunc(self.add_batch), self)
        self.add = types.MethodType(copyfunc(self.add), self)
        self.compute.__func__.__doc__ += self.info.inputs_description
        self.add_batch.__func__.__doc__ += self.info.inputs_description
        self.add.__func__.__doc__ += self.info.inputs_description

        # self.arrow_schema = pa.schema(field for field in self.info.features.type)
        self.buf_writer = None
        self.writer = None
        self.writer_batch_size = None
        self.data = None

        # This is the cache file we store our predictions/references in
        # Keep it None for now so we can (cloud)pickle the object
        self.cache_file_name = None
        self.filelock = None
        self.rendez_vous_lock = None

        # This is all the cache files on which we have a lock when we are in a distributed setting
        self.file_paths = None
        self.filelocks = None

    def __len__(self):
        """Return the number of examples (predictions or predictions/references pair)
        currently stored in the metric's cache.
        """
        return 0 if self.writer is None else len(self.writer)

    def __repr__(self):
        return (
            f'Metric(name: "{self.name}", features: {self.features}, '
            f'usage: """{self.inputs_description}""", '
            f"stored examples: {len(self)})"
        )

    def _build_data_dir(self):
        """Path of this metric in cache_dir:
        Will be:
            self._data_dir_root/self.name/self.config_name/self.hash (if not none)/
        If any of these element is missing or if ``with_version=False`` the corresponding subfolders are dropped.
        """
        builder_data_dir = self._data_dir_root
        builder_data_dir = os.path.join(builder_data_dir, self.name, self.config_name)
        os.makedirs(builder_data_dir, exist_ok=True)
        return builder_data_dir

    def _create_cache_file(self, timeout=1) -> Tuple[str, FileLock]:
        """ Create a new cache file. If the default cache file is used, we generated a new hash. """
        file_path = os.path.join(self.data_dir, f"{self.experiment_id}-{self.num_process}-{self.process_id}.arrow")
        filelock = None
        for i in range(self.max_concurrent_cache_files):
            filelock = FileLock(file_path + ".lock")
            try:
                filelock.acquire(timeout=timeout)
            except Timeout:
                # If we have reached the max number of attempts or we are not allow to find a free name (distributed setup)
                # We raise an error
                if self.num_process != 1:
                    raise ValueError(
                        f"Error in _create_cache_file: another metric instance is already using the local cache file at {file_path}. "
                        f"Please specify an experiment_id (currently: {self.experiment_id}) to avoid colision "
                        f"between distributed metric instances."
                    )
                if i == self.max_concurrent_cache_files - 1:
                    raise ValueError(
                        f"Cannot acquire lock, too many metric instance are operating concurrently on this file system."
                        f"You should set a larger value of max_concurrent_cache_files when creating the metric "
                        f"(current value is {self.max_concurrent_cache_files})."
                    )
                # In other cases (allow to find new file name + not yet at max num of attempts) we can try to sample a new hashing name.
                file_uuid = str(uuid.uuid4())
                file_path = os.path.join(
                    self.data_dir, f"{self.experiment_id}-{file_uuid}-{self.num_process}-{self.process_id}.arrow"
                )
            else:
                break

        return file_path, filelock

    def _get_all_cache_files(self) -> Tuple[List[str], List[FileLock]]:
        """Get a lock on all the cache files in a distributed setup.
        We wait for timeout second to let all the distributed node finish their tasks (default is 100 seconds).
        """
        if self.num_process == 1:
            if self.cache_file_name is None:
                raise ValueError(
                    "Metric cache file doesn't exist. Please make sure that you call `add` or `add_batch` "
                    "at least once before calling `compute`."
                )
            file_paths = [self.cache_file_name]
        else:
            file_paths = [
                os.path.join(self.data_dir, f"{self.experiment_id}-{self.num_process}-{process_id}.arrow")
                for process_id in range(self.num_process)
            ]

        # Let's acquire a lock on each process files to be sure they are finished writing
        filelocks = []
        for process_id, file_path in enumerate(file_paths):
            filelock = FileLock(file_path + ".lock")
            try:
                filelock.acquire(timeout=self.timeout)
            except Timeout:
                raise ValueError(f"Cannot acquire lock on cached file {file_path} for process {process_id}.")
            else:
                filelocks.append(filelock)

        return file_paths, filelocks

    def _check_all_processes_locks(self):
        expected_lock_file_names = [
            os.path.join(self.data_dir, f"{self.experiment_id}-{self.num_process}-{process_id}.arrow.lock")
            for process_id in range(self.num_process)
        ]
        for expected_lock_file_name in expected_lock_file_names:
            nofilelock = FileFreeLock(expected_lock_file_name)
            try:
                nofilelock.acquire(timeout=self.timeout)
            except Timeout:
                raise ValueError(
                    f"Expected to find locked file {expected_lock_file_name} from process {self.process_id} but it doesn't exist."
                )
            else:
                nofilelock.release()

    def _check_rendez_vous(self):
        expected_lock_file_name = os.path.join(self.data_dir, f"{self.experiment_id}-{self.num_process}-0.arrow.lock")
        nofilelock = FileFreeLock(expected_lock_file_name)
        try:
            nofilelock.acquire(timeout=self.timeout)
        except Timeout:
            raise ValueError(
                f"Expected to find locked file {expected_lock_file_name} from process {self.process_id} but it doesn't exist."
            )
        else:
            nofilelock.release()
        lock_file_name = os.path.join(self.data_dir, f"{self.experiment_id}-{self.num_process}-rdv.lock")
        rendez_vous_lock = FileLock(lock_file_name)
        try:
            rendez_vous_lock.acquire(timeout=self.timeout)
        except Timeout:
            raise ValueError(f"Couldn't acquire lock on {lock_file_name} from process {self.process_id}.")
        else:
            rendez_vous_lock.release()

    def _finalize(self):
        """Close all the writing process and load/gather the data
        from all the nodes if main node or all_process is True.
        """
        if self.writer is not None:
            self.writer.finalize()
        self.writer = None
        if self.filelock is not None:
            self.filelock.release()

        if self.keep_in_memory:
            # Read the predictions and references
            reader = ArrowReader(path=self.data_dir, info=DatasetInfo(features=self.features))
            self.data = Dataset.from_buffer(self.buf_writer.getvalue())

        elif self.process_id == 0:
            # Let's acquire a lock on each node files to be sure they are finished writing
            file_paths, filelocks = self._get_all_cache_files()

            # Read the predictions and references
            try:
                reader = ArrowReader(path="", info=DatasetInfo(features=self.features))
                self.data = Dataset(**reader.read_files([{"filename": f} for f in file_paths]))
            except FileNotFoundError:
                raise ValueError(
                    "Error in finalize: another metric instance is already using the local cache file. "
                    "Please specify an experiment_id to avoid colision between distributed metric instances."
                )

            # Store file paths and locks and we will release/delete them after the computation.
            self.file_paths = file_paths
            self.filelocks = filelocks

    def compute(self, *args, **kwargs) -> Optional[dict]:
        """Compute the metrics.

        Args:
            We disallow the usage of positional arguments to prevent mistakes
            `predictions` (Optional list/array/tensor): predictions
            `references` (Optional list/array/tensor): references
            `**kwargs` (Optional other kwargs): will be forwared to the metrics :func:`_compute` method (see details in the docstring)

        Return:
            Dictionnary with the metrics if this metric is run on the main process (process_id == 0)
            None if the metric is not run on the main process (process_id != 0)
        """
        if args:
            raise ValueError("Please call `compute` using keyword arguments.")

        predictions = kwargs.pop("predictions", None)
        references = kwargs.pop("references", None)

        if predictions is not None:
            self.add_batch(predictions=predictions, references=references)
        self._finalize()

        self.cache_file_name = None
        self.filelock = None

        if self.process_id == 0:
            self.data.set_format(type=self.info.format)

            predictions = self.data["predictions"]
            references = self.data["references"]
            with temp_seed(self.seed):
                output = self._compute(predictions=predictions, references=references, **kwargs)

            if self.buf_writer is not None:
                self.buf_writer = None
                del self.data
                self.data = None
            else:
                # Release locks and delete all the cache files
                for filelock, file_path in zip(self.filelocks, self.file_paths):
                    logger.info(f"Removing {file_path}")
                    del self.data
                    self.data = None
                    del self.writer
                    self.writer = None
                    os.remove(file_path)
                    filelock.release()

            return output
        else:
            return None

    def add_batch(self, *, predictions=None, references=None):
        """
        Add a batch of predictions and references for the metric's stack.
        """
        batch = {"predictions": predictions, "references": references}
        batch = self.info.features.encode_batch(batch)
        if self.writer is None:
            self._init_writer()
        try:
            self.writer.write_batch(batch)
        except pa.ArrowInvalid:
            raise ValueError(
                f"Predictions and/or references don't match the expected format.\n"
                f"Expected format: {self.features},\n"
                f"Input predictions: {predictions},\n"
                f"Input references: {references}"
            )

    def add(self, *, prediction=None, reference=None):
        """Add one prediction and reference for the metric's stack."""
        example = {"predictions": prediction, "references": reference}
        example = self.info.features.encode_example(example)
        if self.writer is None:
            self._init_writer()
        try:
            self.writer.write(example)
        except pa.ArrowInvalid:
            raise ValueError(
                f"Prediction and/or reference don't match the expected format.\n"
                f"Expected format: {self.features},\n"
                f"Input predictions: {prediction},\n"
                f"Input references: {reference}"
            )

    def _init_writer(self, timeout=1):
        if self.num_process > 1:
            if self.process_id == 0:
                file_path = os.path.join(self.data_dir, f"{self.experiment_id}-{self.num_process}-rdv.lock")
                self.rendez_vous_lock = FileLock(file_path)
                try:
                    self.rendez_vous_lock.acquire(timeout=timeout)
                except TimeoutError:
                    raise ValueError(
                        f"Error in _init_writer: another metric instance is already using the local cache file at {file_path}. "
                        f"Please specify an experiment_id (currently: {self.experiment_id}) to avoid colision "
                        f"between distributed metric instances."
                    )

        if self.keep_in_memory:
            self.buf_writer = pa.BufferOutputStream()
            self.writer = ArrowWriter(
                features=self.info.features, stream=self.buf_writer, writer_batch_size=self.writer_batch_size
            )
        else:
            self.buf_writer = None

            # Get cache file name and lock it
            if self.cache_file_name is None or self.filelock is None:
                cache_file_name, filelock = self._create_cache_file()  # get ready
                self.cache_file_name = cache_file_name
                self.filelock = filelock

            self.writer = ArrowWriter(
                features=self.info.features, path=self.cache_file_name, writer_batch_size=self.writer_batch_size
            )
        # Setup rendez-vous here if
        if self.num_process > 1:
            if self.process_id == 0:
                self._check_all_processes_locks()  # wait for everyone to be ready
                self.rendez_vous_lock.release()  # let everyone go
            else:
                self._check_rendez_vous()  # wait for master to be ready and to let everyone go

    def _info(self) -> MetricInfo:
        """Construct the MetricInfo object. See `MetricInfo` for details.

        Warning: This function is only called once and the result is cached for all
        following .info() calls.

        Returns:
            info: (MetricInfo) The metrics information
        """
        raise NotImplementedError

    def download_and_prepare(
        self,
        download_config: Optional[DownloadConfig] = None,
        dl_manager: Optional[DownloadManager] = None,
        **download_and_prepare_kwargs,
    ):
        """Downloads and prepares dataset for reading.

        Args:
            download_config (Optional ``datasets.DownloadConfig``: specific download configuration parameters.
            dl_manager (Optional ``datasets.DownloadManager``): specific Download Manger to use
        """
        if dl_manager is None:
            if download_config is None:
                download_config = DownloadConfig()
                download_config.cache_dir = os.path.join(self.data_dir, "downloads")
                download_config.force_download = False

            dl_manager = DownloadManager(
                dataset_name=self.name, download_config=download_config, data_dir=self.data_dir
            )

        self._download_and_prepare(dl_manager)

    def _download_and_prepare(self, dl_manager):
        """Downloads and prepares resources for the metric.

        This is the internal implementation to overwrite called when user calls
        `download_and_prepare`. It should download all required resources for the metric.

        Args:
            dl_manager: (DownloadManager) `DownloadManager` used to download and cache
                data..
        """
        return None

    def _compute(self, *, predictions=None, references=None, **kwargs) -> Dict[str, Any]:
        """ This method defines the common API for all the metrics in the library """
        raise NotImplementedError

    def __del__(self):
        if self.filelock is not None:
            self.filelock.release()
        if self.rendez_vous_lock is not None:
            self.rendez_vous_lock.release()
        if hasattr(self, "writer"):  # in case it was already deleted
            del self.writer
        if hasattr(self, "data"):  # in case it was already deleted
            del self.data
Ejemplo n.º 14
0
from transformers.utils import check_min_version


# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.6.0.dev0")

logger = logging.getLogger(__name__)

try:
    nltk.data.find("tokenizers/punkt")
except (LookupError, OSError):
    if is_offline_mode():
        raise LookupError(
            "Offline mode: run this script without TRANSFORMERS_OFFLINE first to download nltk data files"
        )
    with FileLock(".lock") as lock:
        nltk.download("punkt", quiet=True)


@dataclass
class ModelArguments:
    """
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
    """

    model_name_or_path: str = field(
        metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
    )
    config_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
    )
Ejemplo n.º 15
0
def get_from_cache(url,
                   cache_dir=None,
                   force_download=False,
                   proxies=None,
                   etag_timeout=10,
                   resume_download=False,
                   user_agent=None):
    """
    Given a URL, look for the corresponding dataset in the local cache.
    If it's not there, download it. Then return the path to the cached file.
    """
    if cache_dir is None:
        cache_dir = TRANSFORMERS_CACHE
    if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
        cache_dir = str(cache_dir)
    if sys.version_info[0] == 2 and not isinstance(cache_dir, str):
        cache_dir = str(cache_dir)

    if not os.path.exists(cache_dir):
        os.makedirs(cache_dir)

    # Get eTag to add to filename, if it exists.
    if url.startswith("s3://"):
        etag = s3_etag(url, proxies=proxies)
    else:
        try:
            response = requests.head(url,
                                     allow_redirects=True,
                                     proxies=proxies,
                                     timeout=etag_timeout)
            if response.status_code != 200:
                etag = None
            else:
                etag = response.headers.get("ETag")
        except (EnvironmentError, requests.exceptions.Timeout):
            etag = None

    if sys.version_info[0] == 2 and etag is not None:
        etag = etag.decode("utf-8")
    filename = url_to_filename(url, etag)

    # get cache path to put the file
    cache_path = os.path.join(cache_dir, filename)

    # If we don't have a connection (etag is None) and can't identify the file
    # try to get the last downloaded one
    if not os.path.exists(cache_path) and etag is None:
        matching_files = [
            file
            for file in fnmatch.filter(os.listdir(cache_dir), filename + ".*")
            if not file.endswith(".json") and not file.endswith(".lock")
        ]
        if matching_files:
            cache_path = os.path.join(cache_dir, matching_files[-1])

    # Prevent parallel downloads of the same file with a lock.
    lock_path = cache_path + ".lock"
    with FileLock(lock_path):

        if resume_download:
            incomplete_path = cache_path + ".incomplete"

            @contextmanager
            def _resumable_file_manager():
                with open(incomplete_path, "a+b") as f:
                    yield f

            temp_file_manager = _resumable_file_manager
            if os.path.exists(incomplete_path):
                resume_size = os.stat(incomplete_path).st_size
            else:
                resume_size = 0
        else:
            temp_file_manager = partial(tempfile.NamedTemporaryFile,
                                        dir=cache_dir,
                                        delete=False)
            resume_size = 0

        if etag is not None and (not os.path.exists(cache_path)
                                 or force_download):
            # Download to temporary file, then copy to cache dir once finished.
            # Otherwise you get corrupt cache entries if the download gets interrupted.
            with temp_file_manager() as temp_file:
                logger.info(
                    "%s not found in cache or force_download set to True, downloading to %s",
                    url, temp_file.name)

                # GET file object
                if url.startswith("s3://"):
                    if resume_download:
                        logger.warn(
                            'Warning: resumable downloads are not implemented for "s3://" urls'
                        )
                    s3_get(url, temp_file, proxies=proxies)
                else:
                    http_get(url,
                             temp_file,
                             proxies=proxies,
                             resume_size=resume_size,
                             user_agent=user_agent)

                # we are copying the file before closing it, so flush to avoid truncation
                temp_file.flush()

                logger.info("storing %s in cache at %s", url, cache_path)
                os.rename(temp_file.name, cache_path)

                logger.info("creating metadata file for %s", cache_path)
                meta = {"url": url, "etag": etag}
                meta_path = cache_path + ".json"
                with open(meta_path, "w") as meta_file:
                    output_string = json.dumps(meta)
                    if sys.version_info[0] == 2 and isinstance(
                            output_string, str):
                        output_string = unicode(output_string,
                                                "utf-8")  # noqa: F821
                    meta_file.write(output_string)

    return cache_path
Ejemplo n.º 16
0
    def read(
        self, file_path: Union[Path, str]
    ) -> Union[AllennlpDataset, AllennlpLazyDataset]:
        """
        Returns an dataset containing all the instances that can be read from the file path.

        If `self.lazy` is `False`, this eagerly reads all instances from `self._read()`
        and returns an `AllennlpDataset`.

        If `self.lazy` is `True`, this returns an `AllennlpLazyDataset`, which internally
        relies on the generator created from `self._read()` to lazily produce `Instance`s.
        In this case your implementation of `_read()` must also be lazy
        (that is, not load all instances into memory at once), otherwise
        you will get a `ConfigurationError`.

        In either case, the returned `Iterable` can be iterated
        over multiple times. It's unlikely you want to override this function,
        but if you do your result should likewise be repeatedly iterable.
        """
        if not isinstance(file_path, str):
            file_path = str(file_path)

        lazy = getattr(self, "lazy", None)

        if lazy is None:
            warnings.warn(
                "DatasetReader.lazy is not set, "
                "did you forget to call the superclass constructor?",
                UserWarning,
            )

        if lazy:
            return AllennlpLazyDataset(self._instance_iterator, file_path)
        else:
            cache_file: Optional[str] = None
            if self._cache_directory:
                cache_file = self._get_cache_location_for_file_path(file_path)

            if cache_file is not None and os.path.exists(cache_file):
                try:
                    # Try to acquire a lock just to make sure another process isn't in the middle
                    # of writing to the cache.
                    cache_file_lock = FileLock(
                        cache_file + ".lock",
                        timeout=self.CACHE_FILE_LOCK_TIMEOUT)
                    cache_file_lock.acquire()
                    # We make an assumption here that if we can obtain the lock, no one will
                    # be trying to write to the file anymore, so it should be safe to release the lock
                    # before reading so that other processes can also read from it.
                    cache_file_lock.release()
                    logger.info("Reading instances from cache %s", cache_file)
                    instances = self._instances_from_cache_file(cache_file)
                except Timeout:
                    logger.warning(
                        "Failed to acquire lock on dataset cache file within %d seconds. "
                        "Cannot use cache to read instances.",
                        self.CACHE_FILE_LOCK_TIMEOUT,
                    )
                    instances = self._multi_worker_islice(
                        self._read(file_path))
            else:
                instances = self._multi_worker_islice(self._read(file_path))

            # Then some validation.
            if not isinstance(instances, list):
                instances = list(instances)

            if not instances:
                raise ConfigurationError(
                    "No instances were read from the given filepath {}. "
                    "Is the path correct?".format(file_path))

            # And finally we try writing to the cache.
            if cache_file is not None and not os.path.exists(cache_file):
                if self.max_instances is not None:
                    # But we don't write to the cache when max_instances is specified.
                    logger.warning(
                        "Skipping writing to data cache since max_instances was specified."
                    )
                elif util.is_distributed() or (get_worker_info() and
                                               get_worker_info().num_workers):
                    # We also shouldn't write to the cache if there's more than one process loading
                    # instances since each worker only receives a partial share of the instances.
                    logger.warning(
                        "Can't cache data instances when there are multiple processes loading data"
                    )
                else:
                    try:
                        with FileLock(cache_file + ".lock",
                                      timeout=self.CACHE_FILE_LOCK_TIMEOUT):
                            self._instances_to_cache_file(
                                cache_file, instances)
                    except Timeout:
                        logger.warning(
                            "Failed to acquire lock on dataset cache file within %d seconds. "
                            "Cannot write to cache.",
                            self.CACHE_FILE_LOCK_TIMEOUT,
                        )

            return AllennlpDataset(instances)
Ejemplo n.º 17
0
def get_from_cache(
    url,
    cache_dir=None,
    force_download=False,
    proxies=None,
    etag_timeout=10,
    resume_download=False,
    user_agent: Union[Dict, str, None] = None,
    local_files_only=False,
) -> Optional[str]:
    """
    Given a URL, look for the corresponding file in the local cache.
    If it's not there, download it. Then return the path to the cached file.

    Return:
        None in case of non-recoverable file (non-existent or inaccessible url + no cache on disk).
        Local path (string) otherwise
    """
    if cache_dir is None:
        cache_dir = TRANSFORMERS_CACHE
    if isinstance(cache_dir, Path):
        cache_dir = str(cache_dir)

    os.makedirs(cache_dir, exist_ok=True)

    etag = None
    if not local_files_only:
        try:
            response = requests.head(url,
                                     allow_redirects=True,
                                     proxies=proxies,
                                     timeout=etag_timeout)
            if response.status_code == 200:
                etag = response.headers.get("ETag")
        except (EnvironmentError, requests.exceptions.Timeout):
            # etag is already None
            pass

    filename = url_to_filename(url, etag)

    # get cache path to put the file
    cache_path = os.path.join(cache_dir, filename)

    # etag is None = we don't have a connection, or url doesn't exist, or is otherwise inaccessible.
    # try to get the last downloaded one
    if etag is None:
        if os.path.exists(cache_path):
            return cache_path
        else:
            matching_files = [
                file
                for file in fnmatch.filter(os.listdir(cache_dir), filename +
                                           ".*")
                if not file.endswith(".json") and not file.endswith(".lock")
            ]
            if len(matching_files) > 0:
                return os.path.join(cache_dir, matching_files[-1])
            else:
                # If files cannot be found and local_files_only=True,
                # the models might've been found if local_files_only=False
                # Notify the user about that
                if local_files_only:
                    raise ValueError(
                        "Cannot find the requested files in the cached path and outgoing traffic has been"
                        " disabled. To enable model look-ups and downloads online, set 'local_files_only'"
                        " to False.")
                return None

    # From now on, etag is not None.
    if os.path.exists(cache_path) and not force_download:
        return cache_path

    # Prevent parallel downloads of the same file with a lock.
    lock_path = cache_path + ".lock"
    with FileLock(lock_path):

        # If the download just completed while the lock was activated.
        if os.path.exists(cache_path) and not force_download:
            # Even if returning early like here, the lock will be released.
            return cache_path

        if resume_download:
            incomplete_path = cache_path + ".incomplete"

            @contextmanager
            def _resumable_file_manager():
                with open(incomplete_path, "a+b") as f:
                    yield f

            temp_file_manager = _resumable_file_manager
            if os.path.exists(incomplete_path):
                resume_size = os.stat(incomplete_path).st_size
            else:
                resume_size = 0
        else:
            temp_file_manager = partial(tempfile.NamedTemporaryFile,
                                        dir=cache_dir,
                                        delete=False)
            resume_size = 0

        # Download to temporary file, then copy to cache dir once finished.
        # Otherwise you get corrupt cache entries if the download gets interrupted.
        with temp_file_manager() as temp_file:
            logger.info(
                "%s not found in cache or force_download set to True, downloading to %s",
                url, temp_file.name)

            http_get(url,
                     temp_file,
                     proxies=proxies,
                     resume_size=resume_size,
                     user_agent=user_agent)

        logger.info("storing %s in cache at %s", url, cache_path)
        os.replace(temp_file.name, cache_path)

        logger.info("creating metadata file for %s", cache_path)
        meta = {"url": url, "etag": etag}
        meta_path = cache_path + ".json"
        with open(meta_path, "w") as meta_file:
            json.dump(meta, meta_file)

    return cache_path
Ejemplo n.º 18
0
def cached_path(
    url_or_filename,
    cache_dir=None,
    force_download=False,
    proxies=None,
    resume_download=False,
    user_agent: Union[Dict, str, None] = None,
    extract_compressed_file=False,
    force_extract=False,
    local_files_only=False,
) -> Optional[str]:
    """
    Given something that might be a URL (or might be a local path),
    determine which. If it's a URL, download the file and cache it, and
    return the path to the cached file. If it's already a local path,
    make sure the file exists and then return the path.
    Args:
        cache_dir: specify a cache directory to save the file to (overwrite the default cache dir).
        force_download: if True, re-dowload the file even if it's already cached in the cache dir.
        resume_download: if True, resume the download if incompletly recieved file is found.
        user_agent: Optional string or dict that will be appended to the user-agent on remote requests.
        extract_compressed_file: if True and the path point to a zip or tar file, extract the compressed
            file in a folder along the archive.
        force_extract: if True when extract_compressed_file is True and the archive was already extracted,
            re-extract the archive and overide the folder where it was extracted.

    Return:
        None in case of non-recoverable file (non-existent or inaccessible url + no cache on disk).
        Local path (string) otherwise
    """
    if cache_dir is None:
        cache_dir = TRANSFORMERS_CACHE
    if isinstance(url_or_filename, Path):
        url_or_filename = str(url_or_filename)
    if isinstance(cache_dir, Path):
        cache_dir = str(cache_dir)

    if is_remote_url(url_or_filename):
        # URL, so get it from the cache (downloading if necessary)
        output_path = get_from_cache(
            url_or_filename,
            cache_dir=cache_dir,
            force_download=force_download,
            proxies=proxies,
            resume_download=resume_download,
            user_agent=user_agent,
            local_files_only=local_files_only,
        )
    elif os.path.exists(url_or_filename):
        # File, and it exists.
        output_path = url_or_filename
    elif urlparse(url_or_filename).scheme == "":
        # File, but it doesn't exist.
        raise EnvironmentError("file {} not found".format(url_or_filename))
    else:
        # Something unknown
        raise ValueError(
            "unable to parse {} as a URL or as a local path".format(
                url_or_filename))

    if extract_compressed_file:
        if not is_zipfile(output_path) and not tarfile.is_tarfile(output_path):
            return output_path

        # Path where we extract compressed archives
        # We avoid '.' in dir name and add "-extracted" at the end: "./model.zip" => "./model-zip-extracted/"
        output_dir, output_file = os.path.split(output_path)
        output_extract_dir_name = output_file.replace(".", "-") + "-extracted"
        output_path_extracted = os.path.join(output_dir,
                                             output_extract_dir_name)

        if os.path.isdir(output_path_extracted) and os.listdir(
                output_path_extracted) and not force_extract:
            return output_path_extracted

        # Prevent parallel extractions
        lock_path = output_path + ".lock"
        with FileLock(lock_path):
            shutil.rmtree(output_path_extracted, ignore_errors=True)
            os.makedirs(output_path_extracted)
            if is_zipfile(output_path):
                with ZipFile(output_path, "r") as zip_file:
                    zip_file.extractall(output_path_extracted)
                    zip_file.close()
            elif tarfile.is_tarfile(output_path):
                tar_file = tarfile.open(output_path)
                tar_file.extractall(output_path_extracted)
                tar_file.close()
            else:
                raise EnvironmentError(
                    "Archive format of {} could not be identified".format(
                        output_path))

        return output_path_extracted

    return output_path
Ejemplo n.º 19
0
def get_from_cache(url,
                   cache_dir=None,
                   force_download=False,
                   proxies=None,
                   etag_timeout=10,
                   resume_download=False,
                   user_agent=None) -> Optional[str]:
    """
    Given a URL, look for the corresponding file in the local cache.
    If it's not there, download it. Then return the path to the cached file.

    Return:
        None in case of non-recoverable file (non-existent or inaccessible url + no cache on disk).
        Local path (string) otherwise
    """
    if cache_dir is None:
        cache_dir = TRANSFORMERS_CACHE
    if isinstance(cache_dir, Path):
        cache_dir = str(cache_dir)

    os.makedirs(cache_dir, exist_ok=True)

    # Get eTag to add to filename, if it exists.
    if url.startswith("s3://"):
        etag = s3_etag(url, proxies=proxies)
    else:
        try:
            response = requests.head(url,
                                     allow_redirects=True,
                                     proxies=proxies,
                                     timeout=etag_timeout)
            if response.status_code != 200:
                etag = None
            else:
                etag = response.headers.get("ETag")
        except (EnvironmentError, requests.exceptions.Timeout):
            etag = None

    filename = url_to_filename(url, etag)

    # get cache path to put the file
    cache_path = os.path.join(cache_dir, filename)

    # etag is None = we don't have a connection, or url doesn't exist, or is otherwise inaccessible.
    # try to get the last downloaded one
    if etag is None:
        if os.path.exists(cache_path):
            return cache_path
        else:
            matching_files = [
                file
                for file in fnmatch.filter(os.listdir(cache_dir), filename +
                                           ".*")
                if not file.endswith(".json") and not file.endswith(".lock")
            ]
            if len(matching_files) > 0:
                return os.path.join(cache_dir, matching_files[-1])
            else:
                return None

    # From now on, etag is not None.
    if os.path.exists(cache_path) and not force_download:
        return cache_path

    # Prevent parallel downloads of the same file with a lock.
    lock_path = cache_path + ".lock"
    with FileLock(lock_path):

        if resume_download:
            incomplete_path = cache_path + ".incomplete"

            @contextmanager
            def _resumable_file_manager():
                with open(incomplete_path, "a+b") as f:
                    yield f

            temp_file_manager = _resumable_file_manager
            if os.path.exists(incomplete_path):
                resume_size = os.stat(incomplete_path).st_size
            else:
                resume_size = 0
        else:
            temp_file_manager = partial(tempfile.NamedTemporaryFile,
                                        dir=cache_dir,
                                        delete=False)
            resume_size = 0

        # Download to temporary file, then copy to cache dir once finished.
        # Otherwise you get corrupt cache entries if the download gets interrupted.
        with temp_file_manager() as temp_file:
            logger.info(
                "%s not found in cache or force_download set to True, downloading to %s",
                url, temp_file.name)

            # GET file object
            if url.startswith("s3://"):
                if resume_download:
                    logger.warn(
                        'Warning: resumable downloads are not implemented for "s3://" urls'
                    )
                s3_get(url, temp_file, proxies=proxies)
            else:
                http_get(url,
                         temp_file,
                         proxies=proxies,
                         resume_size=resume_size,
                         user_agent=user_agent)

        logger.info("storing %s in cache at %s", url, cache_path)
        os.rename(temp_file.name, cache_path)

        logger.info("creating metadata file for %s", cache_path)
        meta = {"url": url, "etag": etag}
        meta_path = cache_path + ".json"
        with open(meta_path, "w") as meta_file:
            json.dump(meta, meta_file)

    return cache_path
Ejemplo n.º 20
0
def get_from_cache(
    url: str,
    cache_dir=None,
    force_download=False,
    proxies=None,
    etag_timeout=10,
    resume_download=False,
    user_agent: Union[Dict, str, None] = None,
    local_files_only=False,
) -> Optional[str]:
    """
    Given a URL, look for the corresponding file in the local cache. If it's not there, download it. Then return the
    path to the cached file.

    Return:
        Local path (string) of file or if networking is off, last version of file cached on disk.

    Raises:
        In case of non-recoverable file (non-existent or inaccessible url + no cache on disk).
    """
    if cache_dir is None:
        cache_dir = TRANSFORMERS_CACHE
    if isinstance(cache_dir, Path):
        cache_dir = str(cache_dir)

    os.makedirs(cache_dir, exist_ok=True)

    url_to_download = url
    etag = None
    if not local_files_only:
        try:
            headers = {"user-agent": http_user_agent(user_agent)}
            r = requests.head(url,
                              headers=headers,
                              allow_redirects=False,
                              proxies=proxies,
                              timeout=etag_timeout)
            r.raise_for_status()
            etag = r.headers.get("X-Linked-Etag") or r.headers.get("ETag")
            # We favor a custom header indicating the etag of the linked resource, and
            # we fallback to the regular etag header.
            # If we don't have any of those, raise an error.
            if etag is None:
                raise OSError(
                    "Distant resource does not have an ETag, we won't be able to reliably ensure reproducibility."
                )
            # In case of a redirect,
            # save an extra redirect on the request.get call,
            # and ensure we download the exact atomic version even if it changed
            # between the HEAD and the GET (unlikely, but hey).
            if 300 <= r.status_code <= 399:
                url_to_download = r.headers["Location"]
        except (requests.exceptions.ConnectionError,
                requests.exceptions.Timeout):
            # etag is already None
            pass

    filename = url_to_filename(url, etag)

    # get cache path to put the file
    cache_path = os.path.join(cache_dir, filename)

    # etag is None == we don't have a connection or we passed local_files_only.
    # try to get the last downloaded one
    if etag is None:
        if os.path.exists(cache_path):
            return cache_path
        else:
            matching_files = [
                file for file in fnmatch.filter(os.listdir(cache_dir),
                                                filename.split(".")[0] + ".*")
                if not file.endswith(".json") and not file.endswith(".lock")
            ]
            if len(matching_files) > 0:
                return os.path.join(cache_dir, matching_files[-1])
            else:
                # If files cannot be found and local_files_only=True,
                # the models might've been found if local_files_only=False
                # Notify the user about that
                if local_files_only:
                    raise ValueError(
                        "Cannot find the requested files in the cached path and outgoing traffic has been"
                        " disabled. To enable model look-ups and downloads online, set 'local_files_only'"
                        " to False.")
                else:
                    raise ValueError(
                        "Connection error, and we cannot find the requested files in the cached path."
                        " Please try again or make sure your Internet connection is on."
                    )

    # From now on, etag is not None.
    if os.path.exists(cache_path) and not force_download:
        return cache_path

    # Prevent parallel downloads of the same file with a lock.
    lock_path = cache_path + ".lock"
    with FileLock(lock_path):

        # If the download just completed while the lock was activated.
        if os.path.exists(cache_path) and not force_download:
            # Even if returning early like here, the lock will be released.
            return cache_path

        if resume_download:
            incomplete_path = cache_path + ".incomplete"

            @contextmanager
            def _resumable_file_manager() -> "io.BufferedWriter":
                with open(incomplete_path, "ab") as f:
                    yield f

            temp_file_manager = _resumable_file_manager
            if os.path.exists(incomplete_path):
                resume_size = os.stat(incomplete_path).st_size
            else:
                resume_size = 0
        else:
            temp_file_manager = partial(tempfile.NamedTemporaryFile,
                                        mode="wb",
                                        dir=cache_dir,
                                        delete=False)
            resume_size = 0

        # Download to temporary file, then copy to cache dir once finished.
        # Otherwise you get corrupt cache entries if the download gets interrupted.
        with temp_file_manager() as temp_file:
            logger.info(
                "%s not found in cache or force_download set to True, downloading to %s",
                url, temp_file.name)

            http_get(url_to_download,
                     temp_file,
                     proxies=proxies,
                     resume_size=resume_size,
                     user_agent=user_agent)

        logger.info("storing %s in cache at %s", url, cache_path)
        os.replace(temp_file.name, cache_path)

        logger.info("creating metadata file for %s", cache_path)
        meta = {"url": url, "etag": etag}
        meta_path = cache_path + ".json"
        with open(meta_path, "w") as meta_file:
            json.dump(meta, meta_file)

    return cache_path
Ejemplo n.º 21
0
 def initialize_control_file(self) -> None:
     with FileLock(LTIGradesSenderControlFile.lock_file):
         with Path(self.config_fullname).open('w+') as new_file:
             json.dump(LTIGradesSenderControlFile.cache_sender_data, new_file)
             logger.debug('Control file initialized.')
Ejemplo n.º 22
0
    def __init__(
        self,
        name: str = None,
        experiment_id: Optional[str] = None,
        process_id: int = 0,
        num_process: int = 1,
        data_dir: Optional[str] = None,
        in_memory: bool = False,
        hash: str = None,
        seed: Optional[int] = None,
        **kwargs,
    ):
        """ A Metrics is the base class and common API for all metrics.
            Args:
                process_id (``int``): specify the id of the node in a distributed settings between 0 and num_nodes-1
                    This can be used, to compute metrics on distributed setups
                    (in particular non-additive metrics like F1).
                data_dir (``str``): path to a directory in which temporary data will be stored.
                    This should be a shared file-system for distributed setups.
                hash (``str``): can be used to define a hash specific to the metrics computation script
                    This prevents the metric's data to be overridden when the metric loading script is modified.
                experiment_id (Optional ``str``): Should be used if you perform several concurrent experiments using
                    the same caching directory (will be indicated in the raise error)
                in_memory (``bool``): keep all predictions and references in memory. Not possible in distributed settings.
                seed (Optional ``int``): If specified, this will temporarily set numpy's random seed when :func:`nlp.Metric.compute` is run.
        """
        # Safety checks
        assert isinstance(
            process_id, int
        ) and process_id >= 0, "'process_id' should be a number greater than 0"
        assert (isinstance(num_process, int) and num_process > process_id
                ), "'num_process' should be a number greater than process_id"
        assert (
            process_id == 0 or not in_memory
        ), "Using 'in_memory' is not possible in distributed setting (process_id > 0)."

        # Metric name
        self.name = camelcase_to_snakecase(self.__class__.__name__)
        # Configuration name
        self.config_name: str = name or "default"

        self.process_id = process_id
        self.num_process = num_process
        self.in_memory = in_memory
        self.experiment_id = experiment_id if experiment_id is not None else "cache"
        self.hash = hash
        self._version = "1.0.0"
        self._data_dir_root = os.path.expanduser(data_dir or HF_METRICS_CACHE)
        self.data_dir = self._build_data_dir()
        self.seed: int = seed or np.random.get_state()[1][0]

        # prepare info
        info = self._info()
        info.metric_name = self.name
        info.config_name = self.config_name
        info.version = self._version
        self.info = info

        # Update 'compute' and 'add' docstring
        # methods need to be copied otherwise it changes the docstrings of every instance
        self.compute = types.MethodType(copyfunc(self.compute), self)
        self.add_batch = types.MethodType(copyfunc(self.add_batch), self)
        self.add = types.MethodType(copyfunc(self.add), self)
        self.compute.__func__.__doc__ += self.info.inputs_description
        self.add_batch.__func__.__doc__ += self.info.inputs_description
        self.add.__func__.__doc__ += self.info.inputs_description

        self.arrow_schema = pa.schema(field
                                      for field in self.info.features.type)
        self.buf_writer = None
        self.writer = None
        self.writer_batch_size = None
        self.data = None

        # Check we can write on the cache file without competitors
        self.cache_file_name = self._get_cache_path(self.process_id)
        self.filelock = FileLock(self.cache_file_name + ".lock")
        try:
            self.filelock.acquire(timeout=1)
        except Timeout:
            raise ValueError(
                "Cannot acquire lock, caching file might be used by another process, "
                "you should setup a unique 'experiment_id' for this run.")
Ejemplo n.º 23
0
class Metric(object):
    def __init__(
        self,
        name: str = None,
        experiment_id: Optional[str] = None,
        process_id: int = 0,
        num_process: int = 1,
        data_dir: Optional[str] = None,
        in_memory: bool = False,
        hash: str = None,
        seed: Optional[int] = None,
        **kwargs,
    ):
        """ A Metrics is the base class and common API for all metrics.
            Args:
                process_id (``int``): specify the id of the node in a distributed settings between 0 and num_nodes-1
                    This can be used, to compute metrics on distributed setups
                    (in particular non-additive metrics like F1).
                data_dir (``str``): path to a directory in which temporary data will be stored.
                    This should be a shared file-system for distributed setups.
                hash (``str``): can be used to define a hash specific to the metrics computation script
                    This prevents the metric's data to be overridden when the metric loading script is modified.
                experiment_id (Optional ``str``): Should be used if you perform several concurrent experiments using
                    the same caching directory (will be indicated in the raise error)
                in_memory (``bool``): keep all predictions and references in memory. Not possible in distributed settings.
                seed (Optional ``int``): If specified, this will temporarily set numpy's random seed when :func:`nlp.Metric.compute` is run.
        """
        # Safety checks
        assert isinstance(
            process_id, int
        ) and process_id >= 0, "'process_id' should be a number greater than 0"
        assert (isinstance(num_process, int) and num_process > process_id
                ), "'num_process' should be a number greater than process_id"
        assert (
            process_id == 0 or not in_memory
        ), "Using 'in_memory' is not possible in distributed setting (process_id > 0)."

        # Metric name
        self.name = camelcase_to_snakecase(self.__class__.__name__)
        # Configuration name
        self.config_name: str = name or "default"

        self.process_id = process_id
        self.num_process = num_process
        self.in_memory = in_memory
        self.experiment_id = experiment_id if experiment_id is not None else "cache"
        self.hash = hash
        self._version = "1.0.0"
        self._data_dir_root = os.path.expanduser(data_dir or HF_METRICS_CACHE)
        self.data_dir = self._build_data_dir()
        self.seed: int = seed or np.random.get_state()[1][0]

        # prepare info
        info = self._info()
        info.metric_name = self.name
        info.config_name = self.config_name
        info.version = self._version
        self.info = info

        # Update 'compute' and 'add' docstring
        # methods need to be copied otherwise it changes the docstrings of every instance
        self.compute = types.MethodType(copyfunc(self.compute), self)
        self.add_batch = types.MethodType(copyfunc(self.add_batch), self)
        self.add = types.MethodType(copyfunc(self.add), self)
        self.compute.__func__.__doc__ += self.info.inputs_description
        self.add_batch.__func__.__doc__ += self.info.inputs_description
        self.add.__func__.__doc__ += self.info.inputs_description

        self.arrow_schema = pa.schema(field
                                      for field in self.info.features.type)
        self.buf_writer = None
        self.writer = None
        self.writer_batch_size = None
        self.data = None

        # Check we can write on the cache file without competitors
        self.cache_file_name = self._get_cache_path(self.process_id)
        self.filelock = FileLock(self.cache_file_name + ".lock")
        try:
            self.filelock.acquire(timeout=1)
        except Timeout:
            raise ValueError(
                "Cannot acquire lock, caching file might be used by another process, "
                "you should setup a unique 'experiment_id' for this run.")

    def _relative_data_dir(self, with_version=True):
        """ Relative path of this metric in cache_dir:
            Will be:
                self.name/self.config_name/self.config.version/self.hash/
            If any of these element is missing or if ``with_version=False`` the corresponding subfolders are dropped.
        """
        builder_data_dir = os.path.join(self.name, self.config_name)
        if with_version:
            builder_data_dir = os.path.join(builder_data_dir,
                                            str(self._version))
        if self.hash:
            builder_data_dir = os.path.join(builder_data_dir, self.hash)
        return builder_data_dir

    def _build_data_dir(self):
        """ Return the directory for the current version.
        """
        builder_data_dir = os.path.join(
            self._data_dir_root, self._relative_data_dir(with_version=False))
        version_data_dir = os.path.join(
            self._data_dir_root, self._relative_data_dir(with_version=True))

        def _other_versions_on_disk():
            """Returns previous versions on disk."""
            if not os.path.exists(builder_data_dir):
                return []

            version_dirnames = []
            for dir_name in os.listdir(builder_data_dir):
                try:
                    version_dirnames.append((Version(dir_name), dir_name))
                except ValueError:  # Invalid version (ex: incomplete data dir)
                    pass
            version_dirnames.sort(reverse=True)
            return version_dirnames

        # Check and warn if other versions exist on disk
        version_dirs = _other_versions_on_disk()
        if version_dirs:
            other_version = version_dirs[0][0]
            if other_version != self._version:
                warn_msg = (
                    "Found a different version {other_version} of metric {name} in "
                    "data_dir {data_dir}. Using currently defined version "
                    "{cur_version}.".format(
                        other_version=str(other_version),
                        name=self.name,
                        data_dir=self._data_dir_root,
                        cur_version=str(self._version),
                    ))
                logger.warning(warn_msg)

        os.makedirs(version_data_dir, exist_ok=True)
        return version_data_dir

    def _get_cache_path(self, node_id):
        return os.path.join(
            self.data_dir, f"{self.experiment_id}-{self.name}-{node_id}.arrow")

    def finalize(self, timeout=120):
        """ Close all the writing process and load/gather the data
            from all the nodes if main node or all_process is True.
        """
        self.writer.finalize()
        self.writer = None
        self.buf_writer = None
        self.filelock.release()

        if self.process_id == 0:
            # Let's acquire a lock on each node files to be sure they are finished writing
            node_files = []
            locks = []
            for node_id in range(self.num_process):
                node_file = self._get_cache_path(node_id)
                filelock = FileLock(node_file + ".lock")
                filelock.acquire(timeout=timeout)
                node_files.append({"filename": node_file})
                locks.append(filelock)

            # Read the predictions and references
            reader = ArrowReader(path=self.data_dir, info=None)
            self.data = Dataset(**reader.read_files(node_files))

            # Release all of our locks
            for lock in locks:
                lock.release()

    def compute(self,
                predictions=None,
                references=None,
                timeout=120,
                **metrics_kwargs):
        """ Compute the metrics.
        """
        if predictions is not None:
            self.add_batch(predictions=predictions, references=references)
        self.finalize(timeout=timeout)

        self.data.set_format(type=self.info.format)

        predictions = self.data["predictions"]
        references = self.data["references"]
        with temp_seed(self.seed):
            output = self._compute(predictions=predictions,
                                   references=references,
                                   **metrics_kwargs)
        return output

    def add_batch(self, predictions=None, references=None, **kwargs):
        """ Add a batch of predictions and references for the metric's stack.
        """
        batch = {"predictions": predictions, "references": references}
        if self.writer is None:
            self._init_writer()
        self.writer.write_batch(batch)

    def add(self, prediction=None, reference=None, **kwargs):
        """ Add one prediction and reference for the metric's stack.
        """
        example = {"predictions": prediction, "references": reference}
        example = self.info.features.encode_example(example)
        if self.writer is None:
            self._init_writer()
        self.writer.write(example)

    def _init_writer(self):
        if self.in_memory:
            self.buf_writer = pa.BufferOutputStream()
            self.writer = ArrowWriter(schema=self.arrow_schema,
                                      stream=self.buf_writer,
                                      writer_batch_size=self.writer_batch_size)
        else:
            self.buf_writer = None
            self.writer = ArrowWriter(schema=self.arrow_schema,
                                      path=self.cache_file_name,
                                      writer_batch_size=self.writer_batch_size)

    def _info(self) -> MetricInfo:
        """Construct the MetricInfo object. See `MetricInfo` for details.

        Warning: This function is only called once and the result is cached for all
        following .info() calls.

        Returns:
            info: (MetricInfo) The metrics information
        """
        raise NotImplementedError

    def download_and_prepare(
        self,
        download_config: Optional[DownloadConfig] = None,
        dl_manager: Optional[DownloadManager] = None,
        **download_and_prepare_kwargs,
    ):
        """Downloads and prepares dataset for reading.

        Args:
            download_config (Optional ``nlp.DownloadConfig``: specific download configuration parameters.
            dl_manager (Optional ``nlp.DownloadManager``): specific Download Manger to use
        """
        if dl_manager is None:
            if download_config is None:
                download_config = DownloadConfig()
                download_config.cache_dir = os.path.join(
                    self.data_dir, "downloads")
                download_config.force_download = False

            dl_manager = DownloadManager(dataset_name=self.name,
                                         download_config=download_config,
                                         data_dir=self.data_dir)

        self._download_and_prepare(dl_manager)

    def _download_and_prepare(self, dl_manager):
        """Downloads and prepares resources for the metric.

        This is the internal implementation to overwrite called when user calls
        `download_and_prepare`. It should download all required resources for the metric.

        Args:
            dl_manager: (DownloadManager) `DownloadManager` used to download and cache
                data..
        """
        return None

    def _compute(self,
                 predictions=None,
                 references=None,
                 **kwargs) -> Dict[str, Any]:
        """ This method defines the common API for all the metrics in the library """
        raise NotImplementedError
Ejemplo n.º 24
0
class CustomLogGui(wx.LogGui):
    """Logger."""

    def __init__(self, log_file, debug=False):
        """Initialize."""

        self.format = "%(message)s"
        self.file_name = log_file
        self.file_lock = FileLock(self.file_name + '.lock')
        self.debug = debug

        try:
            with self.file_lock.acquire(1):
                with codecs.open(self.file_name, "w", "utf-8") as f:
                    f.write("")
        except Exception:
            self.file_name = None

        wx.LogGui.__init__(self)

    def DoLogText(self, msg):
        """Log the text."""

        try:
            if self.file_name is not None:
                with self.file_lock.acquire(1):
                    with codecs.open(self.file_name, 'a', encoding='utf-8') as f:
                        f.write(msg)
            else:
                msg = "[ERROR] Could not acquire lock for log!\n" + msg
        except Exception:
            self.file_name = None

        if self.debug:
            sys.stdout.write(
                (self.format % {"message": msg})
            )

        wx.LogGui.DoLogText(self, msg)

    def DoLogTextAtLevel(self, level, msg):
        """Perform log at level."""

        current = self.GetLogLevel()

        if level <= current and level == wx.LOG_Info:
            self._debug(msg)
        elif level <= current and level == wx.LOG_FatalError:
            self._critical(msg)
        elif level <= current and level == wx.LOG_Warning:
            self._warning(msg)
        elif level <= current and level == wx.LOG_Error:
            self._error(msg)

    def formatter(self, lvl, log_fmt, msg, msg_fmt=None):
        """Special formatters for log message."""

        return log_fmt % {
            "loglevel": lvl,
            "message": util.to_ustr(msg if msg_fmt is None else msg_fmt(msg))
        }

    def _log(self, msg):
        """Base logger."""

        return self.format % {"message": msg}

    def _debug(self, msg, log_fmt="%(loglevel)s: %(message)s\n"):
        """Debug level log."""

        self.DoLogText(self._log(self.formatter("DEBUG", log_fmt, msg)))

    def _critical(self, msg, log_fmt="%(loglevel)s: %(message)s\n"):
        """Critical level log."""

        self.DoLogText(self._log(self.formatter("CRITICAL", log_fmt, msg)))

    def _warning(self, msg, log_fmt="%(loglevel)s: %(message)s\n"):
        """Warning level log."""

        self.DoLogText(self._log(self.formatter("WARNING", log_fmt, msg)))

    def _error(self, msg, log_fmt="%(loglevel)s: %(message)s\n"):
        """Error level log."""

        self.DoLogText(self._log(self.formatter("ERROR", log_fmt, msg)))
Ejemplo n.º 25
0
        def __init__(
            self,
            data_dir: Optional[str],
            tokenizer: PreTrainedTokenizer,
            labels: Optional[List[str]],
            model_type: str,
            max_seq_length: Optional[int] = None,
            overwrite_cache=False,
            mode: Split = Split.train,
            # or init from features and not data_dir. this case we'll use data_dir as cache_dir or set it to tmp dir
            features: Optional[List[InputFeatures]] = None):
            if not data_dir:
                assert features is not None, "Must provide either `data_dir` with dataset files, or `features` to init dataset from."
                data_dir = get_tmp_dir()
                os.makedirs(data_dir, exist_ok=True)
            # Load data features from cache or dataset file
            cached_features_file = os.path.join(
                data_dir,
                "cached_{}_{}_{}".format(mode.value,
                                         tokenizer.__class__.__name__,
                                         str(max_seq_length)),
            )

            if features is None:  # init from data_dir
                # Make sure only the first process in distributed training processes the dataset,
                # and the others will use the cache.
                lock_path = cached_features_file + ".lock"
                with FileLock(lock_path):

                    if os.path.exists(
                            cached_features_file) and not overwrite_cache:
                        logger.info(
                            f"Loading features from cached file {cached_features_file}"
                        )
                        self.features = torch.load(cached_features_file)
                    else:
                        logger.info(
                            f"Creating features from dataset file at {data_dir}"
                        )
                        examples = read_examples_from_file(data_dir, mode)
                        # TODO clean up all this to leverage built-in features of tokenizers
                        self.features = convert_examples_to_features(
                            examples,
                            labels,
                            max_seq_length,
                            tokenizer,
                            cls_token_at_end=bool(model_type in ["xlnet"]),
                            # xlnet has a cls token at the end
                            cls_token=tokenizer.cls_token,
                            cls_token_segment_id=2
                            if model_type in ["xlnet"] else 0,
                            sep_token=tokenizer.sep_token,
                            sep_token_extra=bool(model_type in ["roberta"]),
                            # roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805
                            pad_on_left=bool(tokenizer.padding_side == "left"),
                            pad_token=tokenizer.pad_token_id,
                            pad_token_segment_id=tokenizer.pad_token_type_id,
                            pad_token_label_id=NomIdDataset.pad_token_label_id,
                        )
                        logger.info(
                            f"Saving features into cached file {cached_features_file}"
                        )
                        torch.save(self.features, cached_features_file)

            else:  # init from features
                self.features = features
                torch.save(self.features, cached_features_file)
Ejemplo n.º 26
0
    def _instance_iterator(self, file_path: str) -> Iterable[Instance]:
        cache_file: Optional[str] = None
        if self._cache_directory:
            cache_file = self._get_cache_location_for_file_path(file_path)

        if cache_file is not None and os.path.exists(cache_file):
            cache_file_lock = FileLock(cache_file + ".lock",
                                       timeout=self.CACHE_FILE_LOCK_TIMEOUT)
            try:
                cache_file_lock.acquire()
                # We make an assumption here that if we can obtain the lock, no one will
                # be trying to write to the file anymore, so it should be safe to release the lock
                # before reading so that other processes can also read from it.
                cache_file_lock.release()
                logger.info("Reading instances from cache %s", cache_file)
                with open(cache_file) as data_file:
                    yield from self._multi_worker_islice(
                        data_file, transform=self.deserialize_instance)
            except Timeout:
                logger.warning(
                    "Failed to acquire lock on dataset cache file within %d seconds. "
                    "Cannot use cache to read instances.",
                    self.CACHE_FILE_LOCK_TIMEOUT,
                )
                yield from self._multi_worker_islice(self._read(file_path),
                                                     ensure_lazy=True)
        elif cache_file is not None and not os.path.exists(cache_file):
            instances = self._multi_worker_islice(self._read(file_path),
                                                  ensure_lazy=True)
            # The cache file doesn't exist so we'll try writing to it.
            if self.max_instances is not None:
                # But we don't write to the cache when max_instances is specified.
                logger.warning(
                    "Skipping writing to data cache since max_instances was specified."
                )
                yield from instances
            elif util.is_distributed() or (get_worker_info()
                                           and get_worker_info().num_workers):
                # We also shouldn't write to the cache if there's more than one process loading
                # instances since each worker only receives a partial share of the instances.
                logger.warning(
                    "Can't cache data instances when there are multiple processes loading data"
                )
                yield from instances
            else:
                try:
                    with FileLock(cache_file + ".lock",
                                  timeout=self.CACHE_FILE_LOCK_TIMEOUT):
                        with CacheFile(cache_file, mode="w+") as cache_handle:
                            logger.info("Caching instances to temp file %s",
                                        cache_handle.name)
                            for instance in instances:
                                cache_handle.write(
                                    self.serialize_instance(instance) + "\n")
                                yield instance
                except Timeout:
                    logger.warning(
                        "Failed to acquire lock on dataset cache file within %d seconds. "
                        "Cannot write to cache.",
                        self.CACHE_FILE_LOCK_TIMEOUT,
                    )
                    yield from instances
        else:
            # No cache.
            yield from self._multi_worker_islice(self._read(file_path),
                                                 ensure_lazy=True)
Ejemplo n.º 27
0
 def add(tool, bugid, ts=lmdutils.get_timestamp("now"), extra=""):
     check(BugChange.__tablename__)
     with FileLock(lock_path):
         session.add(BugChange(tool, ts, bugid, extra))
         session.commit()
Ejemplo n.º 28
0
        version = None

    if len(id) != 16:
        raise IOError('Invalid title id format')

    if Titles.contains(id):
        title = Titles.get(id)
        cdn.downloadTitle(title.id.lower(), version, key or title.key)
    else:
        cdn.downloadTitle(id.lower(), version, key)
    return True


if __name__ == '__main__':
    try:
        with FileLock("nut2.lock") as lock:
            urllib3.disable_warnings()

            #signal.signal(signal.SIGINT, handler)

            parser = argparse.ArgumentParser()
            parser.add_argument('file', nargs='*')
            parser.add_argument('-g',
                                '--ganymede',
                                help='ganymede config file')
            parser.add_argument('-i',
                                '--info',
                                help='show info about title or file')
            parser.add_argument('--depth',
                                type=int,
                                default=1,
Ejemplo n.º 29
0
    def __init__(self, lock_path, save_path, provider_config):
        self.lock = RLock()
        self.file_lock = FileLock(lock_path)
        self.save_path = save_path

        with self.lock:
            with self.file_lock:
                if os.path.exists(self.save_path):
                    workers = json.loads(open(self.save_path).read())
                    head_config = workers.get(provider_config["head_ip"])
                    if (not head_config or
                            head_config.get("tags", {}).get(TAG_RAY_NODE_KIND)
                            != NODE_KIND_HEAD):
                        workers = {}
                        logger.info("Head IP changed - recreating cluster.")
                else:
                    workers = {}
                logger.info("ClusterState: "
                            "Loaded cluster state: {}".format(list(workers)))
                for worker_ip in provider_config["worker_ips"]:
                    if worker_ip not in workers:
                        workers[worker_ip] = {
                            "tags": {
                                TAG_RAY_NODE_KIND: NODE_KIND_WORKER
                            },
                            "state": "terminated",
                        }
                    else:
                        assert (workers[worker_ip]["tags"][TAG_RAY_NODE_KIND]
                                == NODE_KIND_WORKER)
                if provider_config["head_ip"] not in workers:
                    workers[provider_config["head_ip"]] = {
                        "tags": {
                            TAG_RAY_NODE_KIND: NODE_KIND_HEAD
                        },
                        "state": "terminated",
                    }
                else:
                    assert (workers[provider_config["head_ip"]]["tags"][
                        TAG_RAY_NODE_KIND] == NODE_KIND_HEAD)
                # Relevant when a user reduces the number of workers
                # without changing the headnode.
                list_of_node_ips = list(provider_config["worker_ips"])
                list_of_node_ips.append(provider_config["head_ip"])
                for worker_ip in list(workers):
                    if worker_ip not in list_of_node_ips:
                        del workers[worker_ip]

                # Set external head ip, if provided by user.
                # Necessary if calling `ray up` from outside the network.
                # Refer to LocalNodeProvider.external_ip function.
                external_head_ip = provider_config.get("external_head_ip")
                if external_head_ip:
                    head = workers[provider_config["head_ip"]]
                    head["external_ip"] = external_head_ip

                assert len(workers) == len(provider_config["worker_ips"]) + 1
                with open(self.save_path, "w") as f:
                    logger.debug("ClusterState: "
                                 "Writing cluster state: {}".format(workers))
                    f.write(json.dumps(workers))
Ejemplo n.º 30
0
def train_func(config):
    epochs = config.pop("epochs", 3)
    model = ResNet18(config)
    model = train.torch.prepare_model(model)

    # Create optimizer.
    optimizer = torch.optim.SGD(
        model.parameters(),
        lr=config.get("lr", 0.1),
        momentum=config.get("momentum", 0.9),
    )

    # Load in training and validation data.
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])  # meanstd transformation

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])

    with FileLock(".ray.lock"):
        train_dataset = CIFAR10(root="~/data",
                                train=True,
                                download=True,
                                transform=transform_train)
        validation_dataset = CIFAR10(root="~/data",
                                     train=False,
                                     download=False,
                                     transform=transform_test)

    if config.get("test_mode"):
        train_dataset = Subset(train_dataset, list(range(64)))
        validation_dataset = Subset(validation_dataset, list(range(64)))

    worker_batch_size = config["batch_size"] // session.get_world_size()

    train_loader = DataLoader(train_dataset, batch_size=worker_batch_size)
    validation_loader = DataLoader(validation_dataset,
                                   batch_size=worker_batch_size)

    train_loader = train.torch.prepare_data_loader(train_loader)
    validation_loader = train.torch.prepare_data_loader(validation_loader)

    # Create loss.
    criterion = nn.CrossEntropyLoss()

    results = []
    for _ in range(epochs):
        train_epoch(train_loader, model, criterion, optimizer)
        result = validate_epoch(validation_loader, model, criterion)
        session.report(result)
        results.append(result)

    # return required for backwards compatibility with the old API
    # TODO(team-ml) clean up and remove return
    return results
def cache_adata(session_ID,
                adata=None,
                group=None,
                store_dir=None,
                store_name=None):
    if ((store_dir is None) or (store_name is None)):
        save_dir = save_analysis_path + str(session_ID) + "/"
        filename = save_dir + "adata_cache"
        chunk_factors = [150, 3]  #faster, hot storage
    else:
        save_dir = store_dir
        filename = save_dir + store_name
        chunk_factors = [3, 3]  #slower, cold storage

    if not (os.path.isdir(save_dir)):
        try:
            print("[DEBUG] making directory:" + str(save_dir))
            os.mkdir(save_dir)
        except:
            return None

    lock_filename = (save_analysis_path + str(session_ID) + "/" + "adata.lock")
    lock = FileLock(lock_filename, timeout=lock_timeout)

    compressor = Blosc(cname='blosclz', clevel=3, shuffle=Blosc.SHUFFLE)
    zarr_cache_dir = filename + ".zarr"
    attribute_groups = [
        "obs", "var", "obsm", "varm", "obsp", "varp", "layers", "X", "uns",
        "raw"
    ]
    extra_attribute_groups = ["X_dense", "layers_dense"]

    if (adata is None):  # then -> read it from the store
        if (os.path.exists(zarr_cache_dir) is True):
            store_store = zarr.DirectoryStore(zarr_cache_dir)
            store = zarr.open_group(store=store_store, mode='r')
            if (group in attribute_groups
                ):  # then -> return only that part of the object (fast)
                group_exists = adata_cache_group_exists(session_ID,
                                                        group,
                                                        store=store)
                if (group_exists is True):
                    ret = read_attribute(store[group])
                else:
                    ret = None
                #store_store.close()
                return ret
            elif (group is
                  None):  # then -> return the whole adata object (slow)
                #adata = ad.read_zarr(zarr_cache_dir)
                d = {}
                for g in attribute_groups:
                    if (g in store.keys()):
                        if (adata_cache_group_exists(session_ID,
                                                     g,
                                                     store=store)):
                            if (g in ["obs", "var"]):
                                d[g] = read_dataframe(store[g])
                            else:
                                d[g] = read_attribute(store[g])
                #store_store.close()
                adata = ad.AnnData(**d)
                if not (adata is None):
                    return adata
                else:
                    print("[ERROR] adata object not saved at: " +
                          str(filename))
                    return None
    else:  # then -> update the state dictionary and write adata to the store
        if (group is None):
            cache_state(session_ID,
                        key="# cells/obs",
                        val=len(adata.obs.index))
            cache_state(session_ID,
                        key="# genes/var",
                        val=len(adata.var.index))
            if ("total_counts" in adata.obs):
                cache_state(session_ID,
                            key="# counts",
                            val=int(np.sum(adata.obs["total_counts"])))
            else:
                cache_state(session_ID,
                            key="# counts",
                            val=int(np.sum(adata.X)))

        elif (group == "obs"):
            cache_state(session_ID, key="# cells/obs", val=len(adata.index))
        elif (group == "var"):
            cache_state(session_ID, key="# genes/var", val=len(adata.index))
        with lock:
            store_store = zarr.DirectoryStore(zarr_cache_dir)
            store = zarr.open_group(store=store_store, mode='a')
            if (group in attribute_groups
                ):  # then -> write only that part of the object (fast)
                if (group == "var"):
                    if (np.nan in adata.var.index):
                        adata.var.index = pd.Series(adata.var.index).replace(
                            np.nan, 'nanchung')
                        adata.var["gene_ID"] = pd.Series(
                            adata.var["gene_ID"]).replace(np.nan, 'nanchung')
                        adata.var["gene_ids"] = pd.Series(
                            adata.var["gene_ids"]).replace(np.nan, 'nanchung')
                write_attribute(
                    store, group,
                    adata)  # here "adata" is actually just a subset of adata

                # write dense copies of X or layers if they're what was passed
                if (group == "X"):
                    dense_name = "X_dense"
                    write_dense.delay(zarr_cache_dir, "X", dense_name,
                                      chunk_factors)

                if (group == "layers"):
                    for l in list(adata.keys(
                    )):  #layers was passed with parameter name "adata"
                        dense_name = "layers_dense/" + str(l)
                        write_dense.delay(zarr_cache_dir, "layers/" + l,
                                          dense_name, chunk_factors)
                #store_store.flush()
                #store_store.close()
                lock.release()
            else:
                # check that necessary fields are present in adata object
                if not ("leiden_n" in adata.obs):
                    if ("leiden" in adata.obs):
                        adata.obs["leiden_n"] = pd.to_numeric(
                            adata.obs["leiden"])
                if not ("cell_ID" in adata.obs):
                    adata.obs["cell_ID"] = adata.obs.index
                if not ("cell_numeric_index" in adata.obs):
                    adata.obs["cell_numeric_index"] = pd.to_numeric(
                        list(range(0, len(adata.obs.index))))
                for i in ["user_" + str(j) for j in range(0, 6)]:
                    if not (i in adata.obs.columns):
                        adata.obs[i] = ["0" for k in adata.obs.index.to_list()]
                if not ("gene_ID" in adata.var):
                    adata.var["gene_ID"] = adata.var.index

                # make sure that there are no "nan" genes in the var index
                if (np.nan in adata.var.index):
                    adata.var.index = pd.Series(adata.var.index).replace(
                        np.nan, 'nanchung')
                    adata.var["gene_ID"] = pd.Series(
                        adata.var["gene_ID"]).replace(np.nan, 'nanchung')
                    adata.var["gene_ids"] = pd.Series(
                        adata.var["gene_ids"]).replace(np.nan, 'nanchung')

                # save it all to the cache, but make dense copies of X and layers
                write_attribute(store, "obs", adata.obs)
                write_attribute(store, "var", adata.var)
                write_attribute(store, "obsm", adata.obsm)
                write_attribute(store, "varm", adata.varm)
                write_attribute(store, "obsp", adata.obsp)
                write_attribute(store, "varp", adata.varp)
                write_attribute(store, "uns", adata.uns)
                write_attribute(store, "raw", adata.raw)
                write_attribute(store, "X", adata.X)
                write_attribute(store, "layers", adata.layers)

                # making dense copies of X and layers (compressed to save disk space)

                dense_name = "X_dense"
                write_dense.delay(zarr_cache_dir, "X", dense_name,
                                  chunk_factors)

                for l in list(adata.layers.keys()):
                    dense_name = "layers_dense/" + str(l)
                    write_dense.delay(zarr_cache_dir, "layers/" + l,
                                      dense_name, chunk_factors)

                #store_store.flush()
                #store_store.close()
                lock.release()
            # set the file mod and access times to current time
            # then return adata as usual
            os.utime(zarr_cache_dir)
            return adata
Ejemplo n.º 32
0
 def __init__(self, lock_file, *args, **kwargs):
     self.filelock = FileLock(lock_file)
     super().__init__(lock_file, *args, **kwargs)
Ejemplo n.º 33
0
def cacheLock(cache):
    lock = FileLock("x", timeout=2)
    lock.lockfile = os.path.join(cache.cacheDirectory(), "cache.lock")
    return lock
Ejemplo n.º 34
0
def _set_latest_prequal_id(value):
    lock = FileLock(prequal_id_counter_lock_file)
    with lock:
        open(prequal_id_counter_file, "w").write()
Ejemplo n.º 35
0
def generate_mailer_file(campaign_id):
    app.logger.info('Executing generate_mailer_file...')
    campaign = Campaign.query.get(campaign_id)

    candidates = campaign.candidates.all()

    mapping = {
        'candidate.first_name': 'first',
        'candidate.last_name': 'last',
        'candidate.address': 'address',
        'candidate.city': 'city',
        'candidate.state': 'st',
        'candidate.zip5': 'zip',
        'campaign.phone': 'phone_numb',
        'campaign.job_number': 'job_number',
        'campaign.mailing_date': 'mailing_da',
        'campaign.offer_expire_date': 'offer_expi',
        'candidate.prequal_number': 'prequal',
        'candidate.estimated_debt': 'debt',
        'candidate.debt3': 'debt3',
        'candidate.debt15': 'debt15',
        'candidate.debt2': 'debt2',
        'candidate.debt215': 'debt215',
        'candidate.debt3_2': 'debt3_2',
        'candidate.checkamt': 'checkamt',
        'candidate.spellamt': 'spellamt',
        'candidate.debt315': 'debt315',
        'candidate.year_interest': 'int_yr',
        'candidate.total_interest': 'tot_int',
        'candidate.sav215': 'sav215',
        'candidate.sav15': 'sav15',
        'candidate.sav315': 'sav315'
    }

    filters = {
        'debt': _money,
        'debt3': _money,
        'debt15': _money,
        'debt2': _money,
        'debt215': _money,
        'debt3_2': _money,
        'checkamt': _money,
        'debt315': _money,
        'int_yr': _money,
        'tot_int': _money,
        'sav215': _money,
        'sav15': _money,
        'sav315': _money
    }

    try:
        file_path = _get_mailer_file(campaign)

        lock = FileLock(prequal_id_counter_lock_file)
        with lock:
            if not path.isfile(prequal_id_counter_file):
                latest_prequal_id = 'A10000'
            else:
                latest_prequal_id = open(prequal_id_counter_file, "r").read()

            gen_prequal_func = _generate_prequal_id(latest_prequal_id)

            with open(file_path, 'w') as csvFile:
                writer = csv.DictWriter(csvFile,
                                        fieldnames=mapping.values(),
                                        quoting=csv.QUOTE_ALL)
                writer.writeheader()

                for candidate in candidates:
                    record = {}
                    if not candidate.prequal_number:
                        try:
                            latest_prequal_id = next(gen_prequal_func)
                        except StopIteration:
                            letter, number = latest_prequal_id[:
                                                               1], latest_prequal_id[
                                                                   1:]
                            new_prequal_id = f'{chr(ord(letter) + 1)}10000'
                            gen_prequal_func = _generate_prequal_id(
                                new_prequal_id)
                            latest_prequal_id = next(gen_prequal_func)

                        candidate.prequal_number = latest_prequal_id

                    for source, key in mapping.items():
                        model, attr = source.split('.')
                        if model == 'candidate':
                            record[key] = _filter(
                                filters, key,
                                getattr(candidate, attr, 'MISSING_VALUE'))
                        elif model == 'campaign':
                            record[key] = _filter(
                                filters, key,
                                getattr(campaign, attr, 'MISSING_VALUE'))
                        else:
                            record[key] = 'UNKNOWN_SOURCE_VALUE'

                    writer.writerow(record)

            campaign.mailer_file = file_path
            open(prequal_id_counter_file, 'w').write(latest_prequal_id)
            db.session.commit()
    finally:
        csvFile.close()
Ejemplo n.º 36
0
#!/usr/bin/env python3

import sys
from time import sleep
from filelock import FileLock

lock = FileLock("my_lock")
with lock:
    print("This is process {}.".format(sys.argv[1]))
    sleep(1)
    print("Bye.")
Ejemplo n.º 37
0
	chsum.update( data)

	read_len += len(data)
	new_progress = int(100.0 * read_len / file_in_size)
	if new_progress != progress:
		progress = new_progress
		print('PROGRESS: %d%%' % progress)
		sys.stdout.flush()

file_in.close()
result = chsum.hexdigest()
print( result)

obj = dict()
if os.path.isfile( Options.output):
	lock = FileLock( Options.output, 100, 1)
	file_out = open( Options.output, 'r')
	obj = json.load( file_out)
	file_out.close()
	lock.release()
	#print( json.dumps( obj))

file_in_name = os.path.basename( Options.input)
if not 'files' in obj: obj['files'] = dict()
if not file_in_name in obj['files']: obj['files'][file_in_name] = dict()
if not 'checksum' in obj['files'][file_in_name]: obj['files'][file_in_name]['checksum'] = dict()
obj['files'][file_in_name]['checksum'][Options.type] = result
obj['files'][file_in_name]['checksum']['time'] = time.time()
result = json.dumps( obj, indent=1)

lock = FileLock( Options.output, 100, 1)
Ejemplo n.º 38
0
recipientsFile = 'config/recipients.json'
# report file path
reportFile = 'output/report.md'
# Size of report file with no data or header data
reportFileEmptySize = 0
entries = []
recipients = []
dataIsChanged = False
recipientsIsChanged = False

# Html fetcher default data - configured for bankier.pl
defaultHtmlElement = 'div'
defaultHtmlElementClasses = 'box300 boxGrey border3 right'

# Locks creation
lockConfig = FileLock(configFile + '.lock', timeout=lockTimeout)
lockRecipents = FileLock(recipientsFile + '.lock', timeout=lockTimeout)

# Emergency ForceExit


def ForceExit(message, ErrorCode=1):
    global lockRecipents
    global lockConfig
    print('(Stock-manager) %s\n' % (message))
    lockRecipents.release()
    lockConfig.release()
    sys.exit(ErrorCode)


# Entry handling
Ejemplo n.º 39
0
    pmode = RESET_MODE
    # get prefix for environment name
    envprefix = sys.argv[1]
    # get number of environments to create
    nenvs = sys.argv[2]
elif nargs > 1:
    # returns env to the list of environments
    pmode = WRITE_MODE
    # get name of environment to return
    env = sys.argv[1]
else:
    # gets name of an environment to use
    pmode = READ_MODE

# creates a lock for the file so it can only be accessed one at a time
lock = FileLock(lock_path, timeout=time_out_secs)

with lock:
    if pmode == RESET_MODE:
        # create a list (named clist) of nevns environments with the
        # prefix envprefix
        # add code here
        clist = []
        for i in range(0, int(nenvs)):
            b = envprefix + str(i)
            clist.append(b)
    else:
        # load hickle file
        clist = hickle.load(file_path)

        if pmode == WRITE_MODE:
Ejemplo n.º 40
0
def mirror_file_with_lock(fname, lockfile="/tmp/.lockfile_hk_vectorgen_mirror"):
    lock = FileLock(lockfile, timeout=60*10)
    lock.lock()
    result = mirror_file(fname)
    lock.release()
    return result
Ejemplo n.º 41
0
def ReportAssets(filepath):
    with FileLock(filepath + '.lock', timeout=lockTimeout):
        if os.path.isfile(filepath):
            with open(filepath, 'a+') as f:
                stockAssets.Report(f, 'zl')
Ejemplo n.º 42
0
def main():
    dirs_to_process = []

    parser = argparse.ArgumentParser(description='Manage parallel workers.',
                                     parents=[
                                         get_baseparser(
                                             tool_version=__version__)
                                     ])
    parser.add_argument('-w', '--workers', default=multiprocessing.cpu_count(),
                        help='Number of worker processes')

    # There can be only one way how to supply list of projects to process.
    group1 = parser.add_mutually_exclusive_group()
    group1.add_argument('-d', '--directory',
                        help='Directory to process')
    group1.add_argument('-P', '--projects', nargs='*',
                        help='List of projects to process')

    parser.add_argument('-I', '--indexed', action='store_true',
                        help='Sync indexed projects only')
    parser.add_argument('-i', '--ignore_errors', nargs='*',
                        help='ignore errors from these projects')
    parser.add_argument('-c', '--config', required=True,
                        help='config file in JSON format')
    parser.add_argument('-U', '--uri', default='http://localhost:8080/source',
                        help='URI of the webapp with context path')
    parser.add_argument('-f', '--driveon', action='store_true', default=False,
                        help='continue command sequence processing even '
                        'if one of the commands requests break')
    try:
        args = parser.parse_args()
    except ValueError as e:
        fatal(e)

    logger = get_console_logger(get_class_basename(), args.loglevel)

    uri = args.uri
    if not is_web_uri(uri):
        logger.error("Not a URI: {}".format(uri))
        sys.exit(1)
    logger.debug("web application URI = {}".format(uri))

    # First read and validate configuration file as it is mandatory argument.
    config = read_config(logger, args.config)
    if config is None:
        logger.error("Cannot read config file from {}".format(args.config))
        sys.exit(1)

    # Changing working directory to root will avoid problems when running
    # programs via sudo/su. Do this only after the config file was read
    # so that its path can be specified as relative.
    try:
        os.chdir("/")
    except OSError:
        logger.error("cannot change working directory to /",
                     exc_info=True)
        sys.exit(1)

    try:
        commands = config["commands"]
    except KeyError:
        logger.error("The config file has to contain key \"commands\"")
        sys.exit(1)

    directory = args.directory
    if not args.directory and not args.projects and not args.indexed:
        # Assume directory, get the source root value from the webapp.
        directory = get_config_value(logger, 'sourceRoot', uri)
        if not directory:
            logger.error("Neither -d or -P or -I specified and cannot get "
                         "source root from the webapp")
            sys.exit(1)
        else:
            logger.info("Assuming directory: {}".format(directory))

    ignore_errors = []
    if args.ignore_errors:
        ignore_errors = args.ignore_errors
    else:
        try:
            ignore_errors = config["ignore_errors"]
        except KeyError:
            pass
    logger.debug("Ignored projects: {}".format(ignore_errors))

    lock = FileLock(os.path.join(tempfile.gettempdir(),
                                 "opengrok-sync.lock"))
    try:
        with lock.acquire(timeout=0):
            if args.projects:
                dirs_to_process = args.projects
                logger.debug("Processing directories: {}".
                             format(dirs_to_process))
            elif args.indexed:
                indexed_projects = list_indexed_projects(logger, uri)
                logger.debug("Processing indexed projects: {}".
                             format(indexed_projects))

                if indexed_projects:
                    for line in indexed_projects:
                        dirs_to_process.append(line.strip())
                else:
                    logger.error("cannot get list of projects")
                    sys.exit(1)
            else:
                logger.debug("Processing directory {}".format(directory))
                for entry in os.listdir(directory):
                    if path.isdir(path.join(directory, entry)):
                        dirs_to_process.append(entry)

            logger.debug("to process: {}".format(dirs_to_process))

            cmds_base = []
            for d in dirs_to_process:
                cmd_base = CommandSequenceBase(d, commands, args.loglevel,
                                               config.get("cleanup"),
                                               args.driveon)
                cmds_base.append(cmd_base)

            # Map the commands into pool of workers so they can be processed.
            with Pool(processes=int(args.workers)) as pool:
                try:
                    cmds_base_results = pool.map(worker, cmds_base, 1)
                except KeyboardInterrupt:
                    sys.exit(1)
                else:
                    for cmds_base in cmds_base_results:
                        logger.debug("Checking results of project {}".
                                     format(cmds_base))
                        cmds = CommandSequence(cmds_base)
                        cmds.fill(cmds_base.retcodes, cmds_base.outputs,
                                  cmds_base.failed)
                        cmds.check(ignore_errors)
    except Timeout:
        logger.warning("Already running, exiting.")
        sys.exit(1)
Ejemplo n.º 43
0
def ReportsAppend(filepath, data):
    with FileLock(filepath + '.lock', timeout=lockTimeout):
        if os.path.isfile(filepath):
            with open(filepath, 'a+') as f:
                f.write(data)
Ejemplo n.º 44
0
def allocate(length, place, tm=None):
    """
    1 最小代价找到一个可用文件, 从 _last_id 开始找起 -1 == stat('%d.lock')
    2 try lock
    3 stat again

    # 应对一个目录下文件数目太多问题
    - 在url中携带 bundle path 信息,并加密

    data = open('sample.jpg', 'rb').read()

    w = bundle.allocate(len(data), 'p')
    w.ensure_url(prefix='fmn04', postfix='.jpg')
    w.write(data)

    del w   # important to release the file lock

    >>> if not os.path.exists('p'): os.makedirs('p')
    >>> if os.path.exists('p/20110920/A'): os.unlink('p/20110920/A')
    >>> data = open('sample.jpg', 'rb').read()
    >>> w = allocate(len(data), place='p', tm=time.strptime('20110919', '%Y%m%d'))
    >>> w.ensure_url(prefix='fmn04', postfix='.jpg')
    >>> w.write(data)
    83283
    """

    global _last_id
    if _last_id is None:
        _last_id = os.getpid() % 10

    if tm is None:
        tm = time.localtime()

    loop_count = 0

    while True:
        # 0
        # 检测是否循环了一圈,采用备用方案: 直接分配一个超过 kBundleCountPerDay 的文件
        loop_count += 1
        if loop_count >= kBundleCountPerDay:
            logging.info('allocation policy loop failed')
            _last_id = kBundleCountPerDay + random.randint(1, 100)

        # 1
        bid = new_bid(_last_id % kBundleCountPerDay, tm)
        fn_bundle = bundle_filename(bid, place)
        try:
            statinfo = os.stat(fn_bundle)
        except OSError, e:
            if e.errno != errno.ENOENT:
                _last_id += 1 # TODO: lock
                continue

            statinfo = None

        if statinfo and statinfo.st_size > kMaxBundleSize:
            _last_id += 1 # TODO: lock
            continue

        # 2
        fn_lock = '%s/%s' % (_lock_place, bid) # TODO: /place/.lock/sub-place/
        dir_lock = os.path.dirname(fn_lock)
        if not os.path.exists(dir_lock):
            os.makedirs(dir_lock)

        flock = FileLock(fn_lock, timeout=0.5)
        try:
            if not flock.try_acquire():
                _last_id += 1 # TODO: lock
                continue
        except:
            _last_id += 1 # TODO: lock
            continue

        assert flock.is_locked

        # 3
        if statinfo is None:
            fp = create(fn_bundle)
            fp.close()
            offset = kBundleHeaderSize
        else:
            statinfo = os.stat(fn_bundle)
            if statinfo.st_size > kMaxBundleSize:
                _last_id += 1 # TODO: lock
                continue
            offset = statinfo.st_size
        assert offset >= kBundleHeaderSize, offset
        return Writer(flock, filename=fn_bundle, bid=bid, offset=offset, length=length)
Ejemplo n.º 45
0
 def read_at_position(self, position: int) -> EntityType:
     with FileLock(f"{self.file_path}.lock"):
         with open(self.file_path, "r") as file:
             file.seek(position)  # перемещаем курсор файла на нужную позицию
             return self._load_entity(file.readline())  # десериализуем считанную линию
Ejemplo n.º 46
0
class Store(object):
    auto_sync = True
    auto_mem_resync_counter = 0

    @staticmethod
    def is_valid_schema(self):
        return all([len(x) == 3 and x[1] in DUMPERS for x in self])

    def __init__(self, db_name, schema, log_level=1):
        """
        schema example = (
            (name, type, len),
            ...
        )

        int
        str
        """
        self._log_lvl = log_level
        db_name = os.path.join('db', db_name)

        self.lock = FileLock(db_name)

        if not self.is_valid_schema(schema):
            self.error("Invalid DB schema")
            raise Exception('Schema is bead!')

        all = self.init_mem(db_name, schema)

        self._memory = all
        self._db_name = db_name
        self._schema = schema

    def info(self, msg):
        if self._log_lvl > 3:
            print(msg)

    def error(self, msg):
        if self._log_lvl > 0:
            print(msg)

    def _insert(self, object_dict_):
        assert type(object_dict_) == dict, 'Wow! Not dict!'
        self._memory.append(object_dict_)

    def sync(self):
        self.lock.acquire()
        f = open(self._db_name, 'wb')
        for d in self._memory:
            for name_, type_, len_ in self._schema:
                bytes_bytes = DUMPERS[type_](d[name_], len_)
                f.write(bytes_bytes)
            f.write('\n')
        f.close()
        self.lock.release()
        self._memory = self.init_mem(self._db_name, self._schema)

    def execute(self, row):
        """
        Execute SQL query. Support: select, delete, insert.
        """
        # row = "select where value[1] == 'd' and id >= 0 limit 2"
        # row = "insert into ... values (2, 'awdwa')"
        # row = "delete where ... limit k"
        try:
            method, tail = row.split(' ', 1)
            method = method.lower()
            tail = tail.strip(' ')
            rez = None

            self.info(u'-- SQL {0} {1} --'.format(method, tail))

            if method == 'insert':
                r = re.compile(r'^.*?values?[ ]*(\(.*?\))$', re.M)
                z = r.match(tail)
                if z:
                    rez = [self.insert(*z.groups())]
                    if self.auto_sync:
                        self.sync()

            elif method in ['select', 'delete']:
                r = re.compile(r'^.*?(?:(?:where)[ ]*(.*?))?[ ]*(?:limit[ ]*(\d+))?[ ]*([dD][Ee][ScCs][ckCK])?[ ]*$')
                z = r.match(tail)
                if z:
                    rez = self.__getattribute__('go_go')(method, z.groups())
                else:
                    rez = self.__getattribute__(method)()
            elif method == 'last':
                rez = [self.last()]

            if hasattr(rez, '__len__') and rez.__len__() == 1:
                return rez[0]

            return rez
        except Exception as e:
            self.error("Invalid SQL syntax detected: {0!r} by {1}".format(row, e))
            raise Exception('Invalid SQL syntax!!')

    def go_go(self, method, args):
        return self.__getattribute__(method)(*args)

    def delete(self, where=None, limit=None, desk=None):
        limit = int(limit.strip()) if limit else 0
        where = 'True' if not where else where
        where = self.fix_where(where)

        rez = 0
        del_indexes = []
        l = locals()
        i = 0
        mem = self._memory if not desk else reversed(self._memory)
        for d in mem:
            for name_, type_, len_ in self._schema:
                l[name_] = d[name_]

            st = parser.expr(where)
            is_ok = eval(st.compile())

            if is_ok:
                rez += 1
                del_indexes.append(i)

            i += 1

            if limit and rez >= limit:
                break

        z = 0
        for x in sorted(del_indexes):
            self._delete_dy_index(x - z)
            z += 1

        return rez

    def _delete_dy_index(self, index):
        if 0 <= index < len(self._memory):
            del self._memory[index]
            return 1
        return 0

    def _delete_dy_indexes(self, *indexes):
        del_counter = 0
        for index in sorted(indexes):
            deleted = self._delete_dy_index(index - del_counter)
            del_counter += deleted
        return del_counter

    def _memory_dump(self):
        print '\n-- dump --'
        for d in self._memory:
            print d.values()
        print '-- |--| --\n'

    def select(self, where=None, limit=None, desk=None):
        limit = int(limit.strip()) if limit else 0
        where = 'True' if not where else where
        where = self.fix_where(where)

        rez = []
        l = locals()
        mem = self._memory if not desk else reversed(self._memory)
        for d in mem:
            for name_, type_, len_ in self._schema:
                l[name_] = d[name_]

            st = parser.expr(where)
            is_ok = eval(st.compile())

            if limit and len(rez) >= limit:
                return rez

            if is_ok:
                rez.append(d)
        return rez

    def insert(self, insert_obj_row):
        if not insert_obj_row.startswith('(') or not insert_obj_row.endswith(')'):
            return
        insert_obj_row = insert_obj_row.replace("'", "'''")
        insert_obj_row = insert_obj_row.replace("\"", "'''")
        try:
            st = parser.expr(insert_obj_row)
            obj = eval(st.compile())
            if type(obj) != tuple or len(obj) != len(self._schema):
                return

            d = {}
            i = 0
            for name_, type_, len_ in self._schema:
                d[name_] = obj[i]
                i += 1

                _ck = CHECKERS[type_]
                if _ck(d[name_]):
                    return

            self._insert(d)
            return d
        except Exception as e:
            self.error('Insertion error!', e)
            return

    def fix_where(self, where):
        z = where.replace(' = ', ' == ')
        z = z.replace('__', '')
        return z.replace('import', '')

    def last(self):
        pass

    def init_mem(self, db_name, schema):
        i = 0
        all = []
        data = {}

        self.lock.acquire()

        def read_tail(f):
            while True:
                c = f.read(1)
                if not c or c == '\n':
                    return

        try:
            _r = open(db_name, 'rb')
        except:
            self.info("Create New DB")
            try:
                t = open(db_name, 'a')
                t.close()
            except Exception as e:
                self.error("Can not create NEW DB", e)
                raise Exception('Can not create new DB')

        try:
            _r = open(db_name, 'rb')
        except:
            self.error("I/0 Error #1! Can not open!")
            raise Exception("I/0 Error #1! Can not open!")

        while 1:
            if i == len(schema):
                all.append(data)
                read_tail(_r)
                i = 0
                data = {}

            name_, type_, len_ = schema[i]

            d = _r.read(len_).replace('\0', '')
            if not d: break

            zero = _r.read(1)
            if not zero or ord(zero) != 0:
                read_tail(_r)
                i = 0
                data = {}
                continue

            try:
                data[name_] = type_(d)
            except:
                read_tail(_r)
                i = 0
                data = {}
            else:
                i += 1

        _r.close()

        self.lock.release()

        return all
Ejemplo n.º 47
0
 def write(self, entity: EntityType) -> int:
     with FileLock(f"{self.file_path}.lock"):
         with open(self.file_path, "a") as file:
             position = file.tell()  # получили текущую позицию в файле
             file.write(self._dump_entity(entity) + self.new_line)  # сериализовали сущность и записали в файл
             return position  # отдаем позицию записи в файле
Ejemplo n.º 48
0
def main():
    parser = argparse.ArgumentParser(description='project management.',
                                     formatter_class=argparse.
                                     ArgumentDefaultsHelpFormatter,
                                     parents=[get_baseparser(
                                         tool_version=__version__)
                                     ])

    parser.add_argument('-b', '--base', default="/var/opengrok",
                        help='OpenGrok instance base directory')
    parser.add_argument('-R', '--roconfig',
                        help='OpenGrok read-only configuration file')
    parser.add_argument('-U', '--uri', default='http://localhost:8080/source',
                        help='URI of the webapp with context path')
    parser.add_argument('-c', '--configmerge',
                        help='path to the ConfigMerge binary')
    parser.add_argument('--java', help='Path to java binary '
                                       '(needed for config merge program)')
    parser.add_argument('-j', '--jar', help='Path to jar archive to run')
    parser.add_argument('-u', '--upload', action='store_true',
                        help='Upload configuration at the end')
    parser.add_argument('-n', '--noop', action='store_true', default=False,
                        help='Do not run any commands or modify any config'
                             ', just report. Usually implies '
                             'the --debug option.')
    parser.add_argument('-N', '--nosourcedelete', action='store_true',
                        default=False, help='Do not delete source code when '
                                            'deleting a project')

    group = parser.add_mutually_exclusive_group()
    group.add_argument('-a', '--add', metavar='project', nargs='+',
                       help='Add project (assumes its source is available '
                            'under source root')
    group.add_argument('-d', '--delete', metavar='project', nargs='+',
                       help='Delete project and its data and source code')
    group.add_argument('-r', '--refresh', action='store_true',
                       help='Refresh configuration. If read-only '
                            'configuration is supplied, it is merged '
                            'with current '
                            'configuration.')

    try:
        args = parser.parse_args()
    except ValueError as e:
        fatal(e)

    doit = not args.noop
    configmerge = None

    #
    # Setup logger as a first thing after parsing arguments so that it can be
    # used through the rest of the program.
    #
    logger = get_console_logger(get_class_basename(), args.loglevel)

    if args.nosourcedelete and not args.delete:
        logger.error("The no source delete option is only valid for delete")
        sys.exit(1)

    # Set the base directory
    if args.base:
        if path.isdir(args.base):
            logger.debug("Using {} as instance base".
                         format(args.base))
        else:
            logger.error("Not a directory: {}\n"
                         "Set the base directory with the --base option."
                         .format(args.base))
            sys.exit(1)

    # If read-only configuration file is specified, this means read-only
    # configuration will need to be merged with active webapp configuration.
    # This requires config merge tool to be run so couple of other things
    # need to be checked.
    if args.roconfig:
        if path.isfile(args.roconfig):
            logger.debug("Using {} as read-only config".format(args.roconfig))
        else:
            logger.error("File {} does not exist".format(args.roconfig))
            sys.exit(1)

        configmerge_file = get_command(logger, args.configmerge,
                                       "opengrok-config-merge")
        if configmerge_file is None:
            logger.error("Use the --configmerge option to specify the path to"
                         "the config merge script")
            sys.exit(1)

        configmerge = [configmerge_file]
        if args.loglevel:
            configmerge.append('-l')
            configmerge.append(str(args.loglevel))

        if args.jar is None:
            logger.error('jar file needed for config merge tool, '
                         'use --jar to specify one')
            sys.exit(1)

    uri = args.uri
    if not is_web_uri(uri):
        logger.error("Not a URI: {}".format(uri))
        sys.exit(1)
    logger.debug("web application URI = {}".format(uri))

    lock = FileLock(os.path.join(tempfile.gettempdir(),
                                 os.path.basename(sys.argv[0]) + ".lock"))
    try:
        with lock.acquire(timeout=0):
            if args.add:
                for proj in args.add:
                    project_add(doit=doit, logger=logger,
                                project=proj,
                                uri=uri)

                config_refresh(doit=doit, logger=logger,
                               basedir=args.base,
                               uri=uri,
                               configmerge=configmerge,
                               jar_file=args.jar,
                               roconfig=args.roconfig,
                               java=args.java)
            elif args.delete:
                for proj in args.delete:
                    project_delete(logger=logger,
                                   project=proj,
                                   uri=uri, doit=doit,
                                   deletesource=not args.nosourcedelete)

                config_refresh(doit=doit, logger=logger,
                               basedir=args.base,
                               uri=uri,
                               configmerge=configmerge,
                               jar_file=args.jar,
                               roconfig=args.roconfig,
                               java=args.java)
            elif args.refresh:
                config_refresh(doit=doit, logger=logger,
                               basedir=args.base,
                               uri=uri,
                               configmerge=configmerge,
                               jar_file=args.jar,
                               roconfig=args.roconfig,
                               java=args.java)
            else:
                parser.print_help()
                sys.exit(1)

            if args.upload:
                main_config = get_config_file(basedir=args.base)
                if path.isfile(main_config):
                    if doit:
                        with io.open(main_config, mode='r',
                                     encoding="utf-8") as config_file:
                            config_data = config_file.read().encode("utf-8")
                            if not set_configuration(logger,
                                                     config_data, uri):
                                sys.exit(1)
                else:
                    logger.error("file {} does not exist".format(main_config))
                    sys.exit(1)
    except Timeout:
        logger.warning("Already running, exiting.")
        sys.exit(1)
Ejemplo n.º 49
0
 def load_json_file(self, filename):
     """Loads json file"""
     with FileLock(filename + ".lock"):
         with open(filename) as data_file:
             return json.load(data_file)