Example #1
0
def testBasicOperation():
    def addone(x):
        x['int'] += 1

    default_db = TinyDB('./db_location/default.json')
    real_table = default_db.table('real')
    print(f"{'*'*20}打开了数据库{default_db.name}{'*'*20}")
    for i in range(7):
        default_db.insert({"int": i + 1, "char": chr(65 + i)})
    print("对每一个元素进行打印操作:")
    print("default_db 中每一个int字段加1")
    default_db.update(addone)

    print("对每一个元素进行打印操作:")

    print("default_db中的所有表段为:", default_db.tables())
    print("default_db中所有的数据为:", default_db.all())

    default_db.truncate()
    print(f"{'*'*20}清除了所有表{'*'*20}")
    print("db中所有的表段为:", default_db.tables())
    print("default_db中所有的数据为:", default_db.all())

    print(f"{'*'*20}关闭了表{default_db.name}{'*'*20}")
    default_db.close()
Example #2
0
class JsonDb:
	def __init__(self):
		self.db = TinyDB('bot_api/Database/db.json')
		self.querier = Query()
		# db.insert({'email': '*****@*****.**', 'telegram_id': '1234',
		# 	'mail_unique_id_count':{'mail123':1,'mail124':1},
		# 	'mail_last_read':{'mail123':'today','mail124':'yesterday'},
		#	'mail_comment': {'mail123': "testing", 'mail124': "testing2"}})

	def checkUserExists(self,chat_id):
		result = self.db.search(self.querier.chat_id == str(chat_id))
		if len(result) == 0:
			return False,None
		else:
			return True,result[0]['email']

	def setUserName(self,username,chat_id):
		chk,old_username = self.checkUserExists(chat_id)
		if chk:
			return False,old_username
		else:
			try:
				self.db.insert({"email": username,"chat_id": str(chat_id),"encrypted":str(hash(username)),"mail_unique_id_count": {},"mail_last_read": {},"mail_comment": {},"config_count": {}})
				return True,username
			except Exception as e:
				print(e)
				return False,"error"

	def deleteDb(self):
		self.db.truncate()

	def getData(self):
		return self.db.all()
Example #3
0
class AddOutputBase(ABC):
    """Base class for updating DB with parsed portscan output."""
    def __init__(self, db, rm, scanner_name, scanner_args):
        """
		Constructor.

		:param db: a tinydb database file path
		:type db: tinydb.TinyDB
		:param rm: a flag showing if we need to drop the DB before updating its values
		:type rm: bool
		:param scanner_name: name of the port scanner to run
		:type scanner_name: str
		:param scanner_args: port scanner arguments
		:type scanner_args: str
		:return: base class object
		:rtype: das.modules.add.AddOutputBase
		"""
        self.db = TinyDB(db)
        if rm:
            self.db.truncate()

        self.portscan_out = f'.db/raw/{scanner_name}-{datetime.now().strftime("%Y%m%dT%H%M%S")}.out'
        self.command = f"""sudo {scanner_name} {scanner_args} | tee {self.portscan_out}"""

        Logger.print_cmd(self.command)
        os.system(self.command)

        with open(self.portscan_out, 'r', encoding='utf-8') as fd:
            self.portscan_raw = fd.read().splitlines()

    @abstractmethod
    def parse(self):
        """Interface for a parsing method."""
        raise NotImplementedError
Example #4
0
class PastesTinyDB(PastesDBBase):
    PASTE_DB_FILENAME = 'pastes.db.json'
    MATA_DB_FILENAME = 'meta.db.json'

    def __init__(self):
        super().__init__()

        if not hasattr(self, '_pastes_db'):
            self._pastes_db = TinyDB(self.__class__.PASTE_DB_FILENAME)
        if not hasattr(self, '_meta_db'):
            self._meta_db = TinyDB(self.__class__.MATA_DB_FILENAME)

    def insert_paste(self, paste):
        if isinstance(paste, Paste):
            paste = paste.to_dict()
        self._pastes_db.insert(paste)

        self._update_latest_paste(paste)

    def get_latest_paste(self):
        meta_records = self._meta_db.all()
        return meta_records[0] if meta_records else None

    def _update_latest_paste(self, paste):
        meta_record = self.get_latest_paste()
        if meta_record:
            record_ts = PasteBinTime(meta_record['date']).to_ts()
            paste_ts = PasteBinTime(paste['date']).to_ts()

            # if no need to change the current record - return before change
            if meta_record['key'] == paste['date'] or record_ts > paste_ts:
                return

        self._meta_db.truncate()
        self._meta_db.insert(paste)
Example #5
0
async def update_payment():
    accounts_db = TinyDB(settings.root_dir + '/db/accounts.json')
    subaccounts = FtxClient(api_key=api_key,
                            api_secret=api_secret).get_subaccounts()
    account_names = [''] + [s['nickname'] for s in subaccounts]

    accounts = await asyncio.gather(*[
        asyncio.create_task(_collect_payments(name)) for name in account_names
    ])

    accounts_db.truncate()
    for account in accounts:
        accounts_db.insert(account)
Example #6
0
 def update_article(self):
     inputs = self.get_input()
     if web.ctx.method == "GET":
         article_id = inputs.get("article_id")
         category_list = Categories.select().where(Categories.status == 0)
         article = Articles.get_or_none(Articles.id == article_id)
         print(article.id)
         self.private_data["article"] = article
         self.private_data["category_list"] = category_list
         return self.display("admin/update_article")
     else:
         article_id = inputs.get("article_id")
         name = inputs.get('name')
         content = inputs.get('content')
         summary = inputs.get("summary")
         category_id = inputs.get("category_id")
         source_url = inputs.get("source_url", "")
         keywords = str(inputs.get("keywords", "")).strip()
         article = Articles.get_or_none(Articles.id == article_id)
         try:
             tags_list = keywords.split(",") if keywords else []
             if tags_list:
                 got_tags = Tags.select().where(Tags.name.in_(tags_list))
                 tmp_list = []
                 for tag in got_tags:
                     tmp_list.append(tag.name)
                 for tag_str in tags_list:
                     tag_str.strip()
                     if tag_str not in tmp_list:
                         t = Tags(name=tag_str)
                         t.save()
                 db = TinyDB('settings/db.json')
                 db.truncate()
                 db.close()
             article.update(name=name,
                            content=content,
                            summary=summary,
                            category_id=category_id,
                            original_address=source_url,
                            keywords=keywords,
                            updateTime=time()).where(
                                Articles.id == article_id).execute()
             self.private_data["update_success"] = True
             return web.seeother(self.make_url('articles'))
         except Exception as e:
             log.error('update article failed %s' % traceback.format_exc())
             log.error('input params %s' % inputs)
             self.private_data["update_success"] = False
             return web.seeother(self.make_url('update_article'))
Example #7
0
 def get_tags(self):
     db = TinyDB('settings/db.json')
     table = db.table('_default')
     res = table.get(where('name') == 'tags')
     if res:
         data = res.get("data")
         tags_dict_list = json.loads(data)
     else:
         tags_list = Tags.select().where(Tags.status == 0)
         tags_dict_list = []
         for tag in tags_list:
             count = Articles.select().where(
                 Articles.keywords.contains(str(tag.name))).count()
             tags_dict_list.append({tag.name: count})
         db.truncate()
         table = db.table('_default')
         table.insert({"name": "tags", "data": json.dumps(tags_dict_list)})
         db.close()
     return tags_dict_list
Example #8
0
def main():
    CONFIG_FILE_NAME = "GroupAPI.yaml"
    JSON_DB_NAME_ORIGEM = "final_database.json"
    JSON_DB_NAME_DESTINO = "processos.json"
    CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))

    # Load configuration
    ConfigFile = os.path.join(CURRENT_DIR, CONFIG_FILE_NAME)
    with open(ConfigFile, "r", encoding="utf-8") as ConfigFile:
        Configuration = yaml.load(ConfigFile, Loader=yaml.SafeLoader)
    RootDataDir = pathlib.Path(Configuration["Input"]["RootDataDir"])
    JsonDB = os.path.join(RootDataDir, JSON_DB_NAME_DESTINO)
    JsonDBFonte = os.path.join(RootDataDir, JSON_DB_NAME_ORIGEM)

    db = TinyDB(JsonDB)
    db.truncate()

    with io.open(JsonDBFonte, "r", encoding="utf-8") as json_file:
        data = json.load(json_file)
    db.insert_multiple(data)
Example #9
0
def run_new(args):
    """ Create a new podcast

        This creates a new directory and fills it with a few standard
        templates, creating a new IPNS key

        Accepts
        -------
        args: a Namespace resulting from ArgumentParser.parse_args
    """
    channel_name = Path(args.channel_name).name
    title = args.title or channel_name.replace("_", " ")
    home = Path('./channels').joinpath(args.channel_name).absolute()
    home.mkdir(parents=True, exist_ok=True)

  
    metadata = dict(
        title=title,
        description=args.description or title,
        link=args.link or "https://ipfs.io/",
        copyright=args.copyright or "N/A",
        language=args.language or "en",
        managing_editor=args.managing_editor or "anonymous",
        ttl=args.ttl,
        key=''
    )

    print(
        f"Generating a new channel {title} in {home.as_posix()}"
        " with the following properties:"
    )
    pprint(metadata)

    db = TinyDB(home.joinpath("channel.json").as_posix())
    db.truncate()

    db.insert(metadata)
Example #10
0
class IAddPortscanOutput(ABC):
    """Base class for updating DB with parsed portscan output."""
    def __init__(self, db_path, rm, scanner_name, scanner_args):
        """
		Constructor.

		:param db_path: a TinyDB database file path
		:type db_path: tinydb.TinyDB
		:param rm: a flag showing if we need to drop the DB before updating its values
		:type rm: bool
		:param scanner_name: name of the port scanner to run
		:type scanner_name: str
		:param scanner_args: port scanner arguments
		:type scanner_args: str
		:return: base class object
		:rtype: das.parsers.IAddPortscanOutput
		"""
        self.db = TinyDB(db_path)
        if rm:
            self.db.truncate()

        self.portscan_out = f'{Path.home()}/.das/db/raw/{scanner_name}-{datetime.now().strftime("%Y%m%dT%H%M%S")}.out'
        self.command = f"""sudo {scanner_name} {scanner_args} | tee {self.portscan_out}"""

        Logger.print_cmd(self.command)
        os.system(self.command)

        with open(self.portscan_out, 'r+', encoding='utf-8') as fd:
            content = fd.read()
            fd.seek(0)
            fd.write(f'# {self.command}\n\n{content}')
            self.portscan_raw = content.splitlines()

    @abstractmethod
    def parse(self):
        """Interface for a parsing method."""
        raise NotImplementedError
Example #11
0
    ["short.st", "http://gestyy.com/eq0haV"],
    ["short.st", "http://gestyy.com/eq0hss"],
    ["short.st", "http://gestyy.com/eq0hsY"],
    ["short.st", "http://gestyy.com/eq0hdt"],
    ["short.st", "http://gestyy.com/eq0hd8"],
    ["short.st", "http://gestyy.com/eq0hfW"],
    ["short.st", "http://gestyy.com/eq0hgq"],
    ["short.st", "http://gestyy.com/eq0hgT"],
    ["short.st", "http://gestyy.com/eq0hhE"],
    ["short.st", "http://gestyy.com/eq0hh2"],
    ["short.st", "http://gestyy.com/eq0hjs"],
    ["short.st", "http://gestyy.com/eq0hjY"],
    ["short.st", "http://gestyy.com/eq0hj5"],
    ["short.st", "http://gestyy.com/eq0hk0"],
    ["short.st", "http://gestyy.com/eq0hk3"],
    ["short.st", "http://gestyy.com/eq0hle"],
    ["short.st", "http://gestyy.com/eq0hli"],
    ["short.st", "http://gestyy.com/eq0hlk"],
    ["short.st", "http://gestyy.com/eq0hlJ"],
    ["short.st", "http://gestyy.com/eq0hzq"],
    ["short.st", "http://gestyy.com/eq0hzP"],
    ["short.st", "http://gestyy.com/eq0hz9"],
    ["short.st", "http://gestyy.com/eq0hcw"],
]

lnkstbl = TinyDB('db/bot.json').table('links')
lnkstbl.truncate()
for lnk in lnks:

    lnkstbl.insert({'type': lnk[0], 'url': lnk[1]})
