Example #1
0
def main(args):

    config = utils.parse_yaml(args.config)
    assert config, 'without config I cannot do anything!'

    if config.get('debug'):
        logging.basicConfig(level=logging.DEBUG)
    else:
        logging.basicConfig(level=logging.INFO)

    port = int(config['port'])
    assert port > 0 and port < 65536, 'port must be in range (0, 65535]'

    host: str
    host = config.get('host', '0.0.0.0')

    tokens: "list[str]"
    tokens = config.get('tokens', [])

    valid_tokens: "list[bytes]"
    valid_tokens = []
    for token in tokens:
        data = hashlib.md5(token.encode('utf-8')).digest()
        valid_tokens.append(data)

    loopserver = server.LoopServer(port,
                                   host,
                                   valid_tokens=valid_tokens,
                                   bufsize=config.get('bufsize', 1024))
    loopserver.run()
Example #2
0
 def lp_drop(self, pf, target):
     """ drop module of linchpin cli :
     still need to fix the linchpin_config and outputs,
     inventory_outputs paths"""
     pf = parse_yaml(pf)
     init_dir = os.getcwd()
     e_vars = {}
     e_vars['linchpin_config'] = self.get_config_path()
     e_vars['outputfolder_path'] = init_dir+"/outputs"
     e_vars['inventory_outputs_path'] = init_dir + "/inventories"
     e_vars['state'] = "absent"
     if target == "all":
         for key in set(pf.keys()).difference(self.excludes):
             e_vars['topology'] = self.find_topology(pf[key]["topology"],
                                                     pf)
             output = invoke_linchpin(self.base_path,
                                      e_vars,
                                      "TEARDOWN",
                                      console=True)
     else:
         print(pf[target])
         if pf.get(target, False):
             topology_path = self.find_topology(pf[target]["topology"],
                                                pf)
             e_vars['topology'] = topology_path
             if e_vars['topology'] is None:
                 print("Topology not found !!")
             output = invoke_linchpin(self.base_path,
                                      e_vars,
                                      "TEARDOWN",
                                      console=True)
Example #3
0
def train(args):
    gpuid = tuple(map(int, args.gpus.split(',')))
    debug = args.debug
    logger.info(
        "Start training in {} model".format('debug' if debug else 'normal'))
    num_bins, config_dict = parse_yaml(args.config)
    reader_conf = config_dict["spectrogram_reader"]
    loader_conf = config_dict["dataloader"]
    dcnnet_conf = config_dict["model"]

    logger.info("Training with {}".format(
        "IRM" if reader_conf["apply_abs"] else "PSM"))
    batch_size = loader_conf["batch_size"]
    logger.info(
        "Training in {}".format("per utterance" if batch_size == 1 else
                                '{} utterance per batch'.format(batch_size)))

    train_loader = uttloader(config_dict["train_scp_conf"]
                             if not debug else config_dict["debug_scp_conf"],
                             reader_conf,
                             loader_conf,
                             train=True)
    valid_loader = uttloader(config_dict["valid_scp_conf"]
                             if not debug else config_dict["debug_scp_conf"],
                             reader_conf,
                             loader_conf,
                             train=False)
    checkpoint = config_dict["trainer"]["checkpoint"]
    logger.info("Training for {} epoches -> {}...".format(
        args.num_epoches,
        "default checkpoint" if checkpoint is None else checkpoint))

    nnet = PITNet(num_bins, **dcnnet_conf)
    trainer = PITrainer(nnet, **config_dict["trainer"], gpuid=gpuid)
    trainer.run(train_loader, valid_loader, num_epoches=args.num_epoches)
def main(country_iso3):

    logger.info(f'Creating graph for {country_iso3}')
    main_dir = os.path.join(MAIN_DIR, country_iso3)
    config = utils.parse_yaml(CONFIG_FILE)[country_iso3]

    # Make a graph
    G = nx.Graph()
    G.graph['country'] = country_iso3

    # Add exposure
    G = add_exposure(G, main_dir, country_iso3)

    # Add COVID cases
    G = add_covid(G, main_dir, country_iso3)

    # Add vulnerability
    G = add_vulnerability(G, main_dir, country_iso3)

    # Add contact matrix
    add_contact_matrix(G, config['contact_matrix'])

    # Write out
    data = nx.readwrite.json_graph.node_link_data(G)
    outdir = os.path.join(main_dir, OUTPUT_DIR)
    Path(outdir).mkdir(parents=True, exist_ok=True)
    outfile = os.path.join(main_dir, OUTPUT_DIR, OUTPUT_FILE.format(country_iso3))
    with open(outfile, 'w') as f:
        json.dump(data, f, indent=2)
    logger.info(f'Wrote out to {outfile}')
Example #5
0
def run(args):
    num_bins, config_dict = parse_yaml(args.config)
    dataloader_conf = config_dict["dataloader"]
    spectrogram_conf = config_dict["spectrogram_reader"]
    # Load cmvn
    dict_mvn = dataloader_conf["mvn_dict"]
    if dict_mvn:
        if not os.path.exists(dict_mvn):
            raise FileNotFoundError("Could not find mvn files")
        with open(dict_mvn, "rb") as f:
            dict_mvn = pickle.load(f)
    # default: True
    apply_log = dataloader_conf[
        "apply_log"] if "apply_log" in dataloader_conf else True

    dcnet = PITNet(num_bins, **config_dict["model"])

    frame_length = spectrogram_conf["frame_length"]
    frame_shift = spectrogram_conf["frame_shift"]
    window = spectrogram_conf["window"]

    separator = Separator(dcnet, args.state_dict, cuda=args.cuda)

    utt_dict = parse_scps(args.wave_scp)
    num_utts = 0
    for key, utt in utt_dict.items():
        try:
            samps, stft_mat = stft(utt,
                                   frame_length=frame_length,
                                   frame_shift=frame_shift,
                                   window=window,
                                   center=True,
                                   return_samps=True)
        except FileNotFoundError:
            print("Skip utterance {}... not found".format(key))
            continue
        print("Processing utterance {}".format(key))
        num_utts += 1
        norm = np.linalg.norm(samps, np.inf)
        spk_mask, spk_spectrogram = separator.seperate(stft_mat,
                                                       cmvn=dict_mvn,
                                                       apply_log=apply_log)

        for index, stft_mat in enumerate(spk_spectrogram):
            istft(os.path.join(args.dump_dir,
                               '{}.spk{}.wav'.format(key, index + 1)),
                  stft_mat,
                  frame_length=frame_length,
                  frame_shift=frame_shift,
                  window=window,
                  center=True,
                  norm=norm,
                  fs=8000,
                  nsamps=samps.size)
            if args.dump_mask:
                sio.savemat(
                    os.path.join(args.dump_dir,
                                 '{}.spk{}.mat'.format(key, index + 1)),
                    {"mask": spk_mask[index]})
    print("Processed {} utterance!".format(num_utts))
Example #6
0
def train(args):
    config_dict = parse_yaml(args.config)

    loader_config = config_dict["dataloader"]
    train_config = config_dict["trainer"]
    temp = config_dict["temp"]

    train_dataset = TasDataset(loader_config["train_path_npz"])
    valid_dataset = TasDataset(loader_config["valid_path_npz"])

    train_loader = DataLoader(train_dataset,
                              batch_size=loader_config["batch_size"],
                              shuffle=True,
                              num_workers=4,
                              drop_last=True,
                              pin_memory=True)
    valid_loader = DataLoader(valid_dataset,
                              batch_size=loader_config["batch_size"],
                              shuffle=True,
                              num_workers=4,
                              drop_last=True,
                              pin_memory=True)

    tasnet = TasNET()
    trainer = TasNET_trainer(tasnet, **train_config)
    trainer.run(train_loader, valid_loader)
