예제 #1
0
    def get_data(self):
        # main data
        categories = sorted(glob('/media/ssd/data/train/*'))
        cat_names = [x.split('/')[-1] for x in categories]
        cat_index = {k: i for i, k in enumerate(cat_names)}

        acc = defaultdict(list)

        # voting based pseudo labels
        df = pd.read_csv(VOTING)
        banned = self.exclude_bad_predictions()
        df['is_banned'] = df['fname'].apply(lambda x: x in banned)
        df = df[~df['is_banned']]
        df = df[df.votes >= 5].sort_values('best_camera').reset_index()[[
            'fname', 'best_camera'
        ]]

        for i, row in df.iterrows():
            fold = i % 5
            k = row['fname']
            v = row['best_camera']
            f = join('data/test/', k)
            y_idx = cat_index[v]
            acc[fold].append((f, y_idx))
        logger.info(f'{i} samples come from the pseudo labels dataset')

        return acc, cat_names, cat_index
예제 #2
0
def clean():
    import src.utils.path as pth
    import src.utils.logger as log
    import src.parser.toml as tml

    from shutil import rmtree

    log.info("Cleaning project")

    pth.__remove_file(tml.value('json', section='data', subkey='fname'))
    pth.__remove_file(tml.value('numpy', section='data', subkey='fname'))

    dnames = tml.value('dnames', section='demo')
    if pth.__exists(dnames['input']):
        rmtree(dnames['input'])
    if pth.__exists(dnames['output']):
        rmtree(dnames['output'])
    pth.__remove_file(tml.value('fx_name', section='demo'))

    dnames = tml.value('dnames', section='neuralnet')
    if pth.__exists(dnames['predicted_labels']):
        rmtree(dnames['predicted_labels'])
    if pth.__exists(dnames['expected_labels']):
        rmtree(dnames['expected_labels'])
    if pth.__exists(dnames['original_data']):
        rmtree(dnames['original_data'])
def load_checkpoint(checkpoint_fpath, model, optimizer):
    checkpoint = torch.load(checkpoint_fpath)
    model.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    logger.info(
        "Checkpoint loaded successfully from {}".format(checkpoint_fpath))
    return model, optimizer, checkpoint['epoch']
예제 #4
0
 def add_session(self, session):
     """
     Adds a session to the session list.
     """
     self._sessions.insert(0, session)
     logger.info('Sessions active: %d' %
                 len(self.get_sessions(return_unlogged=True)))
예제 #5
0
def add(vpn_opts: ClientOpts, server_opts: ServerOpts, auth_opts: AuthOpts,
        account: str, is_default: bool, dns_prefix: str, no_connect: bool):
    is_connect = not no_connect
    executor = VPNClientExecutor(vpn_opts).require_install().probe()
    hostname = dns_prefix or executor.generate_host_name(
        server_opts.hub, auth_opts.user, log_lvl=logger.TRACE)
    acc = AccountInfo(server_opts.hub, account, hostname, is_default)
    logger.info(f'Setup VPN Client with VPN account [{acc.account}]...')
    executor.tweak_network_per_account(acc.account, hostname)
    setup_cmd = {
        'AccountCreate':
        f'{acc.account} /SERVER:{server_opts.server} /HUB:{acc.hub} /USERNAME:{auth_opts.user} /NICNAME:{acc.account}'
    }
    setup_cmd = {**setup_cmd, **auth_opts.setup(acc.account)}
    setup_cmd = setup_cmd if not is_connect else {
        **setup_cmd,
        **{
            'AccountConnect': acc.account
        }
    }
    if acc.is_default or is_connect:
        executor.do_disconnect_current(log_lvl=logger.DEBUG)
    executor.exec_command(['NicCreate', 'AccountDisconnect', 'AccountDelete'],
                          acc.account,
                          silent=True)
    executor.exec_command(setup_cmd)
    executor.storage.create_or_update(acc, _connect=is_connect)
    if acc.is_default:
        executor.do_switch_default_acc(acc.account)
    executor.lease_vpn_service(account=acc.account,
                               is_enable=acc.is_default,
                               is_restart=acc.is_default and is_connect,
                               is_lease_ip=not acc.is_default and is_connect)
    logger.done()