class Analyser:
    def __init__(
        self,
        host="localhost",
        port="8086",
        signal_group_id=None,
        signal_cli="/usr/bin/signal-cli",
        max_authorized_delta=0.5,
        verbose=False,
        dry_run=False,
        send_test_message=False,
        db='db.json',
        reset_db=False,
    ):
        self._client = InfluxDBClient(host, 8086)
        self.signal_group_id = signal_group_id
        self.verbose = verbose
        self.dry_run = dry_run
        self.send_test_message = send_test_message
        self.signal_cli = signal_cli
        self.max_authorized_delta = max_authorized_delta

        self.db = TinyDB(db)
        if reset_db:
            self.db.truncate()

        if self.verbose:
            def _query(*args, **kwargs):
                print(kwargs["query"])
                return self._client._query(*args, **kwargs)

            self._client._query = self._client.query
            self._client.query = _query

    def log(self, channel, message=""):
        fg.orange = Style(RgbFg(255, 150, 50))
        icons = {
            "logo": (fg.white + ASCII_ART, fg.rs),
            "header": (" " + ef.bold, rs.bold_dim),
            "subheader": (ef.i + fg.white + " ", fg.rs + rs.i + "\r\n"),
            "info": ("  " + "🤷 " + fg.white, fg.rs),
            "error": ("  " + "💥 " + fg.orange, fg.rs),
            "check": ("  " + "🎉 " + fg.green, fg.rs),
            "phone": ("  " + "📱", ""),
            "debug": ("  " + "🐛", fg.rs),
            "end": ("\r\n", bg.rs),
        }
        before, after = icons.get(channel, "")
        print(f"{bg.black}{before} {message} {after}")

    def update_state(self, fermenter, state):
        Fermenter = Query()
        self.db.upsert({'id': fermenter, 'state': state}, Fermenter.id == fermenter)

    def get_state(self, fermenter):
        Fermenter = Query()
        data = self.db.get(Fermenter.id == fermenter)
        if data:
            return data['state']
        else:
            return STATE.OK # Consider things are okay by default.


    def run(self, fermenters, date, group_time):
        self.log("logo")
        self.log(
            "header", f"Recherche d'anomalies pour les fermenteurs {unpack(fermenters)}"
        )

        msg = ""
        if date != "now":
            msg += f"pour la date {date}, "
        msg += f"par tranches de {group_time} minutes."
        self.log("subheader", msg)

        for fermenter in fermenters:
            try:
                context = self.analyse(
                    fermenter, start_time=date, group_time=group_time
                )
            except Anomaly as e:
                self.send_alert(e)
            else:
                self.log(
                    "check",
                    f"Pas d'anomalies détectées pour {fermenter} (consigne à {context['setpoint']}°C): {unpack_and_round(context['temperatures'])}.",
                )
                self.update_state(fermenter, state=STATE.OK)

        if self.send_test_message:
            self.send_alert(Anomaly(
                ("Ceci est un message de test envoyé par le système "
                 "de supervision de la brasserie")))

        self.log("end")

    def get_temperatures(self, fermenter, start_time, group_time, tries=2):
        if start_time == "now":
            start_time = "now()"

        since = group_time * 3

        query = f"""
        SELECT mean("value") FROM "autogen"."mqtt_consumer_float"
        WHERE ("topic" = 'fermenters/{fermenter}/temperature')
        AND time >=  {start_time} -{since}m
        AND time <= {start_time}
        GROUP BY time({group_time}m) fill(previous)
        """

        response = self._client.query(query=query, database="telegraf")
        if not response:
            if tries:
                return self.get_temperatures(
                    fermenter, start_time, group_time * 2, tries - 1
                )
            else:
                raise Anomaly(STATE.NO_DATA, {"fermenter": fermenter})

        temperatures = [temp for _, temp in response.raw["series"][0]["values"] if temp]
        if not temperatures:
            raise Anomaly(STATE.NO_DATA, {"fermenter": fermenter})
        return temperatures

    def get_setpoint(self, fermenter):
        query = f"""
        SELECT last("value")
        FROM "autogen"."mqtt_consumer_float"
        WHERE ("topic" = 'fermenters/{fermenter}/setpoint')
        """
        response = self._client.query(query=query, database="telegraf")
        return response.raw["series"][0]["values"][0][-1]

    def get_cooling_info(self, fermenter, start_time="now"):
        if start_time == "now":
            start_time = "now()"

        query = f"""
        SELECT last("value")
        FROM "autogen"."mqtt_consumer_float"
        WHERE ("topic" = 'fermenters/{fermenter}/cooling')
        AND time <= {start_time}
        """
        response = self._client.query(query=query, database="telegraf")
        return response.raw["series"][0]["values"][0][1]

    def analyse(self, fermenter, start_time, group_time):
        all_temperatures = self.get_temperatures(fermenter, start_time, group_time)
        # Do the computation on the last 6 values (= last 30mn)
        context = dict(
            fermenter=fermenter,
            temperatures=all_temperatures,
            is_cooling=self.get_cooling_info(fermenter, start_time),
            setpoint=self.get_setpoint(fermenter),
            acceptable_delta=self.max_authorized_delta
        )
        if self.verbose:
            pprint(context)
        self.check_temperature_convergence(**context)
        return context


    def check_temperature_convergence(
        self,
        fermenter,
        temperatures,
        is_cooling,
        setpoint,
        acceptable_delta,
        *args,
        **kwargs
    ):
        # That's here that we detect if problems occured.
        # We check :
        # - Should the temperature be falling? rising?
        # - Is it rising or falling? Are we going in the right direction?
        # - If we are going in the wrong direction, at what pace? is it acceptable?
        # - If we are about to send an alert, filter-out false positives :
        #   - delta to setpoint > 0.5°C
        #   -

        # If setpoint < last_temp, then we're going the wrong way.
        # Ex : Setpoint = 0
        # Mesured temperature = 21, 20, 19, 18
        # Then we're OK.
        #
        # But… Setpoint = 0
        # Mesured temperature = 6,7,8
        # We should raise.
        # So we need to know :
        # 1. If we're increasing or decreasing
        # 2. If we should be increasing or decreasing.

        last_temp = temperatures[-1]

        should_decrease = setpoint < last_temp
        should_increase = setpoint > last_temp

        inner_delta = temperatures[0] - temperatures[-1]
        absolute_delta = last_temp - setpoint

        is_decreasing = inner_delta > 0
        is_increasing = inner_delta < 0

        if (should_decrease
            and is_increasing
            and is_cooling
            and abs(inner_delta) > acceptable_delta
            and abs(absolute_delta) > acceptable_delta
        ):
            raise Anomaly(
                STATE.TEMP_RISING,
                {
                    "fermenter": fermenter,
                    "temperatures": temperatures,
                    "setpoint": setpoint,
                },
            )
        elif (
            should_increase
            and is_decreasing
            and abs(inner_delta) > acceptable_delta
            and abs(absolute_delta) > acceptable_delta
        ):
            raise Anomaly(
                STATE.TEMP_FALLING,
                {
                    "fermenter": fermenter,
                    "temperatures": temperatures,
                    "setpoint": setpoint,
                },
            )

    def send_alert(self, anomaly):
        context = anomaly.context
        anomaly_type = anomaly.message

        send = True
        message_type = "error"

        if anomaly_type == STATE.TEMP_RISING:
            message = (
                f"""Le fermenteur {context['fermenter']} grimpe en température """
                f"""({unpack_and_round(context['temperatures'])}), alors qu'il est """
                f"""sensé refroidir (consigne à {context['setpoint']}°C)!"""
            )
        elif anomaly_type == STATE.TEMP_FALLING:
            message = (
                f"Attention, le fermenteur {context['fermenter']} descends en temperature "
                f"({unpack_and_round(context['temperatures'])}) alors qu'il est sensé monter"
                f" (consigne à {context['setpoint']}°C)"
            )
        elif anomaly_type == STATE.NO_DATA:
            message = f"Aucune température n'est enregistrée par le fermenteur {context['fermenter']}."
        else:
            message = anomaly_type

        self.log(message_type, message)
        if send and not self.dry_run:
            # Send the message first, then change the state in the database.
            if self.get_state(anomaly.context['fermenter']) == anomaly.message:
                self.log('debug', 'message already sent, not sending it again')
            else:
                self.send_signal_message(message)
                self.update_state(anomaly.context['fermenter'], anomaly.message)

    def send_signal_message(self, message):
        command = f'{self.signal_cli} send -m "{message}" -g {self.signal_group_id}'
        resp = delegator.run(command)
        self.log("debug", command)
        if resp.err:
            self.log("error", resp.err)
        else:
            self.log("phone", f"Message de groupe envoyé à {self.signal_group_id}")
Example #13
0
class DataLoggerMixin(LoggerMixin):
    def __init__(self, *args, **kwargs):
        """Class constructor
        """

        super().__init__(*args, **kwargs)
        self.datalogger_current = {}
        """ 多线程并发
            可以使一个线程等待其他线程的通知
        """
        self.__dump_event = threading.Event()
        self.__start_event = threading.Event()
        self.__stop_event = threading.Event()
        self.__is_running = threading.Event()
        self.__logger = logging.getLogger('MCS.Datalogger')
        self.__logger.info('Initialize DataLoggerMixin')
        self.__files = {}

    ################
    # Classmethods #
    ################
    """读取默认值
    """

    @classmethod
    def get_defaults(cls):

        return {
            'active': False,
            'index': 0,
            'start_date': None,
            'stop_date': None,
            'stop_count': 1e9,
            'backend': 'sqlite3',
            'data_dir': '.',
            'table_name': 'datalogger',
            'delimiter': ';',
            'clear_on_export': True
        }

    ##############
    # Properties #
    ##############

    @property
    def datalogger_status(self):

        status = self._settings['datalogger'].copy()
        status['start_date'] = None
        status['stop_date'] = None
        status['files'] = [key for key, value in self.__files.items()]
        return status

    @property
    def __delimiter(self) -> str:

        return self._settings['datalogger']['delimiter']

    @__delimiter.setter
    def __delimiter(self, delimiter: str):

        self._settings['datalogger']['delimiter'] = delimiter

    @property
    def __clear_on_export(self) -> bool:

        return self._settings['datalogger']['clear_on_export']

    @__clear_on_export.setter
    def __clear_on_export(self, clear_flag: bool):

        self._settings['datalogger']['clear_on_export'] = clear_flag

    @property
    def __index(self):

        return self._settings['datalogger']['index']

    @__index.setter
    def __index(self, value):

        try:
            value = int(value)

        except ValueError:
            pass

        self._settings['datalogger']['index'] = value

    @property
    def datalogger_files(self) -> list:
        return self.__files

    @property
    def datalogger_backend(self) -> str:

        return self._settings['datalogger']['backend']

    @datalogger_backend.setter
    def datalogger_backend(self, backend_name: str):

        try:
            backend_name = str(backend_name)
        except ValueError:
            backend_name = None

        if backend_name not in ['tinydb', 'sqlite3']:
            self.__logger.error(f'Backend {backend_name} is not supported!')
            return

        self._settings['datalogger']['backend'] = backend_name
        self.__logger.info(f'Use {self.datalogger_backend} for datalogger')

    @property
    def datalogger_active(self):

        return self.__is_running.is_set()

    @property
    def __table_name(self):

        return self._settings['datalogger']['table_name']

    ###########
    # Methods #
    ###########

    def update_settings(self, settings: dict = {}, *args, **kwargs):
        """Update settings
        :param settings: A dictionary with all settings for this instance
        :type settings: dict, optional
        """

        super().update_settings(settings=settings, *args, **kwargs)
        # TODO: Use instance internal settings here

        if 'datalogger' not in self._settings:
            self._settings['datalogger'] = DataLoggerMixin.get_defaults()
            settings = self._settings['datalogger'].copy(
            )  # Trigger the updates
            self.__logger.info('Load default settings for DataLoggerMixin')

        if 'datalogger' not in settings:
            return

        if 'start_date' in settings['datalogger']:
            self.__start_date = settings['datalogger']['start_date']

        if 'stop_date' in settings['datalogger']:
            self.__stop_date = settings['datalogger']['stop_date']

        if 'stop_count' in settings['datalogger']:
            self.__stop_count = settings['datalogger']['stop_count']

    def datalogger_stop(self):
        """Stop the datalogger
        This will set a flag so the main datalogger loop can react accordingly.
        """

        if self.datalogger_active:
            self.__stop_event.set()
            self.__logger.info('Receive STOP command')
        else:
            self.__logger.error('Stop failed: datalogger is not running')

    def datalogger_start(self):
        """Start the datalogger
        This will set a flag so the main datalogger loop can react accordingly.
        """

        if not self.datalogger_active:
            self.__stop_event.clear()
            self.__start_event.set()
            self.__dump_event.clear()
        else:
            self.__logger.error('Start failed: datalogger is already running')

    def datalogger_dump(self,
                        delimiter: str = ';',
                        clear_database: bool = True):
        """Export all data to a CSV file
        This will set a flag so the main datalogger loop can react accordingly.
        """

        if self.datalogger_active:
            self.__dump_event.set()
        else:
            self.__logger.error('Export failed: datalogger is not running')

    def datalogger_run(self):
        """Main datalogger method
        This method should be hooked in a thread and the datalogger controlled
        through the start/stop/dump events.
        """

        self._settings['datalogger']['active'] = self.datalogger_active
        self._settings['datalogger']['data_dir'] = self.data_dir

        if not self.datalogger_active:

            if self.__start_date is not None:
                if time.time() > self.__start_date:
                    self.__start_event.set()

            if self.__start_event.is_set():
                self.__setup()
                self.__start_event.clear()
            else:
                return  # Do nothing is datalogger is not enabled

        # Add values to the database
        self.__add_multiple(self.datalogger_current)

        # stop the datalogger if a maximum count is set and reached
        if self.__stop_count is not None:
            if self.__index > self.__stop_count:
                self.__stop_event.set()
        # stop the datalogger if a maximum time is set and reached
        elif self.__stop_date is not None:
            if time.time() > self.__stop_date:
                self.__stop_event.set()

        # Dismantle datalogger if deactivated
        if self.__stop_event.is_set():
            self.__close()
            self.__stop_event.clear()
        # Perform DB dump if required
        elif self.__dump_event.is_set():
            self.__export_to_csv()
            self.__dump_event.clear()

    ####################
    # Internal methods #
    ####################

    def __setup(self):
        """Setup the database based on the configured backend.
        """

        if self.datalogger_backend == 'sqlite3':
            self.__setup_sqlite3()
        elif self.datalogger_backend == 'tinydb':
            self.__setup_tinydb()

        self.__is_running.set()

        if self.__start_date is None:
            self._settings['datalogger']['start_date'] = time.time()

    def __setup_tinydb(self):

        self.__db = TinyDB('datalogger.json')

    def __setup_sqlite3(self):
        """Create and connect to a memory SQLITE3 database
        """

        self.__db_connection = sqlite3.connect(':memory:')
        self.__logger.info('Connect to SQLITE3 database')
        create_datalogger_table = f'CREATE TABLE IF NOT EXISTS {self.__table_name} (id INTEGER PRIMARY KEY, timestamp REAL NOT NULL, elapsed REAL NOT NULL)'
        cursor = self.__db_connection.cursor()
        cursor.execute(create_datalogger_table)

        for name, value in flatten(self.datalogger_current).items():
            cursor.execute(
                f'ALTER TABLE {self.__table_name} ADD COLUMN {name} TEXT')

        self.__logger.info('Complete SQLITE3 database setup')

    def __add_multiple(self, data: dict = {}):
        """Add multiple values to the database
        :param data: A dictionary with the data to store. The keys correspond to the column names.
        :type data: dict, optional
        """

        # Do nothing if datalogger is disabled
        if not self.datalogger_active:
            self.__logger.error('Datalogger is already running')
            return

        self.__index += 1
        now = time.time()
        started = self._settings['datalogger']['start_date']
        keys = ['timestamp', 'elapsed']
        values = [f"'{val}'" for val in [now, now - started]]

        for key, value in flatten(data).items():
            key = str(key if key is not None else '')
            value = str(value if value is not None else '')
            keys.append(key)
            values.append("'" + value + "'")

        keys_formatted = ', '.join(keys)
        values_formatted = ', '.join(values)
        cursor = self.__db_connection.cursor()
        sql = f'INSERT INTO {self.__table_name} ({keys_formatted}) VALUES ({values_formatted})'
        cursor.execute(sql)

    def __close(self, export_data: bool = True):
        """Close the database connection
        :param export_data: Export all data to CSV before clearing and closing the database. Default is ``True``
        :type export_data: bool, optional.
        """

        if not self.datalogger_active:
            self.__logger.error('Close failed: datalogger is not running')
            return  # Do nothing if datalogger is disabled

        if self.datalogger_backend == 'tinydb':
            self.__db.truncate()
            self.__db.close()
        elif self.datalogger_backend == 'sqlite3':
            if export_data:
                self.__export_to_csv()
            cursor = self.__db_connection.cursor()
            cursor.execute(f'DROP TABLE {self.__table_name}')
            self.__logger.info(f'Drop SQLITE3 table "{self.__table_name}"')
            self.__db_connection.close()
            self.__logger.info('Close SQLITE3 database connection')

        self.__is_running.clear()
        self.__start_date = None
        self.__stop_date = None
        self.__stop_count = None

    def __export_to_csv(self,
                        delimiter: str = ";",
                        clear_database: bool = True):
        """Export all recorded data to a csv file
        :param delimiter: Delimiter for the csv file. Default is ``;``
        :type delimiter: str, optional
        :param clear_database: Delete all entries in the database after export. Default is ``True``
        :type clear_database: bool, optional
        The csv file holds all recorded values since the start or the last export of the datalogger.
        Files are stored in ``.__files`` as ``DataLoggerFile`` object.
        """

        if not self.datalogger_active:
            self.__logger.error('Export failed: datalogger is not running')
            return

        filename = f'{self.__table_name}_{int(time.time())}.csv'
        file = os.path.join(self.data_dir, filename)
        self.__logger.info(f'Write {self.__table_name} to {filename}')

        if self.datalogger_backend == 'sqlite3':
            cursor = self.__db_connection.cursor()
            cursor.execute(f'SELECT * from {self.__table_name}')
            new_file = DataLoggerFile(name=filename, records=self.__index)

            with io.StringIO() as csv_stream:
                csv_writer = csv.writer(csv_stream, delimiter=delimiter)
                csv_writer.writerow([i[0] for i in cursor.description])
                csv_writer.writerows(cursor)
                new_file.data = csv_stream.getvalue()

            self.__files[new_file.name] = new_file

            if clear_database:
                self.__dump_event.clear()
                cursor.execute(f'DELETE from {self.__table_name}')
                self.__logger.info(
                    f'Clear SQLITE3 table "{self.__table_name}"')
                self.__index = 0

    @property
    def __start_date(self):

        return self._settings['datalogger']['start_date']

    @__start_date.setter
    def __start_date(self, value):

        self.__set_start_date(value)

    def __set_start_date(self, new_date=None):
        """Start the datalogger with a specific time delay
        """

        if new_date == self.__start_date:
            return  # Nothing to do

        if new_date is None:
            self._settings['datalogger']['start_date'] = None
            self.__logger.info('Deactive automatic start time for datalogger')
            return True

        if new_date < time.time():
            self.__logger.error(
                'Start time for datalogger must be in the future')
            return False

        if self.__stop_date is not None:
            if self.__stop_date < new_date:
                self.__logger.error(
                    'Start time for the datalogger must be before stop time')
                return False

        self._settings['datalogger']['start_date'] = new_date

        return True

    @property
    def __stop_date(self):

        return self._settings['datalogger']['stop_date']

    @__stop_date.setter
    def __stop_date(self, value):

        self.__set_stop_date(value)

    def __set_stop_date(self, new_date=None):
        """Limit the duration the datalogger should store values
        """

        if new_date == self.__stop_date:
            return  # Nothing to do

        if new_date is None:
            self._settings['datalogger']['stop_date'] = None
            self.__logger.info('Deactive time limit for datalogger')
            return True

        if new_date < time.time():
            self.__logger.error(
                'Time limit for datalogger must be in the future')
            return False

        if self.__start_date is not None:
            if self.__start_date > new_date:
                self.__logger.error(
                    'Stop time for the datalogger must be after start time')
                return False

        self._settings['datalogger']['stop_date'] = new_date

        return True

    @property
    def __stop_count(self) -> int:

        return self._settings['datalogger']['stop_count']

    @__stop_count.setter
    def __stop_count(self, value: int):

        self.__set_stop_count(value)

    def __set_stop_count(self, new_count: int):
        """Limit the number of entries to be recorded in the datalogger
        :param new_count: The numbe of entries after which the datalogger should stop
        :type new_count: int
        """

        if new_count == self.__stop_count:
            return  # Nothing to do

        if new_count is None:
            self._settings['datalogger']['stop_count'] = None
            self.__logger.info('Deactive entry limit for datalogger')
            return True

        try:
            new_count = int(new_count)
        except ValueError:
            self.__logger.info('Entry limit for datalogger must be an integer')
            return False

        if new_count <= self.__index:
            self.__logger.error(
                f'Entry limit for datalogger must be larger than the current amount of entries ({self.__index})'
            )
            return False

        self._settings['datalogger']['stop_count'] = new_count

    def datalogger_collect(self):

        pass

    print("Hello World")