Example #7
0
    def __init__(self, config_path):

        contents = utils.read_file(config_path)
        self.config = utils.parse_yaml(contents)

        aws_config = self.config.get('aws', {})
        region = os.environ.get('REGION') or aws_config.get('region')
        aws_access_key_id = os.environ.get('AWS_ACCESS_KEY_ID')
        aws_secret_access_key = os.environ.get('AWS_SECRET_ACCESS_KEY')
        subnet_ids = aws_config.get('subnet_ids') or []
        security_group_ids = aws_config.get('security_group_ids') or []
        role_name = os.environ.get('LAMBDA_EXECUTION_ROLE_NAME') or aws_config.get('lambda_execution_role_name')

        general_config = self.config.get('general', {})
        timeout_time = int(os.environ.get('LAMBDA_TIMEOUT_TIME') or general_config.get('lambda_timeout_time') or 10)

        log.debug('region=%s, role_name=%s' % (region, role_name))
        log.debug('timeout_time=%s' % timeout_time)
        self.awslambda = self.setup_lambda(region,
                                           role_name,
                                           timeout_time,
                                           aws_access_key_id,
                                           aws_secret_access_key,
                                           subnet_ids=subnet_ids,
                                           security_group_ids=security_group_ids)
        self.kinesis = self.setup_kinesis(region, aws_access_key_id, aws_secret_access_key)
        self.cwlogs = self.setup_cloud_watch_logs(region, aws_access_key_id, aws_secret_access_key)
def train(args):
    debug = args.debug
    logger.info(
        "Start training in {} model".format('debug' if debug else 'normal'))
    num_bins, config_dict = parse_yaml(args.config)
    reader_conf = config_dict["spectrogram_reader"]
    loader_conf = config_dict["dataloader"]
    dcnnet_conf = config_dict["dcnet"]

    batch_size = loader_conf["batch_size"]
    logger.info(
        "Training in {}".format("per utterance" if batch_size == 1 else
                                '{} utterance per batch'.format(batch_size)))

    train_loader = uttloader(
        config_dict["train_scp_conf"]
        if not debug else config_dict["debug_scp_conf"],
        reader_conf,
        loader_conf,
        train=True)
    valid_loader = uttloader(
        config_dict["valid_scp_conf"]
        if not debug else config_dict["debug_scp_conf"],
        reader_conf,
        loader_conf,
        train=False)
    checkpoint = config_dict["trainer"]["checkpoint"]
    logger.info("Training for {} epoches -> {}...".format(
        args.num_epoches, "default checkpoint"
        if checkpoint is None else checkpoint))

    dcnet = DCNet(num_bins, **dcnnet_conf)
    trainer = Trainer(dcnet, **config_dict["trainer"])
    trainer.run(train_loader, valid_loader, num_epoches=args.num_epoches)
Example #9
0
def train(args):
    config_dict = parse_yaml(args.config)

    loader_config = config_dict["dataloader"]
    train_config = config_dict["trainer"]
    temp = config_dict["temp"]

    train_dataset = TasDataset(loader_config["train_path_npz"])
    valid_dataset = TasDataset(loader_config["valid_path_npz"])

    train_loader = DataLoader(train_dataset,
                              batch_size=loader_config["batch_size"],
                              shuffle=True,
                              num_workers=4,
                              drop_last=True,
                              pin_memory=True)
    valid_loader = DataLoader(valid_dataset,
                              batch_size=loader_config["batch_size"],
                              shuffle=False,
                              num_workers=4,
                              drop_last=True,
                              pin_memory=True)

    tasnet = TasNET(batch_size=loader_config["batch_size"])
    trainer = TasNET_trainer(tasnet,
                             batch_size=loader_config["batch_size"],
                             **train_config)

    if train_config['rerun_mode'] == False:
        trainer.run(train_loader, valid_loader)
    else:
        trainer.rerun(train_loader, valid_loader, temp["model_path"],
                      temp["epoch_done"])
Example #10
0
def train(args):
    num_bins, config_dict = parse_yaml(args.config)
    # reader_conf = config_dict["spectrogram_reader"]
    loader_conf = config_dict["dataloader"]
    dcnet_conf = config_dict["dcnet"]
    train_config = config_dict["trainer"]

    train_dataset = SpectrogramDataset(loader_conf["train_path_npz"])
    valid_dataset = SpectrogramDataset(loader_conf["valid_path_npz"])

    train_loader = DataLoader(train_dataset,
                              batch_size=loader_conf["batch_size"],
                              shuffle=True,
                              num_workers=4,
                              drop_last=True,
                              pin_memory=True)
    valid_loader = DataLoader(valid_dataset,
                              batch_size=loader_conf["batch_size"],
                              shuffle=True,
                              num_workers=4,
                              drop_last=True,
                              pin_memory=True)

    chimera = chimeraNet(num_bins, **dcnet_conf)
    trainer = PerUttTrainer(chimera, args.alpha, **train_config)
    trainer.run(train_loader, valid_loader, num_epoches=args.num_epoches)
Example #11
0
def parse_rules(rules_file):
    with open(rules_file) as f:
        if rules_file.endswith(".yaml") or rules_file.endswith(".yml"):
            rule = parse_yaml(rules_file)
        else:
            rule = json.load(f)
    rule["_filename"] = os.path.basename(rules_file)
    return rule
Example #12
0
    def __init__(self, conf_path):

        #location of output files
        self.conf = parse_yaml(conf_path)
        self.task_dir = os.path.join(self.conf['outpath'], "task")
        #location of audio files
        self.audio_dir = os.path.join(self.conf['outpath'], "audio")
        #location of all the audio data
        self.data_dir = os.path.join(self.audio_dir, "data")
        #location of just the enrollment audio data
        self.enroll_dir = os.path.join(self.audio_dir, "enroll")
        #location of just the test audio data
        self.test_dir = os.path.join(self.audio_dir, "test")
Example #13
0
def train(args):
    debug = args.debug
    logger.info(
        "Start training in {} model".format('debug' if debug else 'normal'))
    num_bins, config_dict = parse_yaml(args.config)
    reader_conf = config_dict["spectrogram_reader"]
    loader_conf = config_dict["dataloader"]
    dcnnet_conf = config_dict["model"]
    state_dict = args.state_dict

    location = "cpu" if args.cpu else None

    logger.info("Training with {}".format("IRM" if reader_conf["apply_abs"]
                                          else "PSM"))
    batch_size = loader_conf["batch_size"]
    logger.info(
        "Training in {}".format("per utterance" if batch_size == 1 else
                                '{} utterance per batch'.format(batch_size)))

    train_loader = uttloader(
        config_dict["train_scp_conf"]
        if not debug else config_dict["debug_scp_conf"],
        reader_conf,
        loader_conf,
        train=True)
    valid_loader = uttloader(
        config_dict["valid_scp_conf"]
        if not debug else config_dict["debug_scp_conf"],
        reader_conf,
        loader_conf,
        train=False)
    checkpoint = config_dict["trainer"]["checkpoint"]
    logger.info("Training for {} epoches -> {}...".format(
        args.num_epoches, "default checkpoint"
        if checkpoint is None else checkpoint))

    nnet = PITNet(num_bins, **dcnnet_conf)

    if(state_dict != ""):
        if not os.path.exists(state_dict):
            raise ValueError("there is no path {}".format(state_dict))
        else:
            logger.info("load {}".format(state_dict))
            nnet.load_state_dict(th.load(state_dict, map_location=location))

    trainer = PITrainer(nnet, **config_dict["trainer"])
    trainer.run(train_loader, valid_loader, num_epoches=args.num_epoches, start=args.start)