예제 #6
0
def key_data_to_xlsx(df, xlsx_filename):
    # Save dataframe to xlsx, with formatting for readability

    n_rows = df.shape[0]
    writer = pd.ExcelWriter(xlsx_filename, engine='xlsxwriter')
    pandas.io.formats.excel.header_style = None
    df.to_excel(writer, sheet_name='debug', index_label='html_file_hyperlink')
    wrap_format = writer.book.add_format({
        'text_wrap': True,
        'align': 'left',
        'valign': 'top'
    })
    headers_format = writer.book.add_format({'text_wrap': True, 'bold': True})
    hyperlink_format = writer.book.add_format({
        'text_wrap': True,
        'align': 'left',
        'valign': 'top',
        'font_color': 'Blue',
        'underline': True
    })
    debug_sheet = writer.sheets['debug']
    debug_sheet.freeze_panes(1, 1)
    debug_sheet.set_column('A:A', 20, hyperlink_format)
    debug_sheet.set_column('B:F', 60, wrap_format)
    debug_sheet.set_row(0, [], headers_format)
    for ii in range(1, n_rows + 2):
        debug_sheet.set_row(ii, 300)
    # fix for LibreOffice calc not showing hyperlinks properly: https://stackoverflow.com/questions/32205927/xlsxwriter-and-libreoffice-not-showing-formulas-result
    writer.save()
    logger.info('Finished summary output to XLSX: %s' % writer.path)
예제 #7
0
def __import(server_opts: ServerOpts, hub_password: str, vpn_opts: ToolOpts,
             group: str, certs_file: str, output_opts: OutputOpts):
    executor = VPNAuthExecutor(vpn_opts, server_opts, hub_password)
    data = JsonHelper.read(certs_file, strict=False)
    tmp_dir = FileHelper.tmp_dir('vpn_auth')
    command_file = FileHelper.touch(tmp_dir.joinpath('vpncmd.txt'))
    vpn_acc = {}
    for k, v in data.items():
        cert_file = tmp_dir.joinpath(f'{k}.cert')
        FileHelper.write_file(cert_file, v['cert_key'])
        commands = [
            f'CAAdd /{cert_file}',
            f'UserCreate {k} /GROUP:{group or "none"} /RealName:none /Note:none',
            f'UserSignedSet {k} /CN:{v["fqdn"]} /SERIAL:{v["serial_number"]}'
        ]
        vpn_acc[k] = {
            'vpn_server': server_opts.host,
            'vpn_port': server_opts.port,
            'vpn_hub': server_opts.hub,
            'vpn_account': server_opts.hub,
            'vpn_auth_type': 'cert',
            'vpn_user': k,
            'vpn_cert_key': v['cert_key'],
            'vpn_private_key': v['private_key'],
        }
        FileHelper.write_file(command_file,
                              '\n'.join(commands) + '\n',
                              append=True)
    executor.exec_command(f'/IN:{command_file}', log_lvl=logger.INFO)
    logger.sep(logger.INFO)
    out = output_opts.make_file(
        f'{server_opts.hub}-{output_opts.to_file("json")}')
    logger.info(f'Export VPN accounts to {out}...')
    JsonHelper.dump(out, vpn_acc)
    logger.done()
예제 #8
0
def demo():    
    import src.utils.logger as log
    import src.utils.path as pth
    import src.parser.toml as tml
    from src.utils.tools import download, extract
    
    
    # Downloading data from URLs and extracting downloaded files

    dry_url = tml.value('urls', section='demo', subkey='dry')
    fx_url = tml.value('urls', section='demo', subkey='fx')

    dry_dpath = tml.value('dnames', section='demo', subkey='input')
    fx_fname = tml.value('fx_name', section='demo')

    log.info("Downloading and extracting dataset and fx")

    fx_fpath = download(fx_url)
    pth.__rename_file(fx_fpath, fx_fname)
    
    if not pth.__exists(dry_dpath):
        dry_fpath = download(dry_url)
        extract(dry_fpath, dry_dpath)
    else:
        log.warning("\"{0}\" already exist, skipping dataset downloading".format(dry_dpath))

    run(dry_dpath, fx_fname, tml.value('dnames', section='demo', subkey='output'))