Example #14
0
class Cache:
    def __init__(
        self,
        central=None,
        data: Union[List[dict, ], dict, ] = None,
        refresh: bool = False,
    ) -> None:
        self.updated: list = []
        self.central = central
        self.DevDB = TinyDB(config.cache_file)
        self.SiteDB = self.DevDB.table("sites")
        self.GroupDB = self.DevDB.table("groups")
        self.TemplateDB = self.DevDB.table("templates")
        self._tables = [self.DevDB, self.SiteDB, self.GroupDB, self.TemplateDB]
        self.Q = Query()
        if data:
            self.insert(data)
        if central:
            self.check_fresh(refresh)

    def __call__(self, refresh=False) -> None:
        if refresh:
            self.check_fresh(refresh)

    def __iter__(self) -> list:
        for db in self._tables:
            yield db.name(), db.all()

    @property
    def devices(self) -> list:
        return self.DevDB.all()

    @property
    def sites(self) -> list:
        return self.SiteDB.all()

    @property
    def groups(self) -> list:
        return self.GroupDB.all()

    @property
    def group_names(self) -> list:
        return [g["name"] for g in self.GroupDB.all()]

    @property
    def templates(self) -> list:
        return self.TemplateDB.all()

    @property
    def all(self) -> dict:
        return {t.name: getattr(self, t.name) for t in self._tables}

    # TODO ??deprecated?? should be able to remove this method. don't remember this note. looks used
    def insert(
        self,
        data: Union[List[dict, ], dict, ],
    ) -> bool:
        _data = data
        if isinstance(data, list) and data:
            _data = data[1]

        table = self.DevDB
        if "zipcode" in _data.keys():
            table = self.SiteDB

        data = data if isinstance(data, list) else [data]
        ret = table.insert_multiple(data)

        return len(ret) == len(data)

    async def update_dev_db(self):
        resp = await self.central.get_all_devicesv2()
        if resp.ok:
            resp.output = utils.listify(resp.output)
            self.updated.append(self.central.get_all_devicesv2)
            self.DevDB.truncate()
            return self.DevDB.insert_multiple(resp.output)

    async def update_site_db(self):
        resp = await self.central.get_all_sites()
        if resp.ok:
            resp.output = utils.listify(resp.output)
            # TODO time this to see which is more efficient
            # start = time.time()
            # upd = [self.SiteDB.upsert(site, cond=self.Q.id == site.get("id")) for site in site_resp.output]
            # upd = [item for in_list in upd for item in in_list]
            self.updated.append(self.central.get_all_sites)
            self.SiteDB.truncate()
            # print(f" site db Done: {time.time() - start}")
            return self.SiteDB.insert_multiple(resp.output)

    async def update_group_db(self):
        resp = await self.central.get_all_groups()
        if resp.ok:
            resp.output = utils.listify(resp.output)
            self.updated.append(self.central.get_all_groups)
            self.GroupDB.truncate()
            return self.GroupDB.insert_multiple(resp.output)

    async def update_template_db(self):
        groups = self.groups if self.central.get_all_groups in self.updated else None
        resp = await self.central.get_all_templates(groups=groups)
        if resp.ok:
            resp.output = utils.listify(resp.output)
            self.updated.append(self.central.get_all_templates)
            self.TemplateDB.truncate()
            return self.TemplateDB.insert_multiple(resp.output)

    async def _check_fresh(self,
                           dev_db: bool = False,
                           site_db: bool = False,
                           template_db: bool = False,
                           group_db: bool = False):
        update_funcs = []
        if dev_db:
            update_funcs += [self.update_dev_db]
        if site_db:
            update_funcs += [self.update_site_db]
        if template_db:
            update_funcs += [self.update_template_db]
        if group_db:
            update_funcs += [self.update_group_db]
        async with ClientSession() as self.central.aio_session:
            if update_funcs:
                if await update_funcs[0]():
                    if len(update_funcs) > 1:
                        await asyncio.gather(*[f() for f in update_funcs[1:]])

            # update groups first so template update can use the result, and to trigger token_refresh if necessary
            elif await self.update_group_db():
                await asyncio.gather(self.update_dev_db(),
                                     self.update_site_db(),
                                     self.update_template_db())

    def check_fresh(
        self,
        refresh: bool = False,
        site_db: bool = False,
        dev_db: bool = False,
        template_db: bool = False,
        group_db: bool = False,
    ) -> None:
        if True in [site_db, dev_db, group_db, template_db]:
            refresh = True

        if refresh or not config.cache_file.is_file(
        ) or not config.cache_file.stat().st_size > 0:
            #  or time.time() - config.cache_file.stat().st_mtime > 7200:
            start = time.time()
            print(typer.style("-- Refreshing Identifier mapping Cache --",
                              fg="cyan"),
                  end="")
            db_res = asyncio.run(
                self._check_fresh(dev_db=dev_db,
                                  site_db=site_db,
                                  template_db=template_db,
                                  group_db=group_db))
            if db_res and False in db_res:
                log.error("TinyDB returned an error during db update")

            log.info(
                f"Cache Refreshed in {round(time.time() - start, 2)} seconds")
            typer.secho(
                f"-- Cache Refresh Completed in {round(time.time() - start, 2)} sec --",
                fg="cyan")

    def handle_multi_match(
        self,
        match: list,
        query_str: str = None,
        query_type: str = "device",
        multi_ok: bool = False,
    ) -> List[Dict[str, Any]]:
        # typer.secho(f" -- Ambiguos identifier provided.  Please select desired {query_type}. --\n", color="cyan")
        typer.echo()
        if query_type == "site":
            fields = ("name", "city", "state", "type")
        elif query_type == "template":
            fields = ("name", "group", "model", "device_type", "version")
        else:  # device
            fields = ("name", "serial", "mac", "type")
        out = utils.output([{k: d[k]
                             for k in d if k in fields} for d in match],
                           title="Ambiguos identifier. Select desired device.")
        menu = out.menu(data_len=len(match))

        if query_str:
            menu = menu.replace(query_str, typer.style(query_str, fg="green"))
            menu = menu.replace(query_str.upper(),
                                typer.style(query_str.upper(), fg="green"))
        typer.echo(menu)
        selection = ""
        valid = [str(idx + 1) for idx, _ in enumerate(match)]
        try:
            while selection not in valid:
                selection = typer.prompt(f"Select {query_type.title()}")
                if not selection or selection not in valid:
                    typer.secho(f"Invalid selection {selection}, try again.")
        except KeyboardInterrupt:
            raise typer.Abort()

        return [match.pop(int(selection) - 1)]

    def get_identifier(
        self,
        qry_str: str,
        qry_funcs: tuple,
        device_type: str = None,
        group: str = None,
        multi_ok: bool = False,
    ) -> CentralObject:
        match = None
        default_kwargs = {"retry": False}
        for _ in range(0, 2):
            for q in qry_funcs:
                kwargs = default_kwargs.copy()
                if q == "dev":
                    kwargs["dev_type"] = device_type
                elif q == "template":
                    kwargs["group"] = group
                match: CentralObject = getattr(self,
                                               f"get_{q}_identifier")(qry_str,
                                                                      **kwargs)

                if match:
                    return match

            # No match found trigger refresh and try again.
            if not match:
                self.check_fresh(
                    dev_db=True if "dev" in qry_funcs else False,
                    site_db=True if "site" in qry_funcs else False,
                    template_db=True if "template" in qry_funcs else False,
                    group_db=True if "group" in qry_funcs else False,
                )

        if not match:
            typer.secho(
                f"Unable to find a matching identifier for {qry_str}, tried: {qry_funcs}",
                fg="red")
            raise typer.Exit(1)

    def get_dev_identifier(
        self,
        query_str: Union[str, List[str], tuple],
        dev_type: str = None,
        ret_field: str = "serial",
        retry: bool = True,
        multi_ok: bool = True,
    ) -> CentralObject:

        # TODO dev_type currently not passed in or handled identifier for show switches would also
        # try to match APs ...  & (self.Q.type == dev_type)
        # TODO refactor to single test function usable by all identifier methods 1 search with a more involved test
        if isinstance(query_str, (list, tuple)):
            query_str = " ".join(query_str)

        match = None
        for _ in range(0, 2 if retry else 1):
            # Try exact match
            match = self.DevDB.search(
                (self.Q.name == query_str)
                | (self.Q.ip.test(lambda v: v.split("/")[0] == query_str))
                | (self.Q.mac == utils.Mac(query_str).cols)
                | (self.Q.serial == query_str))

            # retry with case insensitive name match if no match with original query
            if not match:
                match = self.DevDB.search(
                    (self.Q.name.test(lambda v: v.lower() == query_str.lower())
                     )
                    | self.Q.mac.test(lambda v: v.lower() == utils.Mac(
                        query_str).cols.lower())
                    | self.Q.serial.test(
                        lambda v: v.lower() == query_str.lower()))

            # retry name match swapping - for _ and _ for -
            if not match:
                if "-" in query_str:
                    match = self.DevDB.search(
                        self.Q.name.test(lambda v: v.lower() == query_str.
                                         lower().replace("-", "_")))
                elif "_" in query_str:
                    match = self.DevDB.search(
                        self.Q.name.test(lambda v: v.lower() == query_str.
                                         lower().replace("_", "-")))

            # Last Chance try to match name if it startswith provided value
            if not match:
                match = self.DevDB.search(
                    self.Q.name.test(
                        lambda v: v.lower().startswith(query_str.lower()))
                    | self.Q.serial.test(
                        lambda v: v.lower().startswith(query_str.lower()))
                    | self.Q.mac.test(lambda v: v.lower().startswith(
                        utils.Mac(query_str).cols.lower())))

            if retry and not match and self.central.get_all_devicesv2 not in self.updated:
                typer.secho(
                    f"No Match Found for {query_str}, Updating Device Cachce",
                    fg="red")
                self.check_fresh(refresh=True, dev_db=True)
            if match:
                break

        all_match = None
        if dev_type:
            all_match = match
            match = [
                d for d in match if d["type"].lower() in "".join(
                    dev_type[0:len(d["type"])]).lower()
            ]

        if match:
            if len(match) > 1:
                match = self.handle_multi_match(match,
                                                query_str=query_str,
                                                multi_ok=multi_ok)

            return CentralObject("dev", match)
        elif retry:
            log.error(
                f"Unable to gather device {ret_field} from provided identifier {query_str}",
                show=True)
            if all_match:
                all_match = all_match[-1]
                log.error(
                    f"The Following device matched {all_match.get('name')} excluded as {all_match.get('type')} != {dev_type}",
                    show=True,
                )
            raise typer.Abort()
        # else:
        #     log.error(f"Unable to gather device {ret_field} from provided identifier {query_str}", show=True)

    def get_site_identifier(
        self,
        query_str: Union[str, List[str], tuple],
        ret_field: str = "id",
        retry: bool = True,
        multi_ok: bool = False,
    ) -> CentralObject:
        if isinstance(query_str, (list, tuple)):
            query_str = " ".join(query_str)

        match = None
        for _ in range(0, 2 if retry else 1):
            # try exact site match
            match = self.SiteDB.search(
                (self.Q.name == query_str)
                | (self.Q.id.test(lambda v: str(v) == query_str))
                | (self.Q.zipcode == query_str)
                | (self.Q.address == query_str)
                | (self.Q.city == query_str)
                | (self.Q.state == query_str))

            # retry with case insensitive name & address match if no match with original query
            if not match:
                match = self.SiteDB.search(
                    (self.Q.name.test(lambda v: v.lower() == query_str.lower())
                     )
                    | self.Q.address.test(lambda v: v.lower().replace(
                        " ", "") == query_str.lower().replace(" ", "")))

            # retry name match swapping - for _ and _ for -
            if not match:
                if "-" in query_str:
                    match = self.SiteDB.search(
                        self.Q.name.test(lambda v: v.lower() == query_str.
                                         lower().replace("-", "_")))
                elif "_" in query_str:
                    match = self.SiteDB.search(
                        self.Q.name.test(lambda v: v.lower() == query_str.
                                         lower().replace("_", "-")))

            # Last Chance try to match name if it startswith provided value
            if not match:
                match = self.SiteDB.search(
                    self.Q.name.test(
                        lambda v: v.lower().startswith(query_str.lower())))

            if retry and not match and self.central.get_all_sites not in self.updated:
                typer.secho(
                    f"No Match Found for {query_str}, Updating Site Cachce",
                    fg="red")
                self.check_fresh(refresh=True, site_db=True)
            if match:
                break

        if match:
            if len(match) > 1:
                match = self.handle_multi_match(match,
                                                query_str=query_str,
                                                query_type="site",
                                                multi_ok=multi_ok)

            # return match[0].get(ret_field)
            return CentralObject("site", match)

        elif retry:
            log.error(
                f"Unable to gather site {ret_field} from provided identifier {query_str}",
                show=True)
            raise typer.Abort()

    def get_group_identifier(
        self,
        query_str: str,
        ret_field: str = "name",
        retry: bool = True,
        multi_ok: bool = False,
    ) -> CentralObject:
        """Allows Case insensitive group match"""
        for _ in range(0, 2):
            match = self.GroupDB.search(
                (self.Q.name == query_str)
                | self.Q.name.test(lambda v: v.lower() == query_str.lower()))
            if retry and not match and self.central.get_all_groups not in self.updated:
                typer.secho(
                    f"No Match Found for {query_str}, Updating group Cachce",
                    fg="red")
                self.check_fresh(refresh=True, group_db=True)
            if match:
                break

        if match:
            if len(match) > 1:
                match = self.handle_multi_match(match,
                                                query_str=query_str,
                                                query_type="group",
                                                multi_ok=multi_ok)

            return CentralObject("group", match)
        elif retry:
            log.error(
                f"Unable to gather group {ret_field} from provided identifier {query_str}",
                show=True)
            valid_groups = "\n".join(self.group_names)
            typer.secho(f"{query_str} appears to be invalid", fg="red")
            typer.secho(f"Valid Groups:\n--\n{valid_groups}\n--\n", fg="cyan")
            raise typer.Abort()
        else:
            log.error(
                f"Unable to gather template {ret_field} from provided identifier {query_str}",
                show=True)

    def get_template_identifier(
        self,
        query_str: str,
        ret_field: str = "name",
        group: str = None,
        retry: bool = True,
        multi_ok: bool = False,
    ) -> CentralObject:
        """Allows case insensitive template match by template name"""
        match = None
        for _ in range(0, 2 if retry else 1):
            match = self.TemplateDB.search(
                (self.Q.name == query_str)
                | self.Q.name.test(lambda v: v.lower() == query_str.lower()))

            if not match:
                match = self.TemplateDB.search(
                    self.Q.name.test(lambda v: v.lower() == query_str.lower().
                                     replace("_", "-")))

            if not match:
                match = self.TemplateDB.search(
                    self.Q.name.test(
                        lambda v: v.lower().startswith(query_str.lower())))

            if retry and not match and self.central.get_all_templates not in self.updated:
                typer.secho(
                    f"No Match Found for {query_str}, Updating template Cachce",
                    fg="red")
                self.check_fresh(refresh=True, template_db=True)
            if match:
                break

        if match:
            if len(match) > 1:
                if group:
                    match = [{k: d[k]
                              for k in d} for d in match
                             if d["group"].lower() == group.lower()]

            if len(match) > 1:
                match = self.handle_multi_match(
                    match,
                    query_str=query_str,
                    query_type="template",
                    multi_ok=multi_ok,
                )

            return CentralObject("template", match)

        elif retry:
            log.error(
                f"Unable to gather template {ret_field} from provided identifier {query_str}",
                show=True)
            raise typer.Abort()
        else:
            log.warning(
                f"Unable to gather template {ret_field} from provided identifier {query_str}",
                show=False)