Example #14
0
 def lp_rise(self, pf, target):
     pf = parse_yaml(pf)
     init_dir = os.getcwd()
     e_vars = {}
     e_vars['linchpin_config'] = self.get_config_path()
     e_vars['outputfolder_path'] = init_dir+"/outputs"
     e_vars['inventory_outputs_path'] = init_dir+"/inventories"
     e_vars['state'] = "present"
     if target == "all":
         for key in set(pf.keys()).difference(self.excludes):
             e_vars['topology'] = self.find_topology(pf[key]["topology"],
                                                     pf)
             if e_vars['topology'] is None:
                 print("Topology not found !!")
                 break
             if pf[key].has_key('layout'):
                 layout_path = self.find_layout(pf[key]["layout"], pf)
                 e_vars['inventory_layout_file'] = layout_path
                 if e_vars['inventory_layout_file'] is None:
                     print("Layout not found !!")
                     break
                 print(e_vars)
             output = invoke_linchpin(self.base_path,
                                      e_vars,
                                      "PROVISION",
                                      console=True)
     else:
         if pf.get(target, False):
             topology_path = self.find_topology(pf[target]["topology"],
                                                pf)
             e_vars['topology'] = topology_path
             if e_vars['topology'] is None:
                 print("Topology not found !!")
             if pf[key].has_key('layout'):
                 layout_path = self.find_layout(pf[target]["layout"], pf)
                 e_vars['inventory_layout_file'] = layout_path
                 if e_vars['inventory_layout_file'] is None:
                     print("Layout not found !!")
                 print(e_vars)
             output = invoke_linchpin(self.base_path,
                                      e_vars,
                                      "PROVISION",
                                      console=True)
         else:
             raise KeyError('Target not found in PinFile')
Example #15
0
 def __init__(self, conf_path):
     """
     This method parses the YAML configuration file which can be used for
     initializing the member varaibles!!
     Args:
         conf_path (String): path of the YAML configuration file
     """
     
     #location of output files
     self.conf = parse_yaml(conf_path)
     self.task_dir = os.path.join(self.conf['outpath'], "task")
     #location of audio files
     self.audio_dir = os.path.join(self.conf['outpath'], "audio")
     #location of all the audio data
     self.data_dir = os.path.join(self.audio_dir, "data")
     #location of just the enrollment audio data
     self.enroll_dir = os.path.join(self.audio_dir, "enroll")
     #location of just the test audio data
     self.test_dir = os.path.join(self.audio_dir, "test")
Example #16
0
def run(args):
    num_bins, conf_dict = parse_yaml(args.train_conf)
    reader = SpectrogramReader(args.wave_scp, **conf_dict["spectrogram_reader"])
    mean = np.zeros(num_bins)
    std = np.zeros(num_bins)
    num_frames = 0
    # D(X) = E(X^2) - E(X)^2
    for _, spectrogram in tqdm.tqdm(reader):
        num_frames += spectrogram.shape[0]
        mean += np.sum(spectrogram, 0)
        std += np.sum(spectrogram**2, 0)
    mean = mean / num_frames
    std = np.sqrt(std / num_frames - mean**2)
    with open(args.cmvn_dst, "wb") as f:
        cmvn_dict = {"mean": mean, "std": std}
        pickle.dump(cmvn_dict, f)
    print("Totally processed {} frames".format(num_frames))
    print("Global mean: {}".format(mean))
    print("Global std: {}".format(std))
def run(args):
    num_bins, conf_dict = parse_yaml(args.train_conf)
    reader = SpectrogramReader(args.wave_scp,
                               **conf_dict["spectrogram_reader"])
    mean = np.zeros(num_bins)
    std = np.zeros(num_bins)
    num_frames = 0
    # D(X) = E(X^2) - E(X)^2
    for _, spectrogram in tqdm.tqdm(reader):
        num_frames += spectrogram.shape[0]
        mean += np.sum(spectrogram, 0)
        std += np.sum(spectrogram**2, 0)
    mean = mean / num_frames
    std = np.sqrt(std / num_frames - mean**2)
    with open(args.cmvn_dst, "wb") as f:
        cmvn_dict = {"mean": mean, "std": std}
        pickle.dump(cmvn_dict, f)
    print("Totally processed {} frames".format(num_frames))
    print("Global mean: {}".format(mean))
    print("Global std: {}".format(std))
    def __init__(self, conf_path):
        #parse the YAML configuration file
        self.conf = parse_yaml(conf_path)
        self.audio_dir = os.path.join(self.conf['outpath'], "audio") #input dir
        self.feat_dir = os.path.join(self.conf['outpath'], "feat")
        # Number of parallel threads
        self.NUM_THREADS = cpu_count()

        self.FEAUTRES = self.conf['features']
        self.FILTER_BANK = self.conf['filter_bank']
        self.FILTER_BANK_SIZE = self.conf['filter_bank_size']
        self.LOWER_FREQUENCY = self.conf['lower_frequency']
        self.HIGHER_FREQUENCY = self.conf['higher_frequency']
        self.VAD = self.conf['vad']
        self.SNR_RATIO = self.conf['snr_ratio'] if self.VAD=="snr" else None
        # cepstral coefficients
        self.WINDOW_SIZE = self.conf['window_size']
        self.WINDOW_SHIFT = self.conf['window_shift']
        self.CEPS_NUMBER = self.conf['cepstral_coefficients']
        # reset unnecessary ones based on given configuration
        self.review_member_variables()
Example #19
0
 def lp_drop(self, lpf, target):
     """ drop module of linchpin cli :
     still need to fix the linchpin_config and outputs,
     inventory_outputs paths"""
     lpf = parse_yaml(lpf)
     init_dir = os.getcwd()
     e_vars = {}
     e_vars['linchpin_config'] = self.get_config_path()
     e_vars['inventory_outputs_path'] = init_dir + "/inventory"
     e_vars['state'] = "absent"
     if target == "all":
         for key in set(lpf.keys()).difference(self.excludes):
             e_vars['topology'] = self.find_topology(lpf[key]["topology"],
                                                     lpf)
             topo_name = lpf[key]["topology"].strip(".yml").strip(".yaml")
             output_path = init_dir + "/outputs/" + topo_name + ".output"
             e_vars['topology_output_file'] = output_path
             output = invoke_linchpin(self.base_path,
                                      e_vars,
                                      "TEARDOWN",
                                      console=True)
     else:
         print(lpf[target])
         if lpf.get(target, False):
             topology_path = self.find_topology(lpf[target]["topology"],
                                                lpf)
             e_vars['topology'] = topology_path
             if e_vars['topology'] is None:
                 print("Topology not found !!")
             topo_name = lpf[target]["topology"]
             topo_name = topo_name.strip(".yml").strip(".yaml")
             output_path = init_dir + "/outputs/" + topo_name + ".output"
             e_vars['topology_output_file'] = output_path
             output = invoke_linchpin(self.base_path,
                                      e_vars,
                                      "TEARDOWN",
                                      console=True)
Example #20
0
def run(args):
    num_bins, config_dict = parse_yaml(args.config)
    # Load cmvn
    dict_mvn = config_dict["dataloader"]["mvn_dict"]
    if dict_mvn:
        if not os.path.exists(dict_mvn):
            raise FileNotFoundError("Could not find mvn files")
        with open(dict_mvn, "rb") as f:
            dict_mvn = pickle.load(f)

    dcnet = DCNet(num_bins, **config_dict["dcnet"])

    frame_length = config_dict["spectrogram_reader"]["frame_length"]
    frame_shift = config_dict["spectrogram_reader"]["frame_shift"]
    window = config_dict["spectrogram_reader"]["window"]

    cluster = DeepCluster(dcnet,
                          args.dcnet_state,
                          args.num_spks,
                          pca=args.dump_pca,
                          cuda=args.cuda)

    utt_dict = parse_scps(args.wave_scp)
    num_utts = 0
    for key, utt in utt_dict.items():
        try:
            samps, stft_mat = stft(utt,
                                   frame_length=frame_length,
                                   frame_shift=frame_shift,
                                   window=window,
                                   center=True,
                                   return_samps=True)
        except FileNotFoundError:
            print("Skip utterance {}... not found".format(key))
            continue
        print("Processing utterance {}".format(key))
        num_utts += 1
        norm = np.linalg.norm(samps, np.inf)
        pca_mat, spk_mask, spk_spectrogram = cluster.seperate(stft_mat,
                                                              cmvn=dict_mvn)

        for index, stft_mat in enumerate(spk_spectrogram):
            istft(os.path.join(args.dump_dir,
                               '{}.spk{}.wav'.format(key, index + 1)),
                  stft_mat,
                  frame_length=frame_length,
                  frame_shift=frame_shift,
                  window=window,
                  center=True,
                  norm=norm,
                  fs=8000,
                  nsamps=samps.size)
            if args.dump_mask:
                sio.savemat(
                    os.path.join(args.dump_dir,
                                 '{}.spk{}.mat'.format(key, index + 1)),
                    {"mask": spk_mask[index]})
        if args.dump_pca:
            sio.savemat(os.path.join(args.dump_dir, '{}.mat'.format(key)),
                        {"pca_matrix": pca_mat})
    print("Processed {} utterance!".format(num_utts))