예제 #9
0
파일: main.py 프로젝트: TheWall9/DRHGCN
def test_fn(model, val_loader, save_file_format=None):
    device = model.device
    state = model.training
    model.eval()
    scores, labels, edges = [], [], []
    for batch in val_loader:
        batch = move_data_to_device(batch, device)
        output = model.step(batch)
        label, score = output["label"], output["predict"]
        edge = batch.interaction_pair[:, batch.valid_mask.reshape(-1)]
        scores.append(score.detach().cpu())
        labels.append(label.cpu())
        edges.append(edge.cpu())
    model.train(state)
    scores = torch.cat(scores).numpy()
    labels = torch.cat(labels).numpy()
    edges = torch.cat(edges, dim=1).numpy()
    eval_star_time_stamp = time.time()
    metric = metric_fn.evaluate(predict=scores, label=labels)
    eval_end_time_stamp = time.time()
    logger.info(f"eval time cost: {eval_end_time_stamp-eval_star_time_stamp}")
    if save_file_format is not None:
        save_file = save_file_format.format(aupr=metric["aupr"], auroc=metric["auroc"])
        scio.savemat(save_file, {"row": edges[0],
                      "col": edges[1],
                      "score": scores,
                      "label": labels,
                      })
        logger.info(f"save time cost: {time.time()-eval_end_time_stamp}")
    return scores, labels, edges, metric
    def load_dataset(self, dataloader: DataLoader) -> None:
        logger.info("loading dataset")
        # get shape of one batch data
        batch_shape = next(iter(dataloader))[0].shape
        self._init_mask(batch_shape)

        self._dataloader = dataloader
예제 #11
0
    def process(self, png_ready=True):
        coords, batch = [], []

        while self.i <= self.w and self.j <= self.h:
            if len(batch) < self.batch_size:
                patch, coord = self.get_patch()
                batch.append(patch)
                coords.append(coord)

                self.i += self.overlay
                if self.i > self.w:
                    self.i = 0
                    self.j += self.overlay

            else:
                coords, batch = self.consume_batch(coords, batch)

        self.consume_batch(coords, batch)
        if self.verbose:
            logger.info(f"The image was combined from {self.counter} patches")

        if png_ready:
            img = (self.pred_mask / self.norm_mask) * 255.
            img = img.astype('uint8')
            return img
        return self.pred_mask / self.norm_mask
예제 #12
0
def angle_and_depth(gazenet,
                    color_image_path: str,
                    depth_image_path: str,
                    is_product: bool = False):
    """
    Using frames at which the subject is looking straight at the shelf, get:
    - gaze vector
    - distance from shelf

    Arguments:
        gazenet {Gazenet} -- Face angle model
        config    {dict}  -- dict of paths to color and depth images

    Returns:
        [tuple[float], float] -- Tuple of yaw, pitch and roll and float of median depth
    """
    np_img = cv2.imread(color_image_path)
    face_bbox = face_detect(np_img)
    min_x, min_y, max_x, max_y = face_bbox

    if is_product: logger.info("Product bounding box: " + str(face_bbox))
    else: logger.info("Shelf bounding box: " + str(face_bbox))

    yaw, pitch, roll = gazenet.image_to_euler_angles(np_img, face_bbox)

    # get median depth
    depth_df = pd.DataFrame.from_csv(depth_image_path)
    median_depth = np.median(depth_df.iloc[min_y:max_y, min_x:max_x].values)

    return (yaw, pitch, roll), median_depth, (min_x, min_y, max_x, max_y)