Example #15
0
class DataConverter:
    def __init__(self, m_cfg, c_cfg, options):
        self.m_cfg = m_cfg
        self.c_cfg = c_cfg
        self.opts = options

        if len(m_cfg['fields']) != len(c_cfg['fields']):
            raise ValueError()

        if self.opts.get('has_datestamps', True):
            datestamps = self.opts.get('datestamps',
                                       ['created', 'updated', 'deleted'])
            m_dstamps = []
            c_dstamps = []
            for ds_name in datestamps:
                m_dstamps.append(F"{ds_name}_date")
                c_dstamps.append(F"{ds_name}_at")

            self.m_cfg['fields'].extend(m_dstamps)
            self.c_cfg['fields'].extend(c_dstamps)

        # Connect to MySQL DB
        self.metiisto = mysql.connector.connect(
            host="localhost",
            user="******",
            passwd="",
            # database="metiisto_devel"
            database="metiisto_prod")

        # Connect to / Create TinyDB
        env = os.getenv('CARTARO_ENV', 'dev')
        db_sfx = '' if env == "prod" else F"-{env}"
        self.cartaro = TinyDB(
            F"{options.get('out_dir', '.')}/{self.c_cfg['name']}{db_sfx}.json")
        if not self.opts.get('preserve_data', False):
            self.cartaro.truncate()

    def convert(self):
        print(
            F"--- Converting Metiisto/{self.m_cfg['name']} to Cartaro/{self.c_cfg['name']} ---"
        )
        obj_cursor = self.metiisto.cursor()

        # Fetch Metiisto data
        fld_str = ",".join(self.m_cfg['fields'])
        obj_sql = F"select id,{fld_str} from {self.m_cfg['name']} order by id"
        if self.opts.get('limit', None):
            obj_sql += F" limit {self.opts['limit']}"
        print(F"     => {obj_sql}")

        tags = {}
        if self.opts.get("has_tags", False):
            tag_class = self.opts['tag_class']
            tag_sql = F"select tagged_object.obj_id, tags.name from tags, tagged_object where obj_class='{tag_class}' and tags.id = tagged_object.tag_id order by obj_id"
            print(F"     => {tag_sql}")

            tag_cursor = self.metiisto.cursor()
            tag_cursor.execute(tag_sql)
            for row in tag_cursor:
                note_id = row[0]
                if note_id not in tags:
                    tags[note_id] = []

                tags[note_id].append(Tag.normalize(row[1]))

        count = 0
        obj_cursor.execute(obj_sql)
        for row in obj_cursor:
            record = {}
            obj_id = row[0]

            if self.opts.get('interactive', False):
                print(
                    "-------------------------------------------------------")
                pprint.pprint(row, indent=2)
                print(
                    "-------------------------------------------------------")

            # row[0] == DB `id`, skip for mapping
            mapping = zip(self.c_cfg['fields'], row[1:])
            for (name, raw_value) in mapping:
                value = raw_value
                if raw_value and name.endswith('_at'):
                    date = arrow.get(raw_value, Base.TIMEZONE)
                    value = date.timestamp
                elif name.startswith('is_'):
                    value = True if raw_value else False

                transformer = self.opts.get(F"{name}_transformer", None)
                if transformer:
                    value = transformer(
                        value,
                        record=record,
                        interactive=self.opts.get('interactive'))

                # print(F"{name} == {value} - {type(value)}")
                record[name] = value

            if self.opts.get("has_tags", False):
                record['tags'] = tags.get(obj_id, [])

            for tag in self.opts.get("additional_tags", []):
                if not 'tags' in record:
                    record['tags'] = []
                record['tags'].append(Tag.normalize(tag))

            # Metiisto side does not have TS, Added `created_at` to Cartaro data
            if not self.opts.get('has_datestamps', True):
                record['created_at'] = record[
                    'created_at'] if 'created_at' in record else arrow.now(
                        Base.TIMEZONE).timestamp
                record['updated_at'] = None
                record['deleted_at'] = None

            self.cartaro.insert(record)
            count += 1
            print(F"{self.m_cfg['name']} - {count}", end="\r")

        print(F"     => Converted {count} records.")
Example #16
0
    print("Filter: Bilateral Filtering")
elif args['medianFiltering']:
    print("Filter: Median Filtering")
elif args['gaussianFiltering']:
    print("Filter: Gaussian Filtering")
else:
    print("Filter: Default")

#Load the classifier
face_cascade = cv2.CascadeClassifier('haarcascade_frontalface_default.xml')

#Connect with the database
db = TinyDB(args['json'])

#Clear the database
db.truncate()

#Load the video
if args['video'] == 0 or args['video'] == '0':
    args['video'] = 0
cap = cv2.VideoCapture(args['video'])

#For each frame apply classification
while True:
    _, frame = cap.read()

    #quit if they are no frames
    if frame is None:
        break

    # resize the frame
Example #17
0
def mock_init_catalog(dummyself):
    test_db_filepath = './tests/resources/test.json'
    test_catalog = TinyDB(test_db_filepath)
    test_catalog.truncate()
    return test_catalog
Example #18
0
)
logger = logging.getLogger()
fileHandler = logging.FileHandler("{0}/{1}.log".format("log", f"agenceur"))
fileHandler.setFormatter(logFormatter)
logger.addHandler(fileHandler)
consoleHandler = logging.StreamHandler(sys.stdout)
consoleHandler.setFormatter(logFormatter)
logger.addHandler(consoleHandler)
logger.setLevel(logging.INFO)

# init opaque DB
db_opaque = TinyDB("opaque.json")

# init clear measures DB
db_measures = TinyDB("measures.json")
db_measures.truncate()
lock = threading.Lock()  # on received message


def on_message(client, userdata, message):
    with lock:
        logger.debug(
            "rcvd: " + message.topic + "/" + str(message.payload.decode("utf-8"))
        )

        if message.topic == "addToPool":
            # store in DB
            logger.info("storing payload")
            db_opaque.insert({"entry": str(message.payload.decode("utf-8"))})

        if message.topic == "requestPool":
import requests
import a_a_parameters_connection as pc
import c_a_parameters_bible as pb
import json
import os
from tinydb import TinyDB, Query

db_chapters = TinyDB(pb.BIBLE_ID + '/d_c_db_' + pb.BIBLE_ID + '_chapters.json')
db_chapters.truncate()

db_books = TinyDB(pb.BIBLE_ID + '/c_c_db_' + pb.BIBLE_ID + '_books.json')
headers = {'api-key': pc.API_KEY}
for book in db_books:
    url = pc.BASE_API_URL + '/v1/bibles/' + pb.BIBLE_ID + '/books/' + book[
        'id'] + '/chapters'
    r = requests.get(url, headers=headers)
    chapters_dict = json.loads(r.text)
    for chapter in chapters_dict['data']:
        db_chapters.insert(chapter)
    print(book['id'] + ':' + str(len(chapters_dict['data'])))

print(len(db_chapters))
Example #20
0
def db():
    db = TinyDB("db.json")
    print("\n Setup DB")
    yield db
    db.truncate()
    print("\n Teardown DB")
Example #21
0
    email = input("What is the email?\n")
    author = input("Who is the author?\n")

    contacts.insert({"email": email, "author": author})
    print(f"{email} is saved at contacts.json")
elif response.startswith("r"):
    author = input("Who is the author of the email you want to find?\n")

    print(contacts.search(Contact.author == author))
elif response.startswith("u"):
    author = input("Who is the author of the email you want to update?\n")
    email = input("What is the new email?\n")

    contacts.update({"email": email}, Contact.author == author)
    print(f"The new email {email} is saved for the {author} at contacts.json")
elif response.startswith("d"):
    email = input("Which email you want to delete?\n")

    contacts.remove(Contact.email == email)
elif response.startswith("e"):
    contacts.truncate()
else:
    index = 1
    for contact in contacts:
        email = contact["email"]
        author = contact["author"]

        print(f"{index}. {email} ({author})\n")

        index = index + 1
Example #22
0
def buffer_db(in_memory, filename=None):
    """
    buffers the complete SQL database into a TinyDB object (either in memory or into a local JSON file)

    Parameters
    ----------
    in_memory: bool
        if True: the mysql database will be buffered as a tiny tb object that only exists in memory
        if False: the mysql database will be buffered as a tiny tb object and saved in a local json file
    filename: string
        only relevant if `in_memory = True`: the filename of the json file of the tiny db object
    """
    serialization = SerializationMiddleware()
    serialization.register_serializer(DateTimeSerializer(), 'TinyDate')
    logger.info("buffering SQL database on-the-fly")
    if in_memory:
        db = TinyDB(storage=MemoryStorage)
    else:
        db = TinyDB(filename,
                    storage=serialization,
                    sort_keys=True,
                    indent=4,
                    separators=(',', ': '))
    db.truncate()

    from NuRadioReco.detector import detector_sql
    sqldet = detector_sql.Detector()
    results = sqldet.get_everything_stations()
    table_stations = db.table('stations')
    table_stations.truncate()
    for result in results:
        table_stations.insert({
            'station_id':
            result['st.station_id'],
            'commission_time':
            result['st.commission_time'],
            'decommission_time':
            result['st.decommission_time'],
            'station_type':
            result['st.station_type'],
            'position':
            result['st.position'],
            'board_number':
            result['st.board_number'],
            'MAC_address':
            result['st.MAC_address'],
            'MBED_type':
            result['st.MBED_type'],
            'pos_position':
            result['pos.position'],
            'pos_measurement_time':
            result['pos.measurement_time'],
            'pos_easting':
            result['pos.easting'],
            'pos_northing':
            result['pos.northing'],
            'pos_altitude':
            result['pos.altitude'],
            'pos_zone':
            result['pos.zone'],
            'pos_site':
            result['pos.site']
        })

    table_channels = db.table('channels')
    table_channels.truncate()
    results = sqldet.get_everything_channels()
    for channel in results:
        table_channels.insert({
            'station_id':
            channel['st.station_id'],
            'channel_id':
            channel['ch.channel_id'],
            'commission_time':
            channel['ch.commission_time'],
            'decommission_time':
            channel['ch.decommission_time'],
            'ant_type':
            channel['ant.antenna_type'],
            'ant_orientation_phi':
            channel['ant.orientation_phi'],
            'ant_orientation_theta':
            channel['ant.orientation_theta'],
            'ant_rotation_phi':
            channel['ant.rotation_phi'],
            'ant_rotation_theta':
            channel['ant.rotation_theta'],
            'ant_position_x':
            channel['ant.position_x'],
            'ant_position_y':
            channel['ant.position_y'],
            'ant_position_z':
            channel['ant.position_z'],
            'ant_deployment_time':
            channel['ant.deployment_time'],
            'ant_comment':
            channel['ant.comment'],
            'cab_length':
            channel['cab.cable_length'],
            'cab_reference_measurement':
            channel['cab.reference_measurement'],
            'cab_time_delay':
            channel['cab.time_delay'],
            'cab_id':
            channel['cab.cable_id'],
            'cab_type':
            channel['cab.cable_type'],
            'amp_type':
            channel['amps.amp_type'],
            'amp_reference_measurement':
            channel['amps.reference_measurement'],
            'adc_id':
            channel['adcs.adc_id'],
            'adc_time_delay':
            channel['adcs.time_delay'],
            'adc_nbits':
            channel['adcs.nbits'],
            'adc_n_samples':
            channel['adcs.n_samples'],
            'adc_sampling_frequency':
            channel['adcs.sampling_frequency']
        })

    results = sqldet.get_everything_positions()
    table_positions = db.table('positions')
    table_positions.truncate()
    for result in results:
        table_positions.insert({
            'pos_position':
            result['pos.position'],
            'pos_measurement_time':
            result['pos.measurement_time'],
            'pos_easting':
            result['pos.easting'],
            'pos_northing':
            result['pos.northing'],
            'pos_altitude':
            result['pos.altitude'],
            'pos_zone':
            result['pos.zone'],
            'pos_site':
            result['pos.site']
        })

    logger.info("sql database buffered")
    return db