def main(country_iso3, download_covid=False):
    # Get config file
    config = utils.parse_yaml(CONFIG_FILE)[country_iso3]

    # Get input covid file
    input_dir = os.path.join(DIR_PATH, INPUT_DIR, country_iso3)
   
    # Download latest covid file tiles and read them in
    if download_covid:
        get_covid_data(config['covid'], country_iso3, input_dir)
    df_covid = pd.read_csv('{}/{}'.format(os.path.join(input_dir, COVID_DIR),\
                            config['covid']['filename']), header=config['covid']['header'],\
                            skiprows=config['covid']['skiprows'])
    # convert to standard HLX
    if 'hlx_dict' in config['covid']:
        df_covid=df_covid.rename(columns=config['covid']['hlx_dict'])

    # in some files we have province explicitely
    df_covid[HLX_TAG_ADM1_NAME]= df_covid[HLX_TAG_ADM1_NAME].str.replace(' Province','')
    if 'replace_dict' in config['covid']:
        df_covid[HLX_TAG_ADM1_NAME] = df_covid[HLX_TAG_ADM1_NAME].replace(config['covid']['replace_dict'])
    
    # convert to float
    # TODO check conversions
    if df_covid[HLX_TAG_TOTAL_CASES].dtype == 'object':
        df_covid[HLX_TAG_TOTAL_CASES]=df_covid[HLX_TAG_TOTAL_CASES].str.replace(',','')
    df_covid[HLX_TAG_TOTAL_CASES]=pd.to_numeric(df_covid[HLX_TAG_TOTAL_CASES],errors='coerce')
    if df_covid[HLX_TAG_TOTAL_CASES].dtype == 'object':
       df_covid[HLX_TAG_TOTAL_DEATHS]=df_covid[HLX_TAG_TOTAL_DEATHS].str.replace('-','')
    df_covid[HLX_TAG_TOTAL_DEATHS]=pd.to_numeric(df_covid[HLX_TAG_TOTAL_DEATHS],errors='coerce')

    df_covid.fillna(0,inplace=True)
    
    # Get exposure file
    try:
        exposure_file=f'{DIR_PATH}/{EXP_DIR.format(country_iso3)}/{EXP_FILE.format(country_iso3)}'
        exposure_gdf=gpd.read_file(exposure_file)
    except:
        logger.info(f'Cannot get exposure file for {country_iso3}, COVID file not generate')
    
    # add pcodes
    ADM1_names = dict()
    for k, v in exposure_gdf.groupby('ADM1_EN'):
        ADM1_names[k] = v.iloc[0,:].ADM1_PCODE
    df_covid[HLX_TAG_ADM1_PCODE]= df_covid[HLX_TAG_ADM1_NAME].map(ADM1_names)
    if(df_covid[HLX_TAG_ADM1_PCODE].isnull().sum()>0):
        logger.info('missing PCODE for the following admin units ',df_covid[df_covid[HLX_TAG_ADM1_PCODE].isnull()])
    #recalculate total for each ADM1 unit
    gender_age_groups = list(itertools.product(GENDER_CLASSES, AGE_CLASSES))
    gender_age_group_names = ['{}_{}'.format(gender_age_group[0], gender_age_group[1]) for gender_age_group in
                              gender_age_groups]

    # TODO fields should depend on country
    output_df_covid=pd.DataFrame(columns=[HLX_TAG_ADM1_PCODE,
                                          HLX_TAG_ADM2_PCODE,
                                          HLX_TAG_DATE,
                                          HLX_TAG_TOTAL_CASES,
                                          HLX_TAG_TOTAL_DEATHS])

    # make a loop over reported cases and downscale ADM1 to ADM2
    # print(df_covid.sum())
    for _, row in df_covid.iterrows():
        adm2_pop_fractions=get_adm2_to_adm1_pop_frac(row[HLX_TAG_ADM1_PCODE],exposure_gdf,gender_age_group_names)
        adm1pcode=row[HLX_TAG_ADM1_PCODE]
        date=row[HLX_TAG_DATE]
        adm1cases=row[HLX_TAG_TOTAL_CASES]
        adm1deaths=row[HLX_TAG_TOTAL_DEATHS]
        adm2cases=[v*adm1cases for v in adm2_pop_fractions.values()]
        adm2deaths=[v*adm1deaths for v in adm2_pop_fractions.values()]
        adm2pcodes=[v for v in adm2_pop_fractions.keys()]
        raw_data = {HLX_TAG_ADM1_PCODE:adm1pcode,
                    HLX_TAG_ADM2_PCODE:adm2pcodes,
                    HLX_TAG_DATE:date,
                    HLX_TAG_TOTAL_CASES:adm2cases,
                    HLX_TAG_TOTAL_DEATHS:adm2deaths}
        output_df_covid=output_df_covid.append(pd.DataFrame(raw_data),ignore_index=True)
    
    # cross-check: the total must match
    if(abs((output_df_covid[HLX_TAG_TOTAL_CASES].sum()-\
        df_covid[HLX_TAG_TOTAL_CASES].sum()))>10):
        logger.info('WARNING The sum of input and output files don\'t match')

    # Write to file
    output_df_covid['created_at'] = str(datetime.datetime.now())
    output_df_covid['created_by'] = getpass.getuser()
    output_csv = get_output_filename(country_iso3)
    logger.info(f'Writing to file {output_csv}')
    output_df_covid.to_csv(f'{DIR_PATH}/{output_csv}',index=False)