예제 #13
0
def fit_once(model,
             model_name,
             loss,
             train,
             val,
             stage,
             n_fold,
             start_epoch,
             initial=False):
    logger.info(f'Stage {stage} started: loss {loss}, fold {n_fold}')
    steps_per_epoch = 500
    validation_steps = 100

    model.compile(optimizer=SGD(lr=0.01 if initial else 0.001,
                                clipvalue=4,
                                momentum=.9,
                                nesterov=True),
                  loss=loss,
                  metrics=['accuracy'])
    history = model.fit_generator(
        train,
        epochs=500,
        steps_per_epoch=steps_per_epoch,
        validation_data=val,
        workers=8,
        max_queue_size=32,
        use_multiprocessing=False,
        validation_steps=validation_steps,
        callbacks=get_callbacks(model_name, loss, stage, n_fold),
        initial_epoch=start_epoch,
    )
    return model, max(history.epoch)
    def check_db():
        """Check if database exists"""
        # get all databases
        all_dbs_list = InfluxService.db_client.get_list_database()

        # check if current database exists and if return warning message
        if InfluxService.cnf.INFLUX_DB not in [
                str(x['name']) for x in all_dbs_list
        ]:
            try:
                app_logger.warning("Database {0} does not exist".format(
                    InfluxService.cnf.INFLUX_DB))
            except exceptions.InfluxDBClientError as e:
                app_logger.error(str(e))
            except exceptions.InfluxDBServerError as e1:
                app_logger.error(str(e1))
        else:
            try:
                app_logger.info("Using db {0}".format(
                    InfluxService.cnf.INFLUX_DB))
                InfluxService.db_client.switch_database(
                    InfluxService.cnf.INFLUX_DB)
            except exceptions.InfluxDBClientError as e:
                app_logger.error(str(e))
            except exceptions.InfluxDBServerError as e1:
                app_logger.error(str(e1))
예제 #15
0
def install(vpn_opts: ClientOpts, svc_opts: UnixServiceOpts,
            auto_startup: bool, auto_dnsmasq: bool, dnsmasq: bool,
            auto_connman_dhcp: bool, force: bool):
    executor = VPNClientExecutor(vpn_opts).probe(log_lvl=logger.INFO)
    dns_resolver = executor.device.dns_resolver
    if not dnsmasq and not dns_resolver.is_connman():
        logger.error('Only support dnsmasq as DNS resolver in first version')
        sys.exit(ErrorCode.NOT_YET_SUPPORTED)
    if executor.is_installed(silent=True):
        if force:
            logger.warn(
                'VPN service is already installed. Try to remove then reinstall...'
            )
            executor.do_uninstall(keep_vpn=False, keep_dnsmasq=True)
        else:
            logger.error('VPN service is already installed')
            sys.exit(ErrorCode.VPN_ALREADY_INSTALLED)
    if dnsmasq and not dns_resolver.is_dnsmasq_available(
    ) and not dns_resolver.is_connman():
        executor.device.install_dnsmasq(auto_dnsmasq)
    logger.info(
        f'Installing VPN client into [{vpn_opts.vpn_dir}] and register service[{svc_opts.service_name}]...'
    )
    executor.do_install(svc_opts, auto_startup, auto_connman_dhcp)
    logger.done()
예제 #16
0
def __dns(vpn_opts: ClientOpts, nic: str, reason: str, new_nameservers: str,
          old_nameservers: str, debug: bool):
    logger.info(f'Discover DNS with {reason}::{nic}...')
    _reason = DHCPReason[reason]
    if not vpn_opts.is_vpn_nic(nic):
        logger.warn(f'NIC[{nic}] does not belong to VPN service')
        sys.exit(0)
    executor = VPNClientExecutor(
        vpn_opts, adhoc_task=True).require_install().probe(silent=True,
                                                           log_lvl=logger.INFO)
    current = executor.storage.get_current(info=True)
    if not current:
        current = executor.storage.find(executor.opts.nic_to_account(nic))
        if not current:
            logger.warn(f'Not found any VPN account')
            sys.exit(ErrorCode.VPN_ACCOUNT_NOT_FOUND)
    if executor.opts.nic_to_account(nic) != current.account:
        logger.warn(f'NIC[{nic}] does not meet current VPN account')
        sys.exit(ErrorCode.VPN_ACCOUNT_NOT_MATCH)
    if debug:
        now = datetime.now().isoformat()
        FileHelper.write_file(
            FileHelper.tmp_dir().joinpath('vpn_dns'),
            append=True,
            content=
            f"{now}::{reason}::{nic}::{new_nameservers}::{old_nameservers}\n")
    executor.device.dns_resolver.resolve(executor.vpn_service, _reason,
                                         current.hub, new_nameservers,
                                         old_nameservers)