Example #23
0
class Detector(object):
    """
    main detector class which provides access to the detector description

    This class provides functions for all relevant detector properties.
    """
    def __init__(self,
                 source='json',
                 json_filename='ARIANNA/arianna_detector_db.json',
                 dictionary=None,
                 assume_inf=True,
                 antenna_by_depth=True):
        """
        Initialize the stations detector properties.
        By default, a new detector instance is only created of none exists yet, otherwise the existing instance
        is returned. To force the creation of a new detector instance, pass the additional keyword parameter
        `create_new=True` to this function. For more details, check the documentation for the
        `Singleton metaclass <NuRadioReco.utilities.html#NuRadioReco.utilities.metaclasses.Singleton>`_.
        Parameters
        ----------
        source : str
            'json', 'dictionary' or 'sql'
            default value is 'json'
            if dictionary is specified, the dictionary passed to __init__ is used
            if 'sql' is specified, the file 'detector_sql_auth.json' file needs to be present in this folder that
            specifies the sql server credentials (see 'detector_sql_auth.json.sample' for an example of the syntax)
        json_filename : str
            the path to the json detector description file (if first checks a path relative to this directory, then a
            path relative to the current working directory of the user)
            default value is 'ARIANNA/arianna_detector_db.json'
        assume_inf : Bool
            Default to True, if true forces antenna models to have infinite boundary conditions, otherwise the antenna madel will be determined by the station geometry.
        antenna_by_depth: bool (default True)
            if True the antenna model is determined automatically depending on the depth of the antenna. This is done by
            appending e.g. '_InfFirn' to the antenna model name.
            if False, the antenna model as specified in the database is used.
        create_new: bool (default:False)
            Can be used to force the creation of a new detector object. By default, the __init__ will anly create a new
            object of none already exists.
        """
        self._serialization = SerializationMiddleware()
        self._serialization.register_serializer(DateTimeSerializer(),
                                                'TinyDate')
        if source == 'sql':
            self._db = buffer_db(in_memory=True)
        elif source == 'dictionary':
            self._db = TinyDB(storage=MemoryStorage)
            self._db.truncate()
            stations_table = self._db.table('stations', cache_size=1000)
            for station in dictionary['stations'].values():
                stations_table.insert(station)
            channels_table = self._db.table('channels', cache_size=1000)
            for channel in dictionary['channels'].values():
                channels_table.insert(channel)
        else:
            dir_path = os.path.dirname(
                os.path.realpath(__file__))  # get the directory of this file
            filename = os.path.join(dir_path, json_filename)
            if not os.path.exists(filename):
                # try local folder instead
                filename2 = json_filename
                if not os.path.exists(filename2):
                    logger.error(
                        "can't locate json database file {} or {}".format(
                            filename, filename2))
                    raise NameError
                filename = filename2
            logger.warning("loading detector description from {}".format(
                os.path.abspath(filename)))
            self._db = TinyDB(filename,
                              storage=self._serialization,
                              sort_keys=True,
                              indent=4,
                              separators=(',', ': '))

        self._stations = self._db.table('stations', cache_size=1000)
        self._channels = self._db.table('channels', cache_size=1000)
        self.__positions = self._db.table('positions', cache_size=1000)

        logger.info("database initialized")

        self._buffered_stations = {}
        self.__buffered_positions = {}
        self._buffered_channels = {}
        self.__valid_t0 = astropy.time.Time('2100-1-1')
        self.__valid_t1 = astropy.time.Time('1970-1-1')

        self.__noise_RMS = None

        self.__current_time = None

        self.__assume_inf = assume_inf
        if antenna_by_depth:
            logger.info(
                "the correct antenna model will be determined automatically based on the depth of the antenna"
            )
        self._antenna_by_depth = antenna_by_depth

    def __query_channel(self, station_id, channel_id):
        Channel = Query()
        if self.__current_time is None:
            raise ValueError(
                "Detector time is not set. The detector time has to be set using the Detector.update() function before it can be used."
            )
        res = self._channels.get(
            (Channel.station_id == station_id)
            & (Channel.channel_id == channel_id)
            & (Channel.commission_time <= self.__current_time.datetime)
            & (Channel.decommission_time > self.__current_time.datetime))
        if res is None:
            logger.error(
                "query for station {} and channel {} at time {} returned no results"
                .format(station_id, channel_id, self.__current_time))
            raise LookupError
        return res

    def _query_channels(self, station_id):
        Channel = Query()
        if self.__current_time is None:
            raise ValueError(
                "Detector time is not set. The detector time has to be set using the Detector.update() function before it can be used."
            )
        return self._channels.search(
            (Channel.station_id == station_id)
            & (Channel.commission_time <= self.__current_time.datetime)
            & (Channel.decommission_time > self.__current_time.datetime))

    def _query_station(self, station_id):
        Station = Query()
        if self.__current_time is None:
            raise ValueError(
                "Detector time is not set. The detector time has to be set using the Detector.update() function before it can be used."
            )
        res = self._stations.get(
            (Station.station_id == station_id)
            & (Station.commission_time <= self.__current_time.datetime)
            & (Station.decommission_time > self.__current_time.datetime))
        if res is None:
            logger.error(
                "query for station {} at time {} returned no results".format(
                    station_id, self.__current_time.datetime))
            raise LookupError(
                "query for station {} at time {} returned no results".format(
                    station_id, self.__current_time.datetime))
        return res

    def __query_position(self, position_id):
        Position = Query()
        res = self.__positions.get((Position.pos_position == position_id))
        if self.__current_time is None:
            raise ValueError(
                "Detector time is not set. The detector time has to be set using the Detector.update() function before it can be used."
            )
        if res is None:
            logger.error(
                "query for position {} at time {} returned no results".format(
                    position_id, self.__current_time.datetime))
            raise LookupError(
                "query for position {} at time {} returned no results".format(
                    position_id, self.__current_time.datetime))
        return res

    def get_station_ids(self):
        """
        returns a sorted list of all station ids present in the database
        """
        station_ids = []
        res = self._stations.all()
        if res is None:
            logger.error("query for stations returned no results")
            raise LookupError("query for stations returned no results")
        for a in res:
            if a['station_id'] not in station_ids:
                station_ids.append(a['station_id'])
        return sorted(station_ids)

    def _get_station(self, station_id):
        if station_id not in self._buffered_stations.keys():
            self._buffer(station_id)
        return self._buffered_stations[station_id]

    def get_station(self, station_id):
        return self._get_station(station_id)

    def __get_position(self, position_id):
        if position_id not in self.__buffered_positions.keys():
            self.__buffer_position(position_id)
        return self.__buffered_positions[position_id]

    def __get_channels(self, station_id):
        if station_id not in self._buffered_stations.keys():
            self._buffer(station_id)
        return self._buffered_channels[station_id]

    def __get_channel(self, station_id, channel_id):
        if station_id not in self._buffered_stations.keys():
            self._buffer(station_id)
        return self._buffered_channels[station_id][channel_id]

    def _buffer(self, station_id):
        self._buffered_stations[station_id] = self._query_station(station_id)
        self.__valid_t0 = astropy.time.Time(
            self._buffered_stations[station_id]['commission_time'])
        self.__valid_t1 = astropy.time.Time(
            self._buffered_stations[station_id]['decommission_time'])
        channels = self._query_channels(station_id)
        self._buffered_channels[station_id] = {}
        for channel in channels:
            self._buffered_channels[station_id][
                channel['channel_id']] = channel
            self.__valid_t0 = max(
                self.__valid_t0, astropy.time.Time(channel['commission_time']))
            self.__valid_t1 = min(
                self.__valid_t1,
                astropy.time.Time(channel['decommission_time']))

    def __buffer_position(self, position_id):
        self.__buffered_positions[position_id] = self.__query_position(
            position_id)

    def __get_t0_t1(self, station_id):
        Station = Query()
        res = self._stations.get(Station.station_id == station_id)
        t0 = None
        t1 = None
        if isinstance(res, list):
            for station in res:
                if t0 is None:
                    t0 = station['commission_time']
                else:
                    t0 = min(t0, station['commission_time'])
                if t1 is None:
                    t1 = station['decommission_time']
                else:
                    t1 = max(t1, station['decommission_time'])
        else:
            t0 = res['commission_time']
            t1 = res['decommission_time']
        return astropy.time.Time(t0), astropy.time.Time(t1)

    def has_station(self, station_id):
        """
        checks if a station is present in the database

        Parameters
        ----------
        station_id: int
            the station id

        Returns bool
        """
        Station = Query()
        res = self._stations.get(Station.station_id == station_id)
        return res is not None

    def get_unique_time_periods(self, station_id):
        """
        returns the time periods in which the station configuration (including all channels) was constant

        Parameters
        ----------
        station_id: int
            the station id

        Returns datetime tuple
        """
        up = []
        t0, t1 = self.__get_t0_t1(station_id)
        self.update(t0)
        while True:
            if len(up) > 0 and up[-1] == t1:
                break
            self._buffer(station_id)
            if len(up) == 0:
                up.append(self.__valid_t0)
            up.append(self.__valid_t1)
            self.update(self.__valid_t1)
        return up

    def update(self, time):
        """
        updates the detector description to a new time

        Parameters
        ----------
        time: astropy.time.Time
            the time to update the detector description to
            for backward compatibility datetime is also accepted, but astropy.time is prefered
        """
        if isinstance(time, datetime):
            self.__current_time = astropy.time.Time(time)
        else:
            self.__current_time = time
        logger.info("updating detector time to {}".format(self.__current_time))
        if not ((self.__current_time > self.__valid_t0) and
                (self.__current_time < self.__valid_t1)):
            self._buffered_stations = {}
            self._buffered_channels = {}
            self.__valid_t0 = astropy.time.Time('2100-1-1')
            self.__valid_t1 = astropy.time.Time('1970-1-1')

    def get_detector_time(self):
        """
        Returns the time that the detector is currently set to
        """
        return self.__current_time

    def get_channel(self, station_id, channel_id):
        """
        returns a dictionary of all channel parameters

        Parameters
        ---------
        station_id: int
            the station id
        channel_id: int
            the channel id

        Return
        -------------
        dict of channel parameters
        """
        return self.__get_channel(station_id, channel_id)

    def get_absolute_position(self, station_id):
        """
        get the absolute position of a specific station

        Parameters
        ---------
        station_id: int
            the station id

        Returns
        ----------------
        3-dim array of absolute station position in easting, northing and depth wrt. to snow level at
        time of measurement
        """
        res = self._get_station(station_id)
        easting, northing, altitude = 0, 0, 0
        unit_xy = units.m
        if 'pos_zone' in res and res['pos_zone'] == "SP-grid":
            unit_xy = units.feet
        if res['pos_easting'] is not None:
            easting = res['pos_easting'] * unit_xy
        if res['pos_northing'] is not None:
            northing = res['pos_northing'] * unit_xy
        if res['pos_altitude'] is not None:
            altitude = res['pos_altitude']
        return np.array([easting, northing, altitude])

    def get_absolute_position_site(self, site):
        """
        get the absolute position of a specific station

        Parameters
        ---------
        site: string
            the position identifier e.g. "G"

        Returns
        ---------------
        3-dim array of absolute station position in easting, northing and depth wrt. to snow level at
        time of measurement
        """
        res = self.__get_position(site)
        unit_xy = units.m
        if 'pos_zone' in res and res['pos_zone'] == "SP-grid":
            unit_xy = units.feet
        easting, northing, altitude = 0, 0, 0
        if res['pos_easting'] is not None:
            easting = res['pos_easting'] * unit_xy
        if res['pos_northing'] is not None:
            northing = res['pos_northing'] * unit_xy
        if res['pos_altitude'] is not None:
            altitude = res['pos_altitude'] * units.m
        return np.array([easting, northing, altitude])

    def get_relative_position(self, station_id, channel_id):
        """
        get the relative position of a specific channels/antennas with respect to the station center

        Parameters
        ---------
        station_id: int
            the station id
        channel_id: int
            the channel id

        Returns
        ---------------
        3-dim array of relative station position
        """
        res = self.__get_channel(station_id, channel_id)
        return np.array([
            res['ant_position_x'], res['ant_position_y'], res['ant_position_z']
        ])

    def get_site(self, station_id):
        """
        get the site where the station is deployed (e.g. MooresBay or South Pole)

        Parameters
        ---------
        station_id: int
            the station id

        Returns string
        """

        res = self._get_station(station_id)
        return res['pos_site']

    def get_site_coordinates(self, station_id):
        """
        get the (latitude, longitude) coordinates (in degrees) for a given
        detector site.

        Parameters
        -------------
        station_id: int
            the station ID
        """
        sites = {
            'auger': (-35.10, -69.55),
            'mooresbay': (-78.74, 165.09),
            'southpole': (-90., 0.),
            'summit': (72.57, -38.46)
        }
        site = self.get_site(station_id)
        if site in sites.keys():
            return sites[site]
        return (None, None)

    def get_number_of_channels(self, station_id):
        """
        Get the number of channels per station

        Parameters
        ---------
        station_id: int
            the station id

        Returns int
        """
        res = self.__get_channels(station_id)
        return len(res)

    def get_channel_ids(self, station_id):
        """
        get the channel ids of a station

        Parameters
        ---------
        station_id: int
            the station id

        Returns list of ints
        """
        channel_ids = []
        for channel in self.__get_channels(station_id).values():
            channel_ids.append(channel['channel_id'])
        return sorted(channel_ids)

    def get_parallel_channels(self, station_id):
        """
        get a list of parallel antennas

        Parameters
        ---------
        station_id: int
            the station id

        Returns list of list of ints
        """
        res = self.__get_channels(station_id)
        orientations = np.zeros((len(res), 4))
        antenna_types = []
        channel_ids = []
        for iCh, ch in enumerate(res.values()):
            channel_id = ch['channel_id']
            channel_ids.append(channel_id)
            antenna_types.append(self.get_antenna_type(station_id, channel_id))
            orientations[iCh] = self.get_antenna_orientation(
                station_id, channel_id)
            orientations[iCh][3] = hp.get_normalized_angle(
                orientations[iCh][3], interval=np.deg2rad([0, 180]))
        channel_ids = np.array(channel_ids)
        antenna_types = np.array(antenna_types)
        orientations = np.round(np.rad2deg(
            orientations))  # round to one degree to overcome rounding errors
        parallel_antennas = []
        for antenna_type in np.unique(antenna_types):
            for u_zen_ori in np.unique(orientations[:, 0]):
                for u_az_ori in np.unique(orientations[:, 1]):
                    for u_zen_rot in np.unique(orientations[:, 2]):
                        for u_az_rot in np.unique(orientations[:, 3]):
                            mask = (antenna_types == antenna_type) \
                                & (orientations[:, 0] == u_zen_ori) & (orientations[:, 1] == u_az_ori) \
                                & (orientations[:, 2] == u_zen_rot) & (orientations[:, 3] == u_az_rot)
                            if np.sum(mask):
                                parallel_antennas.append(channel_ids[mask])
        return np.array(parallel_antennas)

    def get_cable_delay(self, station_id, channel_id):
        """
        returns the cable delay of a channel

        Parameters
        ---------
        station_id: int
            the station id
        channel_id: int
            the channel id

        Returns float (delay time)
        """
        res = self.__get_channel(station_id, channel_id)
        return res['cab_time_delay']

    def get_cable_type_and_length(self, station_id, channel_id):
        """
        returns the cable type (e.g. LMR240) and its length

        Parameters
        ---------
        station_id: int
            the station id
        channel_id: int
            the channel id

        Returns tuple (string, float)
        """
        res = self.__get_channel(station_id, channel_id)
        return res['cab_type'], res['cab_length'] * units.m

    def get_antenna_type(self, station_id, channel_id):
        """
        returns the antenna type

        Parameters
        ---------
        station_id: int
            the station id
        channel_id: int
            the channel id

        Returns string
        """
        res = self.__get_channel(station_id, channel_id)
        return res['ant_type']

    def get_antenna_deployment_time(self, station_id, channel_id):
        """
        returns the time of antenna deployment

        Parameters
        ---------
        station_id: int
            the station id
        channel_id: int
            the channel id

        Returns datetime
        """
        res = self.__get_channel(station_id, channel_id)
        return res['ant_deployment_time']

    def get_antenna_orientation(self, station_id, channel_id):
        """
        returns the orientation of a specific antenna

        Parameters
        ---------
        station_id: int
            the station id
        channel_id: int
            the channel id

        Returns
        ---------------
        tuple of floats
            * orientation theta: orientation of the antenna, as a zenith angle (0deg is the zenith, 180deg is straight down); for LPDA: outward along boresight; for dipoles: upward along axis of azimuthal symmetry
            * orientation phi: orientation of the antenna, as an azimuth angle (counting from East counterclockwise); for LPDA: outward along boresight; for dipoles: upward along axis of azimuthal symmetry
            * rotation theta: rotation of the antenna, is perpendicular to 'orientation', for LPDAs: vector perpendicular to the plane containing the the tines
            * rotation phi: rotation of the antenna, is perpendicular to 'orientation', for LPDAs: vector perpendicular to the plane containing the the tines
        """
        res = self.__get_channel(station_id, channel_id)
        return np.deg2rad([
            res['ant_orientation_theta'], res['ant_orientation_phi'],
            res['ant_rotation_theta'], res['ant_rotation_phi']
        ])

    def get_amplifier_type(self, station_id, channel_id):
        """
        returns the type of the amplifier

        Parameters
        ---------
        station_id: int
            the station id
        channel_id: int
            the channel id

        Returns string
        """
        res = self.__get_channel(station_id, channel_id)
        return res['amp_type']

    def get_amplifier_measurement(self, station_id, channel_id):
        """
        returns a unique reference to the amplifier measurement

        Parameters
        ---------
        station_id: int
            the station id
        channel_id: int
            the channel id

        Returns string
        """
        res = self.__get_channel(station_id, channel_id)
        return res['amp_reference_measurement']

    def get_amplifier_response(self, station_id, channel_id, frequencies):
        """
        Returns the amplifier response for the amplifier of a given channel

        Parameters:
        ---------------
        station_id: int
            The ID of the station
        channel_id: int
            The ID of the channel
        frequencies: array of floats
            The frequency array for which the amplifier response shall be returned
        """
        res = self.__get_channel(station_id, channel_id)
        amp_type = None
        if 'amp_type' in res.keys():
            amp_type = res['amp_type']
        if amp_type is None:
            raise ValueError(
                'Amplifier type for station {}, channel {} not in detector description'
                .format(station_id, channel_id))
        amp_response_functions = None
        if amp_type in NuRadioReco.detector.RNO_G.analog_components.get_available_amplifiers(
        ):
            amp_response_functions = NuRadioReco.detector.RNO_G.analog_components.load_amp_response(
                amp_type)
        if amp_type in NuRadioReco.detector.ARIANNA.analog_components.get_available_amplifiers(
        ):
            if amp_response_functions is not None:
                raise ValueError(
                    'Amplifier name {} is not unique'.format(amp_type))
            amp_response_functions = NuRadioReco.detector.ARIANNA.analog_components.load_amplifier_response(
                amp_type)
        if amp_response_functions is None:
            raise ValueError('Amplifier of type {} not found'.format(amp_type))
        amp_gain = amp_response_functions['gain'](frequencies)
        amp_phase = amp_response_functions['phase'](frequencies)
        return amp_gain * amp_phase

    def get_sampling_frequency(self, station_id, channel_id):
        """
        returns the sampling frequency

        Parameters
        ---------
        station_id: int
            the station id
        channel_id: int
            the channel id

        Returns float
        """
        res = self.__get_channel(station_id, channel_id)
        return res['adc_sampling_frequency'] * units.GHz

    def get_number_of_samples(self, station_id, channel_id):
        """
        returns the number of samples of a channel

        Parameters
        ---------
        station_id: int
            the station id
        channel_id: int
            the channel id

        Returns int
        """
        res = self.__get_channel(station_id, channel_id)
        return res['adc_n_samples']

    def get_antenna_model(self, station_id, channel_id, zenith=None):
        """
        determines the correct antenna model from antenna type, position and orientation of antenna

        so far only infinite firn and infinite air cases are differentiated

        Parameters
        ---------
        station_id: int
            the station id
        channel_id: int
            the channel id
        zenith: float or None (default)
            the zenith angle of the incoming signal direction

        Returns string
        """
        antenna_type = self.get_antenna_type(station_id, channel_id)
        antenna_relative_position = self.get_relative_position(
            station_id, channel_id)

        if self._antenna_by_depth:
            if zenith is not None and (antenna_type == 'createLPDA_100MHz'):
                if antenna_relative_position[2] > 0:
                    antenna_model = "{}_InfAir".format(antenna_type)
                    if (not self.__assume_inf) and zenith < 90 * units.deg:
                        antenna_model = "{}_z1cm_InAir_RG".format(antenna_type)
                else:  # antenna in firn
                    antenna_model = "{}_InfFirn".format(antenna_type)
                    if (
                            not self.__assume_inf
                    ) and zenith > 90 * units.deg:  # signal comes from below
                        antenna_model = "{}_z1cm_InFirn_RG".format(
                            antenna_type)
                        # we need to add further distinction here
            elif not antenna_type.startswith('analytic'):
                if antenna_relative_position[2] > 0:
                    antenna_model = "{}_InfAir".format(antenna_type)
                else:
                    antenna_model = "{}_InfFirn".format(antenna_type)
            else:
                antenna_model = antenna_type
        else:
            antenna_model = antenna_type
        return antenna_model

    def get_noise_RMS(self, station_id, channel_id, stage='amp'):
        """
        returns the noise RMS that was precomputed from forced triggers

        Parameters
        ----------
        station_id: int
            station id
        channel_id: int
            the channel id, not used at the moment, only station averages are computed
        stage: string (default 'amp')
            specifies the stage of reconstruction you want the noise RMS for,
            `stage` can be one of
             * 'raw' (raw measured trace)
             * 'amp' (after the amp was deconvolved)
             * 'filt' (after the trace was highpass with 100MHz

        Returns
        -------
        RMS: float
            the noise RMS (actually it is the standard deviation but as the mean should be zero its the same)
        """
        if self.__noise_RMS is None:
            import json
            detector_directory = os.path.dirname(os.path.abspath(__file__))
            with open(os.path.join(detector_directory, 'noise_RMS.json'),
                      'r') as fin:
                self.__noise_RMS = json.load(fin)

        key = "{:d}".format(station_id)
        if key not in self.__noise_RMS.keys():
            rms = self.__noise_RMS['default'][stage]
            logger.warning(
                "no RMS values for station {} available, returning default noise for stage {}: RMS={:.2g} mV"
                .format(station_id, stage, rms / units.mV))
            return rms
        return self.__noise_RMS[key][stage]

    def get_noise_temperature(self, station_id, channel_id):
        """
        returns the noise temperature of the channel

        Parameters
        ----------
        station_id: int
            station id
        channel_id: int
            the channel id

        """
        res = self.__get_channel(station_id, channel_id)
        if 'noise_temperature' not in res:
            raise AttributeError(
                f"field noise_temperature not present in detector description of station {station_id} and channel {channel_id}"
            )
        return res['noise_temperature']

    def is_channel_noiseless(self, station_id, channel_id):
        """
        returns true if the detector description has the field `noiseless` and if this field is True.

        Allows to run a noiseless simulation on specific channels (for example to simulate a single-antenna proxy
        along with the phased array)

        Parameters
        ----------
        station_id: int
            station id
        channel_id: int
            the channel id

        """
        res = self.__get_channel(station_id, channel_id)
        if 'noiseless' not in res:
            return False
        return res['noiseless']