Example #22
0
def test(args):
    config_dict = parse_yaml(args.config)

    loader_config = config_dict["dataloader"]
    train_config = config_dict["trainer"]
    test_config = config_dict['test']
    data_config = config_dict["data_generator"]
    temp = config_dict["temp"]

    test_path  = test_config['test_load_path']
    test_save_path = test_config["test_save_path"]
    sr = data_config["sr"]
    N_L = data_config["N_L"]
    test_len_time = test_config["test_len_time"]

    #find test dirs
    test_dirs = find_files(os.path.join(test_path,'mix/'))
    test_dirs.sort()
    #load Tasnet model
    tasnet = TasNET(batch_size=test_config["test_batch_size"])
    tasnet.to(device)
    tasnet.load_state_dict(torch.load(test_config["test_model_path"]))
    tasnet.eval()

    logger.info("Testing...")
    #initialize
    num_test = 0
    tot = 0
    low = 0
    sdr_list = []    

    #Start test 
    with torch.no_grad():
        for test_dir in test_dirs:
            name = test_dir.split('/')[-1]
            speech1_dir = os.path.join(test_path,'s1/'+name)
            speech2_dir = os.path.join(test_path,'s2/'+name)

            #load mix, s1 and s2 data
            mix, _ = librosa.load(test_dir,sr)
            real1, _ = librosa.load(speech1_dir,sr)
            real2, _ = librosa.load(speech2_dir,sr)

            #save the mix data in target dir
            save_dir_mix = os.path.join(test_save_path,
            							"mix/"+name)
            librosa.output.write_wav(save_dir_mix,mix,sr)
            
            #process data before the Tasnet
            len_mix = len(mix)
            mix = make_same_length(mix, N_L)
            mix = np.reshape(mix, [1,-1,N_L])

            #Separate mix audio with Tasnet
            mix = torch.from_numpy(mix)        
            if torch.cuda.is_available():
            	mix = mix.cuda()
            mix = Variable(mix) 
            speech1,speech2 = tasnet(mix)

            #translate the output to numpy in cpu	
            wave1 = speech1.to(torch.device("cpu"))
            wave2 = speech2.to(torch.device("cpu"))
            wave1 = wave1.view(-1,)
            wave2 = wave2.view(-1,)
            wave1 = zero_mean(wave1[:len_mix].numpy())/np.max(wave1[:len_mix].numpy())
            wave2 = zero_mean(wave2[:len_mix].numpy())/np.max(wave2[:len_mix].numpy())

            #Calculate the SDR with bss tools
            wave = [wave1,wave2]
            estimate = np.array(wave)
            real = [real1,real2]
            reference = np.array(real)
            sdr,sir,sar,_ = bss_eval_sources(estimate,reference) 
            sdr_list.append(np.mean(sdr))

            #Count the number of SDR lower than 5 and calculate the mean SDR
            if np.mean(sdr) < 5:
                low +=1
            num_test += 1
            tot += sdr
            mean = np.mean(tot)/(num_test)

            #Save the separated audio in the target dir
            save_dir1 = os.path.join(test_save_path,
            						"s1/"+name)
            save_dir2 = os.path.join(test_save_path,
            						"s2/"+name)
            librosa.output.write_wav(save_dir1,wave1,sr)
            librosa.output.write_wav(save_dir2,wave2,sr)
            
            if num_test%10 == 0:
                logger.info("The current SDR was {}/{}".format(mean, num_test))
                logger.info("SDR lower than 5 were {}/{}".format(low,num_test))

    #Print the SDR in the figure
    logger.info("Testing for all {} waves have done!".format(num_test))
    logger.info("The total mean SDR is {}".format(mean))
    xData = np.arange(1, len(sdr_list)+1, 1)
    sdr_list.sort()  
    yData = sdr_list
    plt.figure(num=1, figsize=(8, 6))
    plt.title('SDR of test samples', size=14)
    plt.xlabel('index', size=14)
    plt.ylabel('SDR', size=14)
    print(yData)
    plt.plot(xData, yData, color='b', linestyle='--', marker='o')
    plt.savefig('plot.png', format='png')
Example #23
0
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from utils import parse_yaml, get_user_state_from_database, get_current_grade_page, content_changed, \
    create_user_in_database, set_user_password_to_database, set_user_name_to_database, get_all_registered_users, \
    get_current_grades_as_images, init_sqlite_table
from RepeatedFunction import RepeatedFunction
import telebot
import os

# Init variables
yaml_path = "config/config.yaml"

# Initialize base objects
yaml_object = parse_yaml(yaml_path)
bot = telebot.TeleBot(yaml_object['telegram_token'])


def send_update():
    user_list = get_all_registered_users()
    for chat_id in user_list:
        grade_page = get_current_grade_page(chat_id)
        content_has_changed = content_changed(grade_page, chat_id)
        if content_has_changed:
            bot.send_message(chat_id, "New grades are online!")