예제 #17
0
def __last(file):
    with open(file, 'rb') as f:
        f.seek(-2, os.SEEK_END)
        while f.read(1) != b'\n':
            f.seek(-2, os.SEEK_CUR)
        last_line = f.readline().decode()
        logger.info(last_line.strip())
def save_checkpoint(state, is_best, checkpoint_dir, best_model_dir):
    f_path = checkpoint_dir / 'skipgram_embeddings_checkpoint.pt'
    logger.info("Saving checkpoint to {}".format(f_path))
    torch.save(state, f_path)
    if is_best:
        best_fpath = best_model_dir / 'skipgram_embeddings_best_model.pt'
        logger.info("Saving checkpoint as best model")
        shutil.copyfile(f_path, best_fpath)
예제 #19
0
    def connectionMade(self):
        """
        What to do when we get a connection.
        """

        self.session = Session(self)
        logger.info('New connection: %s' % self)
        self._session_manager.add_session(self.session)
        self.session.after_session_connect_event()
예제 #20
0
def get_glue_config(cfg: CN, model_args, name: str):
    num_labels = glue_tasks_num_labels[name]
    logger.info(f"Num {name} Labels: \t {num_labels}")

    return (AutoConfig.from_pretrained(
        model_args.model_name_or_path,
        num_labels=num_labels,
        finetuning_task=name
    ),)
예제 #21
0
파일: telnet.py 프로젝트: gtaylor/dott
    def connectionMade(self):
        """
        What to do when we get a connection.
        """

        self.session = Session(self)
        logger.info('New connection: %s' % self)
        self._session_manager.add_session(self.session)
        self.session.after_session_connect_event()
예제 #22
0
def out(data, overwrite=False, output='-'):
    if output == '-':
        for d in data:
            logger.info(d)
    else:
        with open(output, 'w+' if overwrite else 'a+') as f:
            for d in data:
                f.write(d + '\n')
        logger.success(f'Output: {output}')
예제 #23
0
def speaker_data_to_xlsx(df, xlsx_filename):
    # save speaker stats to xlsx
    writer = pd.ExcelWriter(xlsx_filename, engine='xlsxwriter')
    pandas.io.formats.excel.header_style = None
    df.to_excel(writer, sheet_name='speaker_stats')
    debug_sheet = writer.sheets['speaker_stats']
    debug_sheet.freeze_panes(1, 1)
    debug_sheet.set_column('B:C', 55)
    writer.save()
    logger.info('Finished summary output to XLSX: %s' % writer.path)
예제 #24
0
def delete(vpn_opts: ClientOpts, accounts):
    logger.info(
        f'Delete VPN account [{accounts}] and stop/disable VPN service if it\'s a current VPN connection...'
    )
    if accounts is None or len(accounts) == 0:
        logger.error('Must provide at least account')
        sys.exit(ErrorCode.INVALID_ARGUMENT)
    VPNClientExecutor(vpn_opts).require_install().probe(
        log_lvl=logger.INFO).do_delete(accounts)
    logger.done()
예제 #25
0
def add_trust_server(vpn_opts: ClientOpts, account: str, cert_key: str):
    logger.info("Enable Trust VPN Server on VPN client...")
    VPNClientExecutor(vpn_opts,
                      adhoc_task=True).require_install().probe().exec_command({
                          'AccountServerCertSet':
                          f'{account} /LOADCERT:{cert_key}',
                          'AccountServerCertEnable':
                          account
                      })
    logger.done()