Example #24
0
import json
from tinydb import TinyDB, Query

db = TinyDB("db.json")  # Initialise the databse in json format
print("\n Setup DB")

db.truncate()  # Purge the databse of any previous entries

# Test inserting value into db
db.insert({"name": "James"})  # Insert data into the database
db.insert({"dob": "10 11 91"})
db.insert({"city": "London"})
db.insert({"job": "Network Engineer"})

for i in db:  # Iters over the db to print out line by line
    print(i)

print("\n" + "*" * 50)
Example #25
0
class OssIndex():
    """ossindex.py makes a request to OSSIndex"""
    def __init__(self,
                 url='https://ossindex.sonatype.org/api/v3/component-report',
                 cache_location=''):
        self._url = url
        self._headers = DEFAULT_HEADERS
        self._log = logging.getLogger('jake')
        self._maxcoords = 128
        if cache_location == '':
            home = str(Path.home())
            dir_oss = home + "/.ossindex/"
        else:
            dir_oss = cache_location + "/.ossindex/"
        if not Path(dir_oss).exists():
            Path(dir_oss).mkdir(parents=True, exist_ok=True)
        self._db = TinyDB(dir_oss + "jake.json")

    def get_url(self):
        """gets url to use for OSSIndex request"""
        return self._url

    def get_headers(self):
        """gets headers to use for OSSIndex request"""
        return self._headers

    def chunk(self, coords: Coordinates):
        """chunks up purls array into 128-purl subarrays"""
        chunks = []
        divided = []
        length = len(coords.get_coordinates())
        num_chunks = length // self._maxcoords
        if length % self._maxcoords > 0:
            num_chunks += 1
        start_index = 0
        end_index = self._maxcoords
        for i in range(0, num_chunks):
            if i == (num_chunks - 1):
                divided = coords.get_purls()[start_index:length]
            else:
                divided = coords.get_purls()[start_index:end_index]
                start_index = end_index
                end_index += end_index
            chunks.append(divided)
        return chunks

    def call_ossindex(self, coords: Coordinates) -> (list):
        """makes a request to OSSIndex"""
        self._log.debug("Purls received, total purls before chunk: %s",
                        len(coords.get_coordinates()))

        (coords, results) = self.get_purls_and_results_from_cache(coords)

        self._log.debug(
            "Purls checked against cache, total purls remaining to "
            "call OSS Index: %s", len(coords.get_coordinates()))

        chunk_purls = self.chunk(coords)
        for purls_chunk in chunk_purls:
            data = {}
            data["coordinates"] = purls_chunk
            config_file = Config()
            if config_file.check_if_config_exists() is False:
                response = requests.post(self.get_url(),
                                         data=json.dumps(data),
                                         headers=self.get_headers())
            else:
                auth = config_file.get_config_from_file(".oss-index-config")

                response = requests.post(self.get_url(),
                                         data=json.dumps(data),
                                         headers=self.get_headers(),
                                         auth=(auth["Username"],
                                               auth["Token"]))
            if response.status_code == 200:
                self._log.debug(response.headers)
                first_results = json.loads(response.text, cls=ResultsDecoder)
            else:
                self._log.debug("Response failed, status: %s",
                                response.status_code)
                self._log.debug("Failure reason if any: %s", response.reason)
                self._log.debug("Failure text if any: %s", response.text)
                return None
            results.extend(first_results)

        (cached, num_cached) = self.maybe_insert_into_cache(results)
        self._log.debug("Cached: <%s> num_cached: <%s>", cached, num_cached)
        return results

    def maybe_insert_into_cache(self, results: List[CoordinateResults]):
        """checks to see if result is in cache and if not, stores it"""
        coordinate_query = Query()
        num_cached = 0
        cached = False
        for coordinate in results:
            mydatetime = datetime.now()
            twelvelater = mydatetime + timedelta(hours=12)
            result = self._db.search(
                coordinate_query.purl == coordinate.get_coordinates())
            if len(result) == 0:
                self._db.insert({
                    'purl': coordinate.get_coordinates(),
                    'response': coordinate.to_json(),
                    'ttl': twelvelater.isoformat()
                })
                self._log.debug("Coordinate inserted into cache: <%s>",
                                coordinate.get_coordinates())
                num_cached += 1
                cached = True
            else:
                timetolive = DT.datetime.strptime(result[0]['ttl'],
                                                  '%Y-%m-%dT%H:%M:%S.%f')
                if mydatetime > timetolive:
                    self._db.update(
                        {
                            'response': coordinate.to_json(),
                            'ttl': twelvelater.isoformat()
                        },
                        doc_ids=[result[0].doc_id])
                    self._log.debug(
                        "Coordinate: <%s> updated in cache because TTL"
                        " expired", coordinate.get_coordinates())
                    num_cached += 1
                    cached = True

        return (cached, num_cached)

    def get_purls_and_results_from_cache(
            self, purls: Coordinates) -> (Coordinates, list):
        """get cached purls and results from cache"""
        valid = isinstance(purls, Coordinates)
        if not valid:
            return (None, None)
        new_purls = Coordinates()
        results = []
        coordinate_query = Query()
        for coordinate, purl in purls.get_coordinates().items():
            mydatetime = datetime.now()
            result = self._db.search(coordinate_query.purl == purl)
            if len(result) == 0 or DT.datetime.strptime(
                    result[0]['ttl'], '%Y-%m-%dT%H:%M:%S.%f') < mydatetime:
                new_purls.add_coordinate(coordinate[0], coordinate[1],
                                         coordinate[2])
            else:
                results.append(
                    json.loads(result[0]['response'], cls=ResultsDecoder))
        return (new_purls, results)

    def clean_cache(self):
        """removes all documents from the table"""
        self._db.truncate()
        return True

    def close_db(self):
        """closes connection to TinyDB"""
        self._db.close()
Example #26
0
logstat = os.stat(UPLOAD_DB)
users = TinyDB(USER_DB, indent=4) # user database

def containerStats(): # prints the stats for the container
    print("Current container: " + str(os.getcwd()))
    print("Container size: " + str(sum(os.path.getsize(f) for f in os.listdir(".") if os.path.isfile(f))) + " bytes")
    print("Container file count: " + str(len([name for name in os.listdir(".") if os.path.isfile(name)])))
    print("Upload DB name: " + str(log.name))
    print("Upload DB log count: " + str(len(log)))
    print("Upload DB size: " + str(logstat.st_size) + " bytes.")

#os.chdir(CONTAINER_FOLDER) # change active folder to container
os.chdir(CONTAINER_FOLDER)
print("Simple File Container cleaner \n")
containerStats()

askDel = input("\nDelete all files in the container and clear the upload database? (y/n): ")
if askDel[:1] == "y":
    try:
        #shutil.rmtree(CONTAINER_FOLDER)
        shutil.rmtree(".", ignore_errors=True) # clear the container folder
        log.truncate() # clear the upload database
        print("\nCleared the container successfully!")
        containerStats()

    except:
        print("An error occured while trying to clean the container.")

if askDel[:1] == "n":
    print("Exiting...")
    exit()
Example #27
0
class EnvDb:
    def __init__(self, db_path):
        from tinydb import TinyDB
        self.entries = []
        self.db = TinyDB(db_path)
        # There is a main problem with querying the way it is done here. The content of the DB is only loaded once
        # and therefore the DB is not used as a DB and the object must the re-instantiated to update the entries.
        self._load_entries()

    def _load_entries(self):
        self.entries = []
        # load everything into memory
        for db_entry in self.db:
            try:
                self.entries.append(EnvDbEntry.from_config(db_entry))
            except Exception as e:
                logger.warning("Could not load entry with cli path {0} due to: {1}. "
                            "Skipping...".format(str(db_entry), str(e)))

    def get_entry_by_model(self, model_name, only_most_recent=True, only_valid=False):
        # iterate over all the entries and select the ones where the model_name is part of one of the listed models
        # For checking split the model_name by "/" as well as the env-compatible model names and then check equality.
        # Select the one with the most recent timestamp
        norm_name = lambda x: x.lstrip("/").rstrip("/")
        norm_model_name = norm_name(model_name)
        query_model_tk_len = len(norm_model_name.split("/"))
        sel_entries = {}
        for entry in self.get_all(only_valid=only_valid):
            pre_sel = [m for m in entry.compatible_models if model_name in m]
            sel = [m for m in pre_sel if "/".join(norm_name(m).split("/")[:query_model_tk_len]) == norm_model_name]
            if len(sel) != 0:
                sel_entries[entry.timestamp] = entry
        ordered_entries = OrderedDict([(k, sel_entries[k]) for k in sorted(list(sel_entries.keys()))][::-1])

        if only_most_recent:
            if len(ordered_entries) == 0:
                return None
            else:
                return list(ordered_entries.values())[0]
        else:
            return list(ordered_entries.values())

    def get_all_unfinished(self):
        unfinished = []
        for e in self.entries:
            if not e.successful or e.cli_path is None or not os.path.exists(e.cli_path):
                unfinished.append(e)
        return unfinished

    def db_remove_unfinished(self):
        [self.remove(e) for e in self.get_all_unfinished()]

    def get_all(self, only_valid=False):
        entries = self.entries
        if only_valid:
            invalid = self.get_all_unfinished()
            entries = [e for e in entries if e not in invalid]
        return entries

    def remove(self, entry):
        self.entries = [e for e in self.entries if e != entry]

    def append(self, entry):
        self.entries.append(entry)

    def save(self):
        self.db.truncate()
        for entry in self.entries:
            self.db.insert(entry.get_config())

    def __del__(self):
        self.db.close()