@bot.message_handler(commands=['start', 'help'])
def send_welcome(message):
    bot.reply_to(
        message, "Hello! Use /enterdata to enter your user information \n "
        "You will receive a notification when your grade table changes. \n"
def main(country_iso3, download_worldpop=False):

    # Get config file
    config = utils.parse_yaml(CONFIG_FILE)[country_iso3]

    # Get input boundary shape file
    input_dir = os.path.join(DIR_PATH, INPUT_DIR, country_iso3)
    input_shp = os.path.join(input_dir, SHAPEFILE_DIR,
                             config['admin']['directory'],
                             f'{config["admin"]["directory"]}.shp')
    ADM2boundaries = gpd.read_file(input_shp)

    # Download the worldpop data
    if download_worldpop:
        get_worldpop_data(country_iso3, input_dir)

    # gender and age groups
    gender_age_groups = list(itertools.product(GENDER_CLASSES, AGE_CLASSES))
    for gender_age_group in gender_age_groups:
        gender_age_group_name = f'{gender_age_group[0]}_{gender_age_group[1]}'
        logger.info(f'analyising gender age {gender_age_group_name}')
        input_tif_file = os.path.join(
            input_dir, WORLDPOP_DIR, WORLDPOP_FILENAMES['sadd'].format(
                country_iso3=country_iso3.lower(),
                gender=gender_age_group[0],
                age=gender_age_group[1]))
        zs = zonal_stats(input_shp, input_tif_file, stats='sum')
        total_pop = [district_zs.get('sum') for district_zs in zs]
        ADM2boundaries[gender_age_group_name] = total_pop

    # total population for cross check
    logger.info('adding total population')
    input_tiff_pop = os.path.join(
        input_dir, WORLDPOP_DIR,
        WORLDPOP_FILENAMES['pop'].format(country_iso3=country_iso3.lower()))
    zs = zonal_stats(input_shp, input_tiff_pop, stats='sum')
    total_pop = [district_zs.get('sum') for district_zs in zs]
    ADM2boundaries['tot_pop_WP'] = total_pop

    # total population UNadj for cross check
    logger.info('adding total population UN adjusted')
    input_tiff_pop_unadj = os.path.join(
        input_dir, WORLDPOP_DIR,
        WORLDPOP_FILENAMES['unadj'].format(country_iso3=country_iso3.lower()))
    zs = zonal_stats(input_shp, input_tiff_pop_unadj, stats='sum')
    total_pop = [district_zs.get('sum') for district_zs in zs]
    ADM2boundaries['tot_pop_UN'] = total_pop

    # total from disaggregated
    logger.info('scaling SADD data to match UN Adjusted population estimates')
    gender_age_group_names = [
        '{}_{}'.format(gender_age_group[0], gender_age_group[1])
        for gender_age_group in gender_age_groups
    ]
    for index, row in ADM2boundaries.iterrows():
        tot_UN = row['tot_pop_UN']
        tot_sad = row[gender_age_group_names].sum()
        try:
            ADM2boundaries.loc[index,
                               gender_age_group_names] *= tot_UN / tot_sad
        except ZeroDivisionError:
            region_name = row[f'ADM2_{config["admin"]["language"]}']
            logger.warning(
                f'The sum across all genders and ages for admin region {region_name} is 0'
            )

    if 'pop_co' in config:
        print('Further scaling SADD data to match CO estimates')
        # scaling at the ADM1 level to match figures used by Country Office instead of UN stats
        input_pop_co_filename = os.path.join(input_dir, CO_DIR,
                                             config['pop_co']['filename'])
        df_operational_figures = pd.read_excel(input_pop_co_filename,
                                               usecols='A,D')
        df_operational_figures['Province'] = (
            df_operational_figures['Province'].replace(
                config['pop_co']['province_names']))
        # creating dictionary and add pcode the pcode
        ADM1_names = dict()
        for k, v in ADM2boundaries.groupby('ADM1_EN'):
            ADM1_names[k] = v.iloc[0, :].ADM1_PCODE
        df_operational_figures['ADM1_PCODE'] = df_operational_figures[
            'Province'].map(ADM1_names)
        if (df_operational_figures['ADM1_PCODE'].isnull().sum() > 0):
            print(
                'missing PCODE for: ', df_operational_figures[
                    df_operational_figures['ADM1_PCODE'].isnull()])
        # get total by ADM1
        tot_co_adm1 = df_operational_figures.groupby(
            'ADM1_PCODE').sum()['Estimated Population - 2020']
        tot_sad_adm1 = ADM2boundaries.groupby(
            'ADM1_PCODE')[gender_age_group_names].sum().sum(axis=1)
        for index, row in ADM2boundaries.iterrows():
            adm1_pcode = row['ADM1_PCODE']
            pop_co = tot_co_adm1.get(adm1_pcode)
            pop_sad = tot_sad_adm1.get(adm1_pcode)
            ADM2boundaries.loc[index,
                               gender_age_group_names] *= pop_co / pop_sad

    ADM2boundaries['tot_sad'] = ADM2boundaries.loc[:,
                                                   gender_age_group_names].sum(
                                                       axis=1)

    # adding manually Kochi nomads
    if 'kochi' in config:
        logger.info('Adding Kuchi')
        ADM1_kuchi = config['kochi']['adm1']
        # total population in these provinces
        pop_in_kuchi_ADM1 = ADM2boundaries[ADM2boundaries['ADM1_PCODE'].isin(
            ADM1_kuchi)]['tot_sad'].sum()
        for row_index, row in ADM2boundaries.iterrows():
            if row['ADM1_PCODE'] in ADM1_kuchi:
                tot_kuchi_in_ADM2 = 0
                for gender_age_group in gender_age_groups:
                    # population weighted
                    gender_age_group_name = f'{gender_age_group[0]}_{gender_age_group[1]}'
                    kuchi_pp = config['kochi']['total'] * (
                        row[gender_age_group_name] / pop_in_kuchi_ADM1)
                    ADM2boundaries.loc[row_index, gender_age_group_name] = row[
                        gender_age_group_name] + kuchi_pp
                    tot_kuchi_in_ADM2 += kuchi_pp
                ADM2boundaries.loc[row_index, 'kuchi'] = tot_kuchi_in_ADM2
                comment = f'Added in total {tot_kuchi_in_ADM2} Kuchi nomads to WorldPop estimates'
                ADM2boundaries.loc[row_index, 'comment'] = comment

    # Write to file
    ADM2boundaries['created_at'] = str(datetime.datetime.now())
    ADM2boundaries['created_by'] = getpass.getuser()
    output_geojson = get_output_filename(country_iso3)
    logger.info(f'Writing to file {output_geojson}')
    utils.write_to_geojson(output_geojson, ADM2boundaries)
from style import lili
from style import h1
from style import h2
from style import a
from style import p
from style import newline
from utils import parse_yaml
from utils import name2dir
from utils import dir2name
from utils import name2link
from utils import find_subdirectories
from utils import sort

# TODO conference badges

config = parse_yaml("config.yaml")
f = open(config["filename"], "w")

# Introduction ########################################################
f.write(h1(config["title"]))
for badge in config["badge"]:
    f.write(badge)
    newline(f)

newline(f)
f.write(config["description"])
newline(f, iter=2)

# Table of Contents ###################################################
f.write(h2("Table of Contents"))
table_of_contents = parse_yaml("data.yaml")
Example #26
0
def main(args):

    config = utils.parse_yaml(args.config)
    assert config, 'without config I cannot do anything!'

    if config.get('debug'):
        logging.basicConfig(level=logging.DEBUG)
    else:
        logging.basicConfig(level=logging.INFO)

    lastcolon = config['server'].rfind(':')
    assert lastcolon != -1, 'port must be set'

    server = config['server'][:lastcolon]
    port = int(config['server'][lastcolon + 1:])
    assert port > 0 and port < 65536, 'port must be in range (0, 65535]'

    token = hashlib.md5(config['token'].encode('utf-8')).digest()

    child_list = []

    for instance in config.get('instances', []):
        instance: dict
        try:
            mode = getattr(models.tunnelType, instance['mode'])
        except:
            logging.warn(
                "Config parse warning: instance %s has incorrect mode." %
                (str(instance)))
            continue

        remote_port = instance.get('remote_port')
        try:
            assert remote_port and remote_port > 0 and remote_port < 65536
        except:
            logging.warn(
                "Config parse warning: instance %s port not be in range (0, 65535]"
                % (str(instance)))
            continue

        domain = utils.punycode_encode(instance.get('domain', ''))
        if mode == models.tunnelType.http or mode == models.tunnelType.ssl:
            try:
                assert domain != ''
            except:
                logging.warn(
                    "Config parse warning: instance %s domain not set while using SLB"
                    % (str(instance)))

        arg_proc = {
            'host':
            str(instance.get('local_host')),
            'port':
            int(
                str('0' if instance.get('local_port') is None else instance.
                    get('local_port'))),
            'remote_port':
            remote_port,
            'mode':
            mode,
            'token':
            token,
            'domain':
            domain,
            'static':
            instance.get('static'),
            'bufsize':
            int(config.get('bufsize', 1024))
        }

        child_list.append(client.new_proc(server, port, token, arg_proc))

    logging.info("Netfwd client connecting to server [%s]:%d" % (server, port))
    logging.info("Copyright (C) 2021 Victor Huang <*****@*****.**>")
    logging.info("")
    logging.info("Added %d child services." % len(child_list))

    while child_list:
        for child in child_list:
            if not child.is_alive():
                child_list.remove(child)
Example #27
0
def main(country_iso3, download_ghs=False):

    # Get config file
    config = utils.parse_yaml(CONFIG_FILE)[country_iso3]

    # Get input boundary shape file
    input_dir = os.path.join(DIR_PATH, INPUT_DIR, country_iso3)
    input_shp = os.path.join(input_dir, SHAPEFILE_DIR,
                             config['admin']['directory'],
                             f'{config["admin"]["directory"]}.shp')
    boundaries = gpd.read_file(input_shp).to_crs(GHS_CRS)

    # Download the tiles and read them in
    if download_ghs:
        get_ghs_data('SMOD', config['ghs'], country_iso3, input_dir)
        get_ghs_data('POP', config['ghs'], country_iso3, input_dir)
    ghs_smod = rasterio.open(
        os.path.join(input_dir, GHS_DIR,
                     OUTPUT_GHS['SMOD'].format(country_iso3=country_iso3)))
    ghs_pop = rasterio.open(
        os.path.join(input_dir, GHS_DIR,
                     OUTPUT_GHS['POP'].format(country_iso3=country_iso3)))

    # adding urban/rural disaggregation data using JRC GHSL input
    logger.info("Calculating urban population fraction")
    boundaries['frac_urban'] = boundaries['geometry'].apply(
        lambda x: calc_frac_urban(x, ghs_smod, ghs_pop))

    # Get food insecurity
    logger.info("Getting food insecurity")
    boundaries = add_food_insecurity(config['ipc'], input_dir, boundaries,
                                     config['admin']['language'])

    # Get solid fuels
    if 'solid_fuels' in config:
        logger.info("Getting Solid Fuels data")
        boundaries = add_factor_urban_rural(boundaries, 'fossil_fuels',
                                            config['solid_fuels'])
    else:
        logger.info(
            f'Solid fuels data not available for country {country_iso3}')

    # Get handwashing facilities
    if 'handwashing_facilities' in config:
        logger.info("Getting Handwashing facilities data")
        boundaries = add_factor_urban_rural(boundaries,
                                            'handwashing_facilities',
                                            config['handwashing_facilities'])
    else:
        logger.info(
            f'Handwashing facilities data not available for country {country_iso3}'
        )

    # Get raised blood pressure
    if 'raised_blood_pressure' in config:
        logger.info("Getting Raised Blood Pressure data")
        add_factor_18plus(boundaries, config['raised_blood_pressure'],
                          'raised_blood_pressure', country_iso3)
    else:
        logger.info(
            f'Raised blood pressure data not available for country {country_iso3}'
        )

    # Get raised blood pressure
    if 'diabetes' in config:
        logger.info("Getting diabetes data")
        add_factor_18plus(boundaries, config['diabetes'], 'diabetes',
                          country_iso3)
    else:
        logger.info(f'Diabetes data not available for country {country_iso3}')

    # Get smoking
    if 'smoking' in config:
        logger.info("Getting smoking data")
        add_factor_18plus(boundaries, config['smoking'], 'smoking',
                          country_iso3)
    else:
        logger.info(f'Smoking data not available for country {country_iso3}')

    # Write out results
    output_dir = os.path.join(DIR_PATH, OUTPUT_DIR.format(country_iso3))
    Path(output_dir).mkdir(parents=True, exist_ok=True)
    output_geojson = os.path.join(
        output_dir, OUTPUT_GEOJSON.format(country_iso3=country_iso3))
    logger.info(f"Saving results to {output_geojson}")
    utils.write_to_geojson(output_geojson, boundaries.to_crs(SHP_CRS))
Example #28
0
# Dependencies
import CoinClass
import utils

# Main function of our program
if __name__ == '__main__':

    # Load of the Configuration file
    conf = utils.parse_yaml('conf.yaml')

    # Class initializer and function run
    min_coin = CoinClass.CoinProblem(cfg=conf)
    min_coin.cash_change()
Example #29
0
 def get_config(self):
     config_path = self.get_config_path()
     config = parse_yaml(config_path)
     return config
def main(argv=None):
    if argv is None:
        argv = sys.argv
    try:
        try:
            opts, args = getopt.getopt(argv[1:], "shc", ["send", "help"])
        except getopt.error, msg:
            raise Usage(msg)

        # option processing
        send = False
        for option, value in opts:
            if option in ("-s", "--send"):
                send = True
            if option in ("-h", "--help"):
                raise Usage(help_message)

        # Parse configuration
        config = parse_yaml()
        for key in REQRD:
            if key not in config.keys():
                raise Exception(
                    'Required parameter %s not in yaml config file!' % (key,))

        participants = config['PARTICIPANTS']
        couples = config['COUPLES']
        if len(participants) < 2:
            raise Exception('Not enough participants specified.')

        # Mail parsing
        f = open('templates/mail.html', 'r')
        mail_html = ""
        while 1:
            line = f.readline()
            if not line:
                break
            mail_html += line

        f.close()

        givers = []
        for person in participants:
            name, email = re.match(r'([^<]*)<([^>]*)>', person).groups()
            name = name.strip()
            partner = None
            for couple in couples:
                names = [n.strip() for n in couple.split(',')]
                if name in names:
                    # is part of this couple
                    for member in names:
                        if name != member:
                            partner = member
            person = Person(name, email, partner)
            givers.append(person)

        recievers = givers[:]
        pairs = create_pairs(givers, recievers)
        if not send:
            print """
                    Test pairings:

                    %s

                    To send out emails with new pairings,
                    call with the --send argument:

                    $ python secret_santa.py --send

            """ % ("\n".join([str(p) for p in pairs]))

        for pair in pairs:

            if send:
                to = "%s <%s>" % (pair.giver.name, pair.giver.email)
                mail = HtmlMail(
                    config['SUBJECT'], config['FROM'], to, config['USERNAME'],
                    config['PASSWORD'])

                mail.send(
                    parse_email(config['TEMPLATE']).format(
                        config['SUBJECT'], pair.giver.name, pair.reciever.name,
                        config['LIMIT'], config['DEATHLINE'])
                )
                print "Emailed %s <%s>" % (pair.giver.name, pair.giver.email)
Example #31
0
    def _validate_docs(self):
        doc_info = self._get_docs()
        deprecated = False
        if not bool(doc_info['DOCUMENTATION']['value']):
            self.reporter.error(path=self.object_path,
                                code=301,
                                msg='No DOCUMENTATION provided')
        else:
            doc, errors, traces = parse_yaml(
                doc_info['DOCUMENTATION']['value'],
                doc_info['DOCUMENTATION']['lineno'], self.name,
                'DOCUMENTATION')
            for error in errors:
                self.reporter.error(path=self.object_path, code=302, **error)
            for trace in traces:
                self.reporter.trace(path=self.object_path, tracebk=trace)
            if not errors and not traces:
                with CaptureStd():
                    try:
                        get_docstring(self.path, verbose=True)
                    except AssertionError:
                        fragment = doc['extends_documentation_fragment']
                        self.reporter.error(
                            path=self.object_path,
                            code=303,
                            msg='DOCUMENTATION fragment missing: %s' %
                            fragment)
                    except Exception:
                        self.reporter.trace(path=self.object_path,
                                            tracebk=traceback.format_exc())
                        self.reporter.error(
                            path=self.object_path,
                            code=304,
                            msg='Unknown DOCUMENTATION error, see TRACE')

                if 'options' in doc and doc['options'] is None and doc.get(
                        'extends_documentation_fragment'):
                    self.reporter.error(
                        path=self.object_path,
                        code=304,
                        msg=
                        ('DOCUMENTATION.options must be a dictionary/hash when used '
                         'with DOCUMENTATION.extends_documentation_fragment'))

                if self.object_name.startswith('_') and not os.path.islink(
                        self.object_path):
                    deprecated = True
                    if 'deprecated' not in doc or not doc.get('deprecated'):
                        self.reporter.error(
                            path=self.object_path,
                            code=318,
                            msg=
                            'Module deprecated, but DOCUMENTATION.deprecated is missing'
                        )

                if os.path.islink(self.object_path):
                    # This module has an alias, which we can tell as it's a symlink
                    # Rather than checking for `module: $filename` we need to check against the true filename
                    self._validate_docs_schema(
                        doc,
                        doc_schema(
                            os.readlink(self.object_path).split('.')[0]),
                        'DOCUMENTATION', 305)
                else:
                    # This is the normal case
                    self._validate_docs_schema(
                        doc, doc_schema(self.object_name.split('.')[0]),
                        'DOCUMENTATION', 305)

                self._check_version_added(doc)
                self._check_for_new_args(doc)

        if not bool(doc_info['EXAMPLES']['value']):
            self.reporter.error(path=self.object_path,
                                code=310,
                                msg='No EXAMPLES provided')
        else:
            _, errors, traces = parse_yaml(doc_info['EXAMPLES']['value'],
                                           doc_info['EXAMPLES']['lineno'],
                                           self.name,
                                           'EXAMPLES',
                                           load_all=True)
            for error in errors:
                self.reporter.error(path=self.object_path, code=311, **error)
            for trace in traces:
                self.reporter.trace(path=self.object_path, tracebk=trace)

        if not bool(doc_info['RETURN']['value']):
            if self._is_new_module():
                self.reporter.error(path=self.object_path,
                                    code=312,
                                    msg='No RETURN provided')
            else:
                self.reporter.warning(path=self.object_path,
                                      code=312,
                                      msg='No RETURN provided')
        else:
            data, errors, traces = parse_yaml(doc_info['RETURN']['value'],
                                              doc_info['RETURN']['lineno'],
                                              self.name, 'RETURN')
            if data:
                for ret_key in data:
                    self._validate_docs_schema(data[ret_key],
                                               return_schema(data[ret_key]),
                                               'RETURN.%s' % ret_key, 319)

            for error in errors:
                self.reporter.error(path=self.object_path, code=313, **error)
            for trace in traces:
                self.reporter.trace(path=self.object_path, tracebk=trace)

        if not bool(doc_info['ANSIBLE_METADATA']['value']):
            self.reporter.error(path=self.object_path,
                                code=314,
                                msg='No ANSIBLE_METADATA provided')
        else:
            metadata = None
            if isinstance(doc_info['ANSIBLE_METADATA']['value'], ast.Dict):
                metadata = ast.literal_eval(
                    doc_info['ANSIBLE_METADATA']['value'])
            else:
                metadata, errors, traces = parse_yaml(
                    doc_info['ANSIBLE_METADATA']['value'].s,
                    doc_info['ANSIBLE_METADATA']['lineno'], self.name,
                    'ANSIBLE_METADATA')
                for error in errors:
                    self.reporter.error(path=self.object_path,
                                        code=315,
                                        **error)
                for trace in traces:
                    self.reporter.trace(path=self.object_path, tracebk=trace)

            if metadata:
                self._validate_docs_schema(metadata,
                                           metadata_1_1_schema(deprecated),
                                           'ANSIBLE_METADATA', 316)

        return doc_info
Example #32
0
def run(args):
    num_bins, config_dict = parse_yaml(args.config)
    # Load cmvn
    dict_mvn = config_dict["dataloader"]["mvn_dict"]
    if dict_mvn:
        if not os.path.exists(dict_mvn):
            raise FileNotFoundError("Could not find mvn files")
        with open(dict_mvn, "rb") as f:
            dict_mvn = pickle.load(f)

    dcnet = DCNet(num_bins, **config_dict["dcnet"])

    frame_length = config_dict["spectrogram_reader"]["frame_length"]
    frame_shift = config_dict["spectrogram_reader"]["frame_shift"]
    window = config_dict["spectrogram_reader"]["window"]

    cluster = DeepCluster(
        dcnet,
        args.dcnet_state,
        args.num_spks,
        pca=args.dump_pca,
        cuda=args.cuda)

    utt_dict = parse_scps(args.wave_scp)
    num_utts = 0
    for key, utt in utt_dict.items():
        try:
            samps, stft_mat = stft(
                utt,
                frame_length=frame_length,
                frame_shift=frame_shift,
                window=window,
                center=True,
                return_samps=True)
        except FileNotFoundError:
            print("Skip utterance {}... not found".format(key))
            continue
        print("Processing utterance {}".format(key))
        num_utts += 1
        norm = np.linalg.norm(samps, np.inf)
        pca_mat, spk_mask, spk_spectrogram = cluster.seperate(
            stft_mat, cmvn=dict_mvn)

        for index, stft_mat in enumerate(spk_spectrogram):
            istft(
                os.path.join(args.dump_dir, '{}.spk{}.wav'.format(
                    key, index + 1)),
                stft_mat,
                frame_length=frame_length,
                frame_shift=frame_shift,
                window=window,
                center=True,
                norm=norm,
                fs=8000,
                nsamps=samps.size)
            if args.dump_mask:
                sio.savemat(
                    os.path.join(args.dump_dir, '{}.spk{}.mat'.format(
                        key, index + 1)), {"mask": spk_mask[index]})
        if args.dump_pca:
            sio.savemat(
                os.path.join(args.dump_dir, '{}.mat'.format(key)),
                {"pca_matrix": pca_mat})
    print("Processed {} utterance!".format(num_utts))
Example #33
0
 def get_config(self):
     config_path = self.get_config_path()
     config = parse_yaml(config_path)
     return config
Example #34
0
    def _validate_docs(self):
        doc_info = self._get_docs()
        deprecated = False
        if not bool(doc_info['DOCUMENTATION']['value']):
            self.reporter.error(
                path=self.object_path,
                code=301,
                msg='No DOCUMENTATION provided'
            )
        else:
            doc, errors, traces = parse_yaml(
                doc_info['DOCUMENTATION']['value'],
                doc_info['DOCUMENTATION']['lineno'],
                self.name, 'DOCUMENTATION'
            )
            for error in errors:
                self.reporter.error(
                    path=self.object_path,
                    code=302,
                    **error
                )
            for trace in traces:
                self.reporter.trace(
                    path=self.object_path,
                    tracebk=trace
                )
            if not errors and not traces:
                with CaptureStd():
                    try:
                        get_docstring(self.path, verbose=True)
                    except AssertionError:
                        fragment = doc['extends_documentation_fragment']
                        self.reporter.error(
                            path=self.object_path,
                            code=303,
                            msg='DOCUMENTATION fragment missing: %s' % fragment
                        )
                    except Exception:
                        self.reporter.trace(
                            path=self.object_path,
                            tracebk=traceback.format_exc()
                        )
                        self.reporter.error(
                            path=self.object_path,
                            code=304,
                            msg='Unknown DOCUMENTATION error, see TRACE'
                        )

                if 'options' in doc and doc['options'] is None and doc.get('extends_documentation_fragment'):
                    self.reporter.error(
                        path=self.object_path,
                        code=304,
                        msg=('DOCUMENTATION.options must be a dictionary/hash when used '
                             'with DOCUMENTATION.extends_documentation_fragment')
                    )

                if self.object_name.startswith('_') and not os.path.islink(self.object_path):
                    deprecated = True
                    if 'deprecated' not in doc or not doc.get('deprecated'):
                        self.reporter.error(
                            path=self.object_path,
                            code=318,
                            msg='Module deprecated, but DOCUMENTATION.deprecated is missing'
                        )

                if os.path.islink(self.object_path):
                    # This module has an alias, which we can tell as it's a symlink
                    # Rather than checking for `module: $filename` we need to check against the true filename
                    self._validate_docs_schema(doc, doc_schema(os.readlink(self.object_path).split('.')[0]), 'DOCUMENTATION', 305)
                else:
                    # This is the normal case
                    self._validate_docs_schema(doc, doc_schema(self.object_name.split('.')[0]), 'DOCUMENTATION', 305)

                self._check_version_added(doc)
                self._check_for_new_args(doc)

        if not bool(doc_info['EXAMPLES']['value']):
            self.reporter.error(
                path=self.object_path,
                code=310,
                msg='No EXAMPLES provided'
            )
        else:
            _, errors, traces = parse_yaml(doc_info['EXAMPLES']['value'],
                                           doc_info['EXAMPLES']['lineno'],
                                           self.name, 'EXAMPLES', load_all=True)
            for error in errors:
                self.reporter.error(
                    path=self.object_path,
                    code=311,
                    **error
                )
            for trace in traces:
                self.reporter.trace(
                    path=self.object_path,
                    tracebk=trace
                )

        if not bool(doc_info['RETURN']['value']):
            if self._is_new_module():
                self.reporter.error(
                    path=self.object_path,
                    code=312,
                    msg='No RETURN provided'
                )
            else:
                self.reporter.warning(
                    path=self.object_path,
                    code=312,
                    msg='No RETURN provided'
                )
        else:
            data, errors, traces = parse_yaml(doc_info['RETURN']['value'],
                                              doc_info['RETURN']['lineno'],
                                              self.name, 'RETURN')
            if data:
                for ret_key in data:
                    self._validate_docs_schema(data[ret_key], return_schema(data[ret_key]), 'RETURN.%s' % ret_key, 319)

            for error in errors:
                self.reporter.error(
                    path=self.object_path,
                    code=313,
                    **error
                )
            for trace in traces:
                self.reporter.trace(
                    path=self.object_path,
                    tracebk=trace
                )

        if not bool(doc_info['ANSIBLE_METADATA']['value']):
            self.reporter.error(
                path=self.object_path,
                code=314,
                msg='No ANSIBLE_METADATA provided'
            )
        else:
            metadata = None
            if isinstance(doc_info['ANSIBLE_METADATA']['value'], ast.Dict):
                metadata = ast.literal_eval(
                    doc_info['ANSIBLE_METADATA']['value']
                )
            else:
                metadata, errors, traces = parse_yaml(
                    doc_info['ANSIBLE_METADATA']['value'].s,
                    doc_info['ANSIBLE_METADATA']['lineno'],
                    self.name, 'ANSIBLE_METADATA'
                )
                for error in errors:
                    self.reporter.error(
                        path=self.object_path,
                        code=315,
                        **error
                    )
                for trace in traces:
                    self.reporter.trace(
                        path=self.object_path,
                        tracebk=trace
                    )

            if metadata:
                self._validate_docs_schema(metadata, metadata_1_1_schema(deprecated),
                                           'ANSIBLE_METADATA', 316)

        return doc_info