예제 #26
0
파일: telnet.py 프로젝트: gtaylor/dott
    def connectionLost(self, reason):
        """
        Execute this when a client abruplty loses their connection.

        :param basestring reason: A short reason as to why they disconnected.
        """

        self.session.after_session_disconnect_event()
        logger.info('Disconnected: %s, %s' % (self, reason))
        self.disconnectClient()
예제 #27
0
    def _get_same_img(self, id_, landmark_id):
        subset = self.data[(self.data['landmark_id'] == landmark_id) & (self.data['id'] != id_)]
        l = subset.shape[0]
        if not l:
            logger.info(f'There is no same images as {landmark_id}')
            return self._get_other_img(id_, landmark_id)

        idx = np.random.randint(0, l)
        row = subset.iloc[idx]
        return row['id'], row['landmark_id'], 1
예제 #28
0
    def connectionLost(self, reason):
        """
        Execute this when a client abruplty loses their connection.

        :param basestring reason: A short reason as to why they disconnected.
        """

        self.session.after_session_disconnect_event()
        logger.info('Disconnected: %s, %s' % (self, reason))
        self.disconnectClient()
예제 #29
0
    def remove_session(self, session):
        """
        Removes a session from the session list.
        """

        try:
            self._sessions.remove(session)
            logger.info('Sessions active: %d' % len(self.get_sessions()))
        except ValueError:
            # the session was already removed. Probably garbage collected.
            return
예제 #30
0
    def remove_session(self, session):
        """
        Removes a session from the session list.
        """

        try:
            self._sessions.remove(session)
            logger.info('Sessions active: %d' % len(self.get_sessions()))
        except ValueError:
            # the session was already removed. Probably garbage collected.
            return
예제 #31
0
 def restore_config(self, backup_dir: Path, keep_backup: bool):
     logger.info(
         f'Restore VPN configuration [{backup_dir}] to [{self.opts.vpn_dir}]...'
     )
     FileHelper.copy(backup_dir.joinpath(self.opts.VPN_CONFIG_FILE),
                     self.opts.config_file,
                     force=True)
     FileHelper.copy(backup_dir.joinpath(self.opts.RUNTIME_FOLDER),
                     self.opts.runtime_dir,
                     force=True)
     FileHelper.rm(backup_dir, force=not keep_backup)
예제 #32
0
 def reset_hook(self, vpn_nameserver_hook_conf: Path):
     logger.info(f'Reset VPN DNS config file...')
     if FileHelper.is_writable(vpn_nameserver_hook_conf):
         FileHelper.write_file(vpn_nameserver_hook_conf,
                               mode=0o644,
                               content='')
         FileHelper.create_symlink(vpn_nameserver_hook_conf,
                                   self._dnsmasq_vpn_hook_cfg,
                                   force=True)
     else:
         FileHelper.rm(self._dnsmasq_vpn_hook_cfg)
예제 #33
0
 def remove(self, svc_opts: UnixServiceOpts, force: bool = False):
     service_fqn = self.to_service_fqn(svc_opts.service_dir,
                                       svc_opts.service_name)
     self.stop(svc_opts.service_name)
     self.disable(svc_opts.service_name)
     if force and FileHelper.is_exists(service_fqn):
         logger.info(f'Remove System service [{svc_opts.service_name}]...')
         FileHelper.rm(service_fqn)
     SystemHelper.exec_command("systemctl daemon-reload",
                               silent=True,
                               log_lvl=logger.INFO)