Example #28
0
class Cache:
    def __init__(
        self,
        central: CentralApi = None,
        data: Union[List[dict, ], dict, ] = None,
        refresh: bool = False,
    ) -> None:
        self.updated: list = []
        self.central = central
        self.DevDB = TinyDB(config.cache_file)
        self.SiteDB = self.DevDB.table("sites")
        self.GroupDB = self.DevDB.table("groups")
        self.TemplateDB = self.DevDB.table("templates")
        # log db is used to provide simple index to get details for logs
        # vs the actual log id in form 'audit_trail_2021_2,AXfQAu2hkwsSs1O3R7kv'
        # it is updated anytime show logs is ran.
        self.LogDB = self.DevDB.table("logs")
        self._tables = [self.DevDB, self.SiteDB, self.GroupDB, self.TemplateDB]
        self.Q = Query()
        if data:
            self.insert(data)
        if central:
            self.check_fresh(refresh)

    def __call__(self, refresh=False) -> None:
        if refresh:
            self.check_fresh(refresh)

    def __iter__(self) -> list:
        for db in self._tables:
            yield db.name(), db.all()

    @property
    def devices(self) -> list:
        return self.DevDB.all()

    @property
    def sites(self) -> list:
        return self.SiteDB.all()

    @property
    def groups(self) -> list:
        return self.GroupDB.all()

    @property
    def logs(self) -> list:
        return self.LogDB.all()

    @property
    def group_names(self) -> list:
        return [g["name"] for g in self.GroupDB.all()]

    @property
    def templates(self) -> list:
        return self.TemplateDB.all()

    @property
    def all(self) -> dict:
        return {t.name: getattr(self, t.name) for t in self._tables}

    @staticmethod
    def account_completion(incomplete: str, ):
        for a in config.defined_accounts:
            if a.lower().startswith(incomplete.lower()):
                yield a

    def null_completion(self, incomplete: str):
        incomplete = "NULL_COMPLETION"
        _ = incomplete
        for m in ["|", "<cr>"]:
            yield m

    def dev_completion(
        self,
        incomplete: str,
        args: List[str] = None,
    ):
        dev_type = None
        if args[-1].lower() == "gateways":
            dev_type = "gateway"
        if args[-1].lower().startswith("switch"):
            dev_type = "switch"
        if args[-1].lower() in ["aps", "ap"]:
            dev_type = "ap"

        match = self.get_dev_identifier(
            incomplete,
            dev_type=dev_type,
            completion=True,
        )
        out = []
        if match:
            for m in sorted(match, key=lambda i: i.name):
                if m.name.startswith(incomplete):
                    out += [tuple([m.name, m.help_text])]
                elif m.serial.startswith(incomplete):
                    out += [tuple([m.serial, m.help_text])]
                elif m.mac.strip(":.-").lower().startswith(
                        incomplete.strip(":.-")):
                    out += [tuple([m.mac, m.help_text])]
                elif m.ip.startswith(incomplete):
                    out += [tuple([m.ip, m.help_text])]
                else:
                    # failsafe, shouldn't hit
                    out += [tuple([m.name, m.help_text])]

        for m in out:
            yield m

    def dev_kwarg_completion(
        self,
        incomplete: str,
        args: List[str] = None,
    ):
        """Completion for commands that allow a list of devices followed by grouop/site.

        i.e. cencli move dev1 dev2 dev3 site site_name group group_name

        Args:
            incomplete (str): The incomplete word for autocompletion
            args (List[str], optional): The prev args passed into the command.

        Yields:
            tuple: matching completion string, help text
        """
        if args[-1].lower() == "group":
            out = [m for m in self.group_completion(incomplete)]
            for m in out:
                yield m

        elif args[-1].lower() == "site":
            out = [m for m in self.site_completion(incomplete)]
            for m in out:
                yield m

        else:
            out = []
            if len(args) > 1:
                if "site" not in args and "site".startswith(
                        incomplete.lower()):
                    out += ("site", )
                if "group" not in args and "group".startswith(
                        incomplete.lower()):
                    out += ("group", )

            if "site" not in args and "group" not in args:
                out += [m for m in self.dev_completion(incomplete)]
            elif "site" in args and "group" in args:
                incomplete = "NULL_COMPLETION"
                out += ["|", "<cr>"]

            for m in out:
                yield m

    def group_completion(
        self,
        incomplete: str,
        args: List[str] = None,
    ):
        match = self.get_group_identifier(
            incomplete,
            completion=True,
        )
        out = []
        if match:
            for m in sorted(match, key=lambda i: i.name):
                out += [tuple([m.name, m.help_text])]

        for m in out:
            yield m[0], m[1]

    def site_completion(
        self,
        incomplete: str,
        args: List[str] = None,
    ):
        match = self.get_site_identifier(
            incomplete.replace('"', "").replace("'", ""),
            completion=True,
        )
        out = []
        if match:
            for m in sorted(match, key=lambda i: i.name):
                out += [
                    tuple([
                        m.name if " " not in m.name else f"'{m.name}'",
                        m.help_text
                    ])
                ]

        for m in out:
            yield m

    def template_completion(
        self,
        incomplete: str,
        args: List[str] = None,
    ):
        match = self.get_template_identifier(
            incomplete,
            completion=True,
        )
        out = []
        if match:
            for m in sorted(match, key=lambda i: i.name):
                out += [tuple([m.name, m.help_text])]

        for m in out:
            yield m

    def dev_template_completion(
        self,
        incomplete: str,
        args: List[str] = None,
    ):
        match = self.get_template_identifier(
            incomplete,
            completion=True,
        )
        match = match or []
        dev_match = self.get_dev_identifier(
            incomplete,
            completion=True,
        )
        match += dev_match or []
        out = []
        if match:
            for m in sorted(match, key=lambda i: i.name):
                out += [tuple([m.name, m.help_text])]

        for m in out:
            yield m

    def dev_site_completion(
        self,
        incomplete: str,
        args: List[str] = None,
    ):
        match = self.get_dev_identifier(
            incomplete,
            completion=True,
        )
        match = match or []
        match += self.get_site_identifier(
            incomplete,
            completion=True,
        )
        out = []
        if match:
            for m in sorted(match, key=lambda i: i.name):
                out += [tuple([m.name, m.help_text])]

        for m in out:
            yield m

    def remove_completion(
        self,
        incomplete: str,
        args: List[str],
    ):
        if args[-1].lower() == "site":
            out = [m for m in self.site_completion(incomplete)]
            for m in out:
                yield m
        else:
            out = []
            if len(args) > 1:
                if "site" not in args and "site".startswith(
                        incomplete.lower()):
                    out += ("site", )

            if "site" not in args:
                out += [m for m in self.dev_completion(incomplete)]
            else:
                out += [m for m in self.null_completion(incomplete)]

            for m in out:
                yield m

    def completion(
        self,
        incomplete: str,
        args: List[str],
    ):
        cache = ()
        if [True for m in DEV_COMPLETION if args[-1].endswith(m)]:
            cache += tuple(["dev"])
        elif [True for m in GROUP_COMPLETION if args[-1].endswith(m)]:
            cache += tuple(["group"])
        elif [True for m in SITE_COMPLETION if args[-1].endswith(m)]:
            cache += tuple(["site"])
        elif [True for m in TEMPLATE_COMPLETION if args[-1].endswith(m)]:
            cache += tuple(["template"])

        if not cache:
            match = self.get_identifier(
                incomplete,
                ("dev", "group", "site", "template"),
                completion=True,
            )
        else:
            match = self.get_identifier(
                incomplete,
                tuple(cache),
                completion=True,
            )

        out = []
        _extra = [e for e in EXTRA_COMPLETION if e in args and args[-1] != e]
        if _extra:
            out += [
                tuple([m, "COMMAND KEYWORD"]) for e in _extra
                for m in EXTRA_COMPLETION[e]
                if m.startswith(incomplete) and args[-1] != m
            ]

        if match:
            for m in sorted(match, key=lambda i: i.name):
                out += [tuple([m.name, m.help_text])]

        for m in out:
            yield m

    # TODO ??deprecated?? should be able to remove this method. don't remember this note. looks used
    def insert(
        self,
        data: Union[List[dict, ], dict, ],
    ) -> bool:
        _data = data
        if isinstance(data, list) and data:
            _data = data[1]

        table = self.DevDB
        if "zipcode" in _data.keys():
            table = self.SiteDB

        data = data if isinstance(data, list) else [data]
        ret = table.insert_multiple(data)

        return len(ret) == len(data)

    # TODO have update methods return Response
    async def update_dev_db(self):
        resp = await self.central.get_all_devicesv2()
        if resp.ok:
            resp.output = utils.listify(resp.output)
            self.updated.append(self.central.get_all_devicesv2)
            self.DevDB.truncate()
            return self.DevDB.insert_multiple(resp.output)

    async def update_site_db(self,
                             data: Union[list, dict] = None,
                             remove: bool = False) -> List[int]:
        # cli.cache.SiteDB.search(cli.cache.Q.id == del_list[0])[0].doc_id
        if data:
            data = utils.listify(data)
            if not remove:
                return self.SiteDB.insert_multiple(data)
            else:
                doc_ids = []
                for qry in data:
                    # provided list of site_ids to remove
                    if isinstance(qry, (int, str)) and str(qry).isdigit():
                        doc_ids += [
                            self.SiteDB.get((self.Q.id == int(qry))).doc_id
                        ]
                    else:
                        # list of dicts with {search_key: value_to_search_for}
                        if len(qry.keys()) > 1:
                            raise ValueError(
                                f"cache.update_site_db remove Should only have 1 query not {len(qry.keys())}"
                            )
                        q = list(qry.keys())[0]
                        doc_ids += [
                            self.SiteDB.get((self.Q[q] == qry[q])).doc_id
                        ]
                return self.SiteDB.remove(doc_ids=doc_ids)
        else:
            resp = await self.central.get_all_sites()
            if resp.ok:
                resp.output = utils.listify(resp.output)
                # TODO time this to see which is more efficient
                # start = time.time()
                # upd = [self.SiteDB.upsert(site, cond=self.Q.id == site.get("id")) for site in site_resp.output]
                # upd = [item for in_list in upd for item in in_list]
                self.updated.append(self.central.get_all_sites)
                self.SiteDB.truncate()
                # print(f" site db Done: {time.time() - start}")
                return self.SiteDB.insert_multiple(resp.output)

    async def update_group_db(self,
                              data: Union[list, dict] = None,
                              remove: bool = False) -> List[int]:
        if data:
            data = utils.listify(data)
            if not remove:
                return self.GroupDB.insert_multiple(data)
            else:
                doc_ids = []
                for qry in data:
                    if len(qry.keys()) > 1:
                        raise ValueError(
                            f"cache.update_group_db remove Should only have 1 query not {len(qry.keys())}"
                        )
                    q = list(qry.keys())[0]
                    doc_ids += [self.GroupDB.get((self.Q[q] == qry[q])).doc_id]
                return self.GroupDB.remove(doc_ids=doc_ids)
        else:
            resp = await self.central.get_all_groups()
            if resp.ok:
                resp.output = utils.listify(resp.output)
                self.updated.append(self.central.get_all_groups)
                self.GroupDB.truncate()
                return self.GroupDB.insert_multiple(resp.output)

    async def update_template_db(self):
        groups = self.groups if self.central.get_all_groups in self.updated else None
        resp = await self.central.get_all_templates(groups=groups)
        if resp.ok:
            resp.output = utils.listify(resp.output)
            self.updated.append(self.central.get_all_templates)
            self.TemplateDB.truncate()
            return self.TemplateDB.insert_multiple(resp.output)

    def update_log_db(self, log_data: List[Dict[str, Any]]) -> bool:
        self.LogDB.truncate()
        return self.LogDB.insert_multiple(log_data)

    async def _check_fresh(self,
                           dev_db: bool = False,
                           site_db: bool = False,
                           template_db: bool = False,
                           group_db: bool = False):
        update_funcs = []
        if dev_db:
            update_funcs += [self.update_dev_db]
        if site_db:
            update_funcs += [self.update_site_db]
        if template_db:
            update_funcs += [self.update_template_db]
        if group_db:
            update_funcs += [self.update_group_db]
        async with ClientSession() as self.central.aio_session:
            if update_funcs:
                if await update_funcs[0]():
                    if len(update_funcs) > 1:
                        await asyncio.gather(*[f() for f in update_funcs[1:]])

            # update groups first so template update can use the result, and to trigger token_refresh if necessary
            elif await self.update_group_db():
                await asyncio.gather(self.update_dev_db(),
                                     self.update_site_db(),
                                     self.update_template_db())

    def check_fresh(
        self,
        refresh: bool = False,
        site_db: bool = False,
        dev_db: bool = False,
        template_db: bool = False,
        group_db: bool = False,
    ) -> None:
        if True in [site_db, dev_db, group_db, template_db]:
            refresh = True

        if refresh or not config.cache_file.is_file(
        ) or not config.cache_file.stat().st_size > 0:
            #  or time.time() - config.cache_file.stat().st_mtime > 7200:
            start = time.time()
            print(typer.style("-- Refreshing Identifier mapping Cache --",
                              fg="cyan"),
                  end="")
            db_res = asyncio.run(
                self._check_fresh(dev_db=dev_db,
                                  site_db=site_db,
                                  template_db=template_db,
                                  group_db=group_db))
            if db_res and False in db_res:
                res_map = ["dev_db", "site_db", "template_db", "group_db"]
                res_map = ", ".join(
                    [db for idx, db in enumerate(res_map) if not db_res(idx)])
                log.error(
                    f"TinyDB returned error ({res_map}) during db update")
                self.central.spinner.fail(
                    f"Cache Refresh Returned an error updating ({res_map})")
            else:
                self.central.spinner.succeed(
                    f"Cache Refresh Completed in {round(time.time() - start, 2)} sec"
                )
            log.info(
                f"Cache Refreshed in {round(time.time() - start, 2)} seconds")
            # typer.secho(f"-- Cache Refresh Completed in {round(time.time() - start, 2)} sec --", fg="cyan")

    def handle_multi_match(
        self,
        match: List[CentralObject],
        query_str: str = None,
        query_type: str = "device",
        multi_ok: bool = False,
    ) -> List[Dict[str, Any]]:
        # typer.secho(f" -- Ambiguos identifier provided.  Please select desired {query_type}. --\n", color="cyan")
        typer.echo()
        if query_type == "site":
            fields = ("name", "city", "state", "type")
        elif query_type == "template":
            fields = ("name", "group", "model", "device_type", "version")
        else:  # device
            fields = ("name", "serial", "mac", "type")
        out = utils.output(
            [{k: d[k]
              for k in d.data if k in fields} for d in match],
            title=f"Ambiguos identifier. Select desired {query_type}.")
        menu = out.menu(data_len=len(match))

        if query_str:
            menu = menu.replace(query_str,
                                typer.style(query_str, fg="bright_cyan"))
            menu = menu.replace(
                query_str.upper(),
                typer.style(query_str.upper(), fg="bright_cyan"))
        typer.echo(menu)
        selection = ""
        valid = [str(idx + 1) for idx, _ in enumerate(match)]
        try:
            while selection not in valid:
                selection = typer.prompt(f"Select {query_type.title()}")
                if not selection or selection not in valid:
                    typer.secho(f"Invalid selection {selection}, try again.")
        except KeyboardInterrupt:
            raise typer.Abort()

        return [match.pop(int(selection) - 1)]

    def get_identifier(
        self,
        qry_str: str,
        qry_funcs: Sequence[str],
        device_type: str = None,
        group: str = None,
        multi_ok: bool = False,
        completion: bool = False,
    ) -> CentralObject:
        """Get Identifier when iden type could be one of multiple types.  i.e. device or group

        Args:
            qry_str (str): The query string provided by user.
            qry_funcs (Sequence[str]): Sequence of strings "dev", "group", "site", "template"
            device_type (str, optional): str indicating what devices types are valid for dev idens.
                Defaults to None.
            group (str, optional): applies to get_template_identifier, Only match if template is in this group.
                Defaults to None.
            multi_ok (bool, optional): DEPRECATED, NO LONGER USED
            completion (bool, optional): If function is being called for AutoCompletion purposes. Defaults to False.
                When called for completion it will fail silently and will return multiple when multiple matches are found.

        Raises:
            typer.Exit: If not ran for completion, and there is no match, exit with code 1.

        Returns:
            CentralObject
        """
        # TODO remove multi_ok once verified refs are removed
        match = None
        default_kwargs = {"retry": False, "completion": completion}
        for _ in range(0, 2):
            for q in qry_funcs:
                kwargs = default_kwargs.copy()
                if q == "dev":
                    kwargs["dev_type"] = device_type
                elif q == "template":
                    kwargs["group"] = group
                match: CentralObject = getattr(self,
                                               f"get_{q}_identifier")(qry_str,
                                                                      **kwargs)

                if match and not completion:
                    return match

            # No match found trigger refresh and try again.
            if not match and not completion:
                self.check_fresh(
                    dev_db=True if "dev" in qry_funcs else False,
                    site_db=True if "site" in qry_funcs else False,
                    template_db=True if "template" in qry_funcs else False,
                    group_db=True if "group" in qry_funcs else False,
                )

        if completion:
            return match

        if not match:
            typer.secho(
                f"Unable to find a matching identifier for {qry_str}, tried: {qry_funcs}",
                fg="red")
            raise typer.Exit(1)

    def get_dev_identifier(
        self,
        query_str: Union[str, List[str], tuple],
        dev_type: str = None,
        ret_field: str = "serial",
        retry: bool = True,
        multi_ok: bool = True,
        completion: bool = False,
    ) -> CentralObject:

        retry = False if completion else retry
        # TODO dev_type currently not passed in or handled identifier for show switches would also
        # try to match APs ...  & (self.Q.type == dev_type)
        # TODO refactor to single test function usable by all identifier methods 1 search with a more involved test
        if isinstance(query_str, (list, tuple)):
            query_str = " ".join(query_str)

        match = None
        for _ in range(0, 2 if retry else 1):
            # Try exact match
            match = self.DevDB.search(
                (self.Q.name == query_str)
                | (self.Q.ip.test(lambda v: v.split("/")[0] == query_str))
                | (self.Q.mac == utils.Mac(query_str).cols)
                | (self.Q.serial == query_str))

            # retry with case insensitive name match if no match with original query
            if not match:
                match = self.DevDB.search(
                    (self.Q.name.test(lambda v: v.lower() == query_str.lower())
                     )
                    | self.Q.mac.test(lambda v: v.lower() == utils.Mac(
                        query_str).cols.lower())
                    | self.Q.serial.test(
                        lambda v: v.lower() == query_str.lower()))

            # retry name match swapping - for _ and _ for -
            if not match:
                if "-" in query_str:
                    match = self.DevDB.search(
                        self.Q.name.test(lambda v: v.lower() == query_str.
                                         lower().replace("-", "_")))
                elif "_" in query_str:
                    match = self.DevDB.search(
                        self.Q.name.test(lambda v: v.lower() == query_str.
                                         lower().replace("_", "-")))

            # Last Chance try to match name if it startswith provided value
            if not match:
                match = self.DevDB.search(
                    self.Q.name.test(
                        lambda v: v.lower().startswith(query_str.lower()))
                    | self.Q.serial.test(
                        lambda v: v.lower().startswith(query_str.lower())))
                if not match:
                    qry_mac = utils.Mac(query_str)
                    qry_mac_fuzzy = utils.Mac(query_str, fuzzy=True)
                    if qry_mac or len(qry_mac) == len(qry_mac_fuzzy):
                        match = self.DevDB.search(
                            self.Q.mac.test(lambda v: v.lower().startswith(
                                utils.Mac(query_str, fuzzy=completion).cols.
                                lower())))
            if retry and not match and self.central.get_all_devicesv2 not in self.updated:
                typer.secho(
                    f"No Match Found for {query_str}, Updating Device Cache",
                    fg="red")
                self.check_fresh(refresh=True, dev_db=True)
            if match:
                match = [CentralObject("dev", dev) for dev in match]

        all_match = None
        if dev_type:
            all_match = match
            match = [
                d for d in match if d.generic_type.lower() in "".join(
                    dev_type[0:len(d.generic_type)]).lower()
            ]

        if match:
            if completion:
                return match

            elif len(match) > 1:
                match = self.handle_multi_match(match,
                                                query_str=query_str,
                                                multi_ok=multi_ok)

            return match[0]
        elif retry:
            log.error(
                f"Unable to gather device {ret_field} from provided identifier {query_str}",
                show=True)
            if all_match:
                all_match = all_match[-1]
                log.error(
                    f"The Following device matched {all_match.name} excluded as {all_match.type} != {dev_type}",
                    show=True,
                )
            raise typer.Exit(1)
        # else:
        #     log.error(f"Unable to gather device {ret_field} from provided identifier {query_str}", show=True)

    def get_site_identifier(
        self,
        query_str: Union[str, List[str], tuple],
        ret_field: str = "id",
        retry: bool = True,
        multi_ok: bool = False,
        completion: bool = False,
    ) -> CentralObject:
        retry = False if completion else retry
        if isinstance(query_str, (list, tuple)):
            query_str = " ".join(query_str)

        match = None
        for _ in range(0, 2 if retry else 1):
            # try exact site match
            match = self.SiteDB.search(
                (self.Q.name == query_str)
                | (self.Q.id.test(lambda v: str(v) == query_str))
                | (self.Q.zipcode == query_str)
                | (self.Q.address == query_str)
                | (self.Q.city == query_str)
                | (self.Q.state == query_str))

            # retry with case insensitive name & address match if no match with original query
            if not match:
                match = self.SiteDB.search(
                    (self.Q.name.test(lambda v: v.lower() == query_str.lower())
                     )
                    | self.Q.address.test(lambda v: v.lower().replace(
                        " ", "") == query_str.lower().replace(" ", "")))

            # swap _ and - and case insensitive
            if not match:
                if "-" in query_str:
                    match = self.SiteDB.search(
                        self.Q.name.test(lambda v: v.lower() == query_str.
                                         lower().replace("-", "_")))
                elif "_" in query_str:
                    match = self.SiteDB.search(
                        self.Q.name.test(lambda v: v.lower() == query_str.
                                         lower().replace("_", "-")))

            # Last Chance try to match name if it startswith provided value
            if not match:
                match = self.SiteDB.search(
                    self.Q.name.test(
                        lambda v: v.lower().startswith(query_str.lower())))

            if retry and not match and self.central.get_all_sites not in self.updated:
                typer.secho(
                    f"No Match Found for {query_str}, Updating Site Cache",
                    fg="red")
                self.check_fresh(refresh=True, site_db=True)
            if match:
                match = [CentralObject("site", s) for s in match]
                break

        if match:
            if completion:
                return match

            if len(match) > 1:
                match = self.handle_multi_match(match,
                                                query_str=query_str,
                                                query_type="site",
                                                multi_ok=multi_ok)

            return match[0]

        elif retry:
            log.error(
                f"Unable to gather site {ret_field} from provided identifier {query_str}",
                show=True)
            raise typer.Exit(1)

    def get_group_identifier(
        self,
        query_str: str,
        ret_field: str = "name",
        retry: bool = True,
        multi_ok: bool = False,
        completion: bool = False,
    ) -> CentralObject:
        """Allows Case insensitive group match"""
        retry = False if completion else retry
        for _ in range(0, 2):
            # Exact match
            match = self.GroupDB.search((self.Q.name == query_str))

            # case insensitive
            if not match:
                match = self.GroupDB.search(
                    self.Q.name.test(lambda v: v.lower() == query_str.lower()))

            # case insensitive ignore -_
            if not match:
                if "_" in query_str or "-" in query_str:
                    match = self.GroupDB.search(
                        self.Q.name.test(lambda v: v.lower().strip("-_") ==
                                         query_str.lower().strip("_-")))
                #     match = self.GroupDB.search(
                #         self.Q.name.test(
                #             lambda v: v.lower() == query_str.lower().replace("_", "-")
                #         )
                #     )
                # elif "-" in query_str:
                #     match = self.GroupDB.search(
                #         self.Q.name.test(
                #             lambda v: v.lower() == query_str.lower().replace("-", "_")
                #         )
                #     )

            # case insensitive startswith
            if not match:
                match = self.GroupDB.search(
                    self.Q.name.test(
                        lambda v: v.lower().startswith(query_str.lower())))

            # case insensitive startswith ignore - _
            if not match:
                match = self.GroupDB.search(
                    self.Q.name.test(lambda v: v.lower().strip(
                        "-_").startswith(query_str.lower().strip("-_"))))

            if not match and retry and self.central.get_all_groups not in self.updated:
                typer.secho(
                    f"No Match Found for {query_str}, Updating group Cache",
                    fg="red")
                self.check_fresh(refresh=True, group_db=True)
                _ += 1
            if match:
                match = [CentralObject("group", g) for g in match]
                break

        if match:
            if completion:
                return match

            if len(match) > 1:
                match = self.handle_multi_match(match,
                                                query_str=query_str,
                                                query_type="group",
                                                multi_ok=multi_ok)

            return match[0]

        elif retry:
            log.error(
                f"Central API CLI Cache unable to gather group data from provided identifier {query_str}",
                show=True)
            valid_groups = "\n".join(self.group_names)
            typer.secho(f"{query_str} appears to be invalid", fg="red")
            typer.secho(f"Valid Groups:\n--\n{valid_groups}\n--\n", fg="cyan")
            raise typer.Exit(1)
        else:
            if not completion:
                log.error(
                    f"Central API CLI Cache unable to gather group data from provided identifier {query_str}",
                    show=True)

    def get_template_identifier(
        self,
        query_str: str,
        ret_field: str = "name",
        group: str = None,
        retry: bool = True,
        multi_ok: bool = False,
        completion: bool = False,
    ) -> CentralObject:
        """Allows case insensitive template match by template name"""
        retry = False if completion else retry
        match = None
        for _ in range(0, 2 if retry else 1):
            # exact
            match = self.TemplateDB.search((self.Q.name == query_str))

            # case insensitive
            if not match:
                match = self.TemplateDB.search(
                    self.Q.name.test(lambda v: v.lower() == query_str.lower()))

            # case insensitive with -/_ swap
            if not match:
                if "_" in query_str:
                    match = self.TemplateDB.search(
                        self.Q.name.test(lambda v: v.lower() == query_str.
                                         lower().replace("_", "-")))
                elif "-" in query_str:
                    match = self.TemplateDB.search(
                        self.Q.name.test(lambda v: v.lower() == query_str.
                                         lower().replace("-", "_")))

            # startswith
            if not match:
                match = self.TemplateDB.search(
                    self.Q.name.test(
                        lambda v: v.lower().startswith(query_str.lower())))

            if retry and not match and self.central.get_all_templates not in self.updated:
                typer.secho(
                    f"No Match Found for {query_str}, Updating template Cache",
                    fg="red")
                self.check_fresh(refresh=True, template_db=True)
            if match:
                match = [CentralObject("template", tmplt) for tmplt in match]
                break

        if match:
            if completion:
                return match

            if len(match) > 1:
                if group:
                    match = [
                        d for d in match if d.group.lower() == group.lower()
                    ]

            if len(match) > 1:
                match = self.handle_multi_match(
                    match,
                    query_str=query_str,
                    query_type="template",
                    multi_ok=multi_ok,
                )

            return match[0]

        elif retry:
            log.error(
                f"Unable to gather template {ret_field} from provided identifier {query_str}",
                show=True)
            raise typer.Exit(1)
        else:
            if not completion:
                log.warning(
                    f"Unable to gather template {ret_field} from provided identifier {query_str}",
                    show=False)

    def get_log_identifier(self, query: str) -> str:
        if "audit_trail" in query:
            return query
        elif query == "":  # tab completion
            return [x["id"] for x in self.logs]

        try:

            match = self.LogDB.search(self.Q.id == int(query))
            if not match:
                log.warning(
                    f"Unable to gather log id from short index query {query}",
                    show=True)
                typer.echo(
                    "Short log_id aliases are built each time 'show logs' is ran."
                )
                typer.echo(
                    "  You can verify the cache by running (hidden command) 'show cache logs'"
                )
                typer.echo(
                    "  run 'show logs [OPTIONS]' then use the short index for details"
                )
                raise typer.Exit(1)
            else:
                return match[-1]["long_id"]

        except ValueError as e:
            log.exception(
                f"Exception in get_log_identifier {e.__class__.__name__}\n{e}")
            typer.secho(
                f"Exception in get_log_identifier {e.__class__.__name__}",
                fg="red")
            raise typer.Exit(1)