예제 #34
0
파일: session.py 프로젝트: gtaylor/dott
    def login(self, account):
        """
        After the user has authenticated, this actually logs them in. Attaches
        the Session to the account's default PlayerObject instance.
        """

        # set the session properties
        self.account = account
        self.conn_time = datetime.time()
        self.interactive_shell = None

        logger.info("Logged in: %s" % self.account.username)

        controlled_id = self.account.currently_controlling_id
        object_sessions = self._session_manager.get_sessions_for_object_id(
            controlled_id)

        if not self._mud_service.is_connected_to_mud_server():
            # Proxy is not connected to MUD server, we can't go any further
            # with events done at time of connection.
            return

        # This command runs on the MUD server letting it know that the player
        # is logging in. If no PlayerObject exists for this account yet,
        # one will be created. The ID of the object that the player controls
        # will always be returned, whether new or old.
        results = yield self._mud_service.proxyamp.callRemote(
            OnSessionConnectToObjectCmd,
            account_id=self.account.id,
            # A -1 value means a new object will be created for the player
            # to control. Most likely their first time logging in.
            controlling_id=self.account.currently_controlling_id or -1,
            username=self.account.username,
        )
        # The ID of the object that the account controls in-game.
        controlled_id = results['object_id']
        if self.account.currently_controlling_id != controlled_id:
            self.account.currently_controlling_id = controlled_id
            yield self.account.save()

        if len(object_sessions) == 1:
            # This is the only Session controlling the object it is associated
            # with. Trigger the 'at connect' event on the object.
            yield self._mud_service.proxyamp.callRemote(
                NotifyFirstSessionConnectedOnObjectCmd,
                object_id=controlled_id,
            )

        self.execute_command('look')
예제 #35
0
파일: db_io.py 프로젝트: gtaylor/dott
    def load_objects_into_store(self, loader_func):
        """
        Loads all of the objects from the DB into RAM.

        :param function loader_func: The function to run on the instantiated
            BaseObject sub-classes.
        """

        logger.info("Loading objects into store.")

        results = yield self._db.runQuery(self.BASE_OBJECT_SELECT)

        for row in results:
            # Given an object ID and a JSON str, load this object into the store.
            loader_func(self.instantiate_object_from_row(row))
예제 #36
0
    def _workForTable(self, table, fieldsSchema):
        inputConfig = self.ctx.getInputConfig()
        numOfDocsToGen = inputConfig["includeTables"][table]["seedSize"]
        logger.info("Will generate {} documents for {}..".format(numOfDocsToGen, table))

        # Following code does following:
        # (1) Creates given "numbers of" documents "in batch" usign given "schema of table and it's fields"
        # (2) Fields schema has info on what seeder to call
        # (3) Finally it returns the list of dict
        numOfDocsWorked = 0
        diff = numOfDocsToGen - numOfDocsWorked
        while diff > 0:
            localBatchCount = DataGen.kDocBatchCount if DataGen.kDocBatchCount < diff else diff
            numOfDocsWorked += localBatchCount
            diff = numOfDocsToGen - numOfDocsWorked
            docs = []
            while localBatchCount > 0:
                doc = {}
                for f, fSchema in fieldsSchema.items():
                    doc[f] = self.seeder.callSeederFunc(fSchema["seeder"], fSchema["seederArgs"])
                docs.append(doc)
                localBatchCount -= 1
            yield {"docs": docs, "table": table}
        cache.emptyCache()
예제 #37
0
 def testMysqlSeeder(self):
     logger.info("Initializing mysql integration testing components..")
     ctx = Context(self.conn, self.inputConfig)
     orderInfo, schemaForDatagen = SchemaBuilder(ctx).getSchemaForDataGen()
     logger.debug("Schema for data generation:\n{}".format(json.dumps(schemaForDatagen)))
     logger.debug("Will be worked in order:\n{}".format(json.dumps(orderInfo)))
     writer = Writer(ctx)
     dataGen = DataGen(ctx)
     for results in dataGen.generate(schemaForDatagen, orderInfo):
         logger.info("Writing {} documents into {}..".format(len(results["docs"]), results["table"]))
         writer.doBulkWrite(results["table"], results["docs"])
     logger.info("Finally, Done with it!")
예제 #38
0
 def add_session(self, session):
     """
     Adds a session to the session list.
     """
     self._sessions.insert(0, session)
     logger.info('Sessions active: %d' % len(self.get_sessions(return_unlogged=True)))
예제 #39
0
 def setUpClass(self):
     logger.debug("Setting up class..")
     self.inputConfig = {
         "engine":        "mysql",
         "host":          "localhost",
         "user":          "******",
         "database":      "jseeder",
         "password":      "******",
         "port":          3306,
         "includeTables": {
             "users": {
                 "seedSize":        10,
                 "excludeFields":   ["middle_name"],
                 "inclusionPolicy": "all", # "all"/"none" - Include all/ none fields, default - "none"
                 "includeFields":   {
                     "first_name": {
                         "seeder":     "j.fromList",
                         "seederArgs":  {
                             "l": ["jitendra", "kumar", "ojha"],
                             "inSerial": True
                         }
                     },
                     "last_name": {
                         "seeder":     "j.fromList",
                         "seederArgs":  {
                             "l": ["jitendra", "kumar", "ojha"],
                             "inSerial": True
                         }
                     },
                     "fav_num": {
                         "seeder":     "j.fromBetween",
                         "seederArgs":  {
                             "i": 0,
                             "j": 3,
                             "inSerial": False
                         }
                     },
                     "city_id": {
                         "seederArgs": {
                             "inSerial": True,
                             "offset":   3,
                             "limit":    5
                         }
                     }
                 }
             },
             "cities": {
                 "seedSize":        10,
                 "inclusionPolicy": "all",
                 "includeFields":   {
                     "name": {
                         "seeder": "j.fromList",
                         "seederArgs":  {
                             "l": ["Bangalore", "Patna"],
                             "inSerial": True
                         }
                     }
                 }
             }
         }
     }
     logger.info("Using following input config:\n{}".format(json.dumps(self.inputConfig)))
     self.conn = MySQLdb.connect(
         self.inputConfig["host"],
         self.inputConfig["user"],
         self.inputConfig["password"],
         self.inputConfig["database"],
         self.inputConfig["port"]
     )
     logger.info("Creating required test tables..")
     self.cursor = self.conn.cursor()
     sql = """CREATE TABLE cities (
             id INT PRIMARY KEY AUTO_INCREMENT NOT NULL,
             name  VARCHAR(20) NOT NULL)"""
     self.cursor.execute(sql)
     sql = """CREATE TABLE users (
             id INT PRIMARY KEY AUTO_INCREMENT NOT NULL,
             first_name  VARCHAR(20) NOT NULL,
             middle_name VARCHAR(20),
             last_name  VARCHAR(20),
             fav_num INT,
             city_id INT,
             CONSTRAINT fk_users_cities_city_id_id FOREIGN KEY (city_id) REFERENCES cities(id))"""
     self.cursor.execute(sql)
예제 #40
0
 def tearDownClass(self):
     logger.debug("Tearing down class..")
     logger.info("Droping all tests tables..")
     self.cursor.execute("DROP TABLE users")
     self.cursor.execute("DROP TABLE cities")
예제 #41
0
        logger.error(str(e))
        sys.exit(1)
    inputConfig = parseYamlFile(inputFile)

    if inputConfig["engine"] == "mysql":
        from src.schema_builders.mysql import MysqlSchemaBuilder as SchemaBuilder
        from src.writers.mysql import MysqlWriter as Writer
        from src.contexts.mysql import MysqlContext as Context
        conn = MySQLdb.connect(
            inputConfig["host"],
            inputConfig["user"],
            inputConfig["password"],
            inputConfig["database"],
            inputConfig["port"]
        )
        ctx = Context(conn, inputConfig)

    else:
        logger.error("Engine - {} not supported".format(inputConfig["engine"]))
        sys.exit(1)

    orderInfo, schemaForDatagen = SchemaBuilder(ctx).getSchemaForDataGen()
    logger.debug("Schema for data generation:\n{}".format(json.dumps(schemaForDatagen)))
    logger.debug("Will be worked in order:\n{}".format(json.dumps(orderInfo)))
    writer = Writer(ctx)
    dataGen = DataGen(ctx)
    for results in dataGen.generate(schemaForDatagen, orderInfo):
        logger.info("Writing {} documents into {}..".format(len(results["docs"]), results["table"]))
        writer.doBulkWrite(results["table"], results["docs"])
    logger.info("Finally, Done with it!")