Example #29
0
class Tracker(object):
    def __init__(self, path=None):
        """Constructor for class Tracker"""
        self.path = path if path is not None else os.environ['TRACKER_PATH']
        self.db = TinyDB(self.path)
        self.entry = Query()

    def save(self, entry):
        self.db.insert(entry.properties)
        self.deduplicate()
        self.label_recent()

    def filter(self, cond):
        """tr.filter(tr.entry.tag == "tag")"""
        return (self.db.search(cond))

    def uniq(self, property):
        return (set(self[property]))

    def label_recent(self):
        tags = self.uniq('tag')
        for t in tags:
            entries = self.filter(self.entry.tag == t)
            versions = [e['version'] for e in entries]
            most_recent = sort_versions(versions)[-1]
            self.db.update({'most_recent': True},
                           ((self.entry.tag == t) &
                            (self.entry.version == most_recent)))
            self.db.update({'most_recent': False},
                           ((self.entry.tag == t) &
                            (self.entry.version != most_recent)))

    def remove(self, cond):
        """tr.remove(tr.entry.tag == "tag")"""
        logger.info(f"Conditional for removing entries: {cond}")
        self.db.remove(cond)

    def deduplicate(self):
        vtags = [
            key for key, value in Counter(self['tag_version']).items()
            if value >= 2
        ]
        for tv in vtags:
            entries = self.filter(self.entry.tag_version == tv)
            times = [e['time'] for e in entries]
            most_recent = sorted(times)[-1]
            self.db.remove(((self.entry.tag_version == tv) &
                            (self.entry.time != most_recent)))

    def copy(self, path):
        copyfile(self.path, path)
        return (Tracker(path))

    def get_entry(self, entry_tag, version=None):
        if version is not None:
            entry = self.filter((self.entry.tag == entry_tag)
                                & (self.entry.version == version))[0]
        else:
            entry = self.filter((self.entry.tag == entry_tag)
                                & (self.entry.most_recent == True))[0]
        return (entry)

    def get_file(self, entry_tag, file_tag, version=None):
        """Get File dictionary given entry tag and file tag

        Parameters
        ----------
        entry_tag : str
            Entry tag in database
        file_tag : str
            File tag identifying specific file
        version : str, optional
            Version string if not going by most recent, by default None

        See Also
        --------
        get_entry, get_file_path

        Examples
        --------
        >>>tr.get_file("entry_tag", "file_tag")
        """
        entry = self.get_entry(entry_tag, version)['output_files']
        dli = DictList(entry)
        return (dli.filter_first(cond=lambda x: x['tag'] == file_tag))

    def get_file_path(self, *args, **kwargs):
        return (self.get_file(*args, **kwargs)['path'])

    def get_output_files(self, entry_tag, version=None):
        """Get list of output files from an entry.

        Parameters
        ----------
        entry_tag : str
            Entry tag in database
        version : str, optional
            Version string if not going by most recent, by default None

        See Also
        --------
        get_entry

        Examples
        --------
        >>>tr.get_output_files("entry_tag")
        """
        return (self.get_entry(entry_tag, version)['output_files'])

    def update(self):
        pass

    def to_pandas(self, tag=None, module=None, most_recent=True):
        df = pd.DataFrame.from_dict([row for row in self.db])
        if df.empty:
            return (df)
        df = df.sort_values('time', ascending=False)
        if tag:
            df = df[df.tag == tag]
        if module:
            df = df[df.module == module]
        if most_recent:
            df = df[df.most_recent]
        return (df)

    def explode(self, *args, **kwargs):
        df = self.to_pandas(*args, **kwargs)

        df = df[[
            'tag', 'category', 'module', 'description', 'version',
            'input_files', 'output_files', 'most_recent', 'time'
        ]]

        df0 = df.drop('output_files', axis=1).rename({'input_files': 'files'},
                                                     axis=1)
        df1 = df.drop('input_files', axis=1).rename({'output_files': 'files'},
                                                    axis=1)

        df0 = self._explode_files(df0)
        df1 = self._explode_files(df1)

        df = pd.concat([df0.assign(type='input'),
                        df1.assign(type='output')],
                       axis=0)
        df['basename'] = df['path'].apply(os.path.basename)
        df = df[[
            'tag', 'category', 'module', 'file_tag', 'description', 'type',
            'file_desc', 'basename', 'path', 'most_recent', 'index', 'time'
        ]]
        df = df.sort_values(
            ['time', 'category', 'module', 'tag', 'type', 'index'],
            ascending=[False, True, True, True, True, True])
        return (df)

    def _explode_files(self, df):
        df['files'] = df['files'].map(
            lambda l: [dict(x, **{'index': i}) for i, x in enumerate(l)])
        df = df.explode('files')
        df = df[df.files.notnull()]
        df['file_tag'] = df['files'].map(lambda x: x['tag'])
        df['path'] = df['files'].map(lambda x: x['path'])
        df['file_desc'] = df['files'].map(lambda x: x['description'])
        df['index'] = df['files'].map(lambda x: x['index'])
        df = df.drop('files', axis=1)
        return (df)

    def to_excel(self, path, **kwargs):
        df = self.to_pandas(**kwargs)
        df.to_excel(path)

    def kill(self):
        self.db.truncate()

    def __len__(self):
        return (len(self.db))

    def __getitem__(self, property):
        return ([row[property] for row in self.db])

    @property
    def entries(self):
        return (self.db.all())

    @property
    def table(self):
        return (self.to_pandas(tag=None, most_recent=False))

    @property
    def summary(self):
        df = self.table[[
            'category', 'module', 'tag', 'description', 'version', 'date',
            'time', 'most_recent'
        ]]
        return (df.sort_values('time', ascending=False))
Example #30
0
def deleteAll():
    if request.method == 'GET':
        db = TinyDB('db.json')
        print(db)
        db.truncate()
        return json.dumps(db.all())