Ejemplo n.º 1
0
 def train_subnet(self, ph_keep, a_path):
     na = str(ph_keep)[1:-1].replace(' ', '').replace('\'',
                                                      '').replace(',', '+')
     #net = FeedForwardNet(hidden=[10], tf_name='Sigmoid')
     dataset = open_shelve(a_path + 'dataset_' + na + '.ds')
     #net.fit(x=dataset['x'], y=dataset['y'], x_val=dataset['x_val'], y_val=dataset['y_val'], learning_rate=0.07, n_epoch=200, req_acc=1.0, batch_size=10)
     dataset.close()
Ejemplo n.º 2
0
    def create_dataset(self, ph_keep, a_path):
        na = str(ph_keep)[1:-1].replace(' ', '').replace('\'',
                                                         '').replace(',', '+')
        dataset_new = open_shelve(a_path + 'dataset_' + na + '.ds')
        tmp_x = list()
        tmp_y = list()
        for x, y in self.net.t_data:
            if y in ph_keep:
                tmp_y.append(y)
                tmp_x.append(x)
        dataset_new['x'] = tmp_x
        dataset_new['y'] = tmp_y

        tmp_x = list()
        tmp_y = list()
        for x, y in self.net.v_data:
            if y in ph_keep:
                tmp_y.append(y)
                tmp_x.append(x)
        dataset_new['x_val'] = tmp_x
        dataset_new['y_val'] = tmp_y
        dataset_new.close()
Ejemplo n.º 3
0
                        '--name_appendix',
                        type=str,
                        default='',
                        help='Dataset filename appendix')
    return parser.parse_args()


if __name__ == '__main__':
    args = parse_arguments()
    destination = 'dataset_mnist' + args.name_appendix + '.ds'

    print_message(message='Loading YannLecun\'s MNIST data...')
    with open_gzip('../../../data/data_mnist/mnist.pkl.gz', 'rb') as f:
        data_train, data_val, data_test = load_cpickle(f)

    dataset = open_shelve(destination, 'c')
    class_counter = dict()
    if args.n_samples == -1:
        print_message(message='Got MNIST dataset: ' + str(len(data_train[0])) +
                      ' : ' + str(len(data_val[0])) + ' : ' +
                      str(len(data_test[0])) + ', saving...')
        dataset['x'] = [reshape(x, (784, 1)) for x in data_train[0]]
        dataset['y'] = data_train[1]
    else:
        print_message(message='Got MNIST dataset: ' +
                      str(args.n_samples * 10) + ' : ' +
                      str(len(data_val[0])) + ' : ' + str(len(data_test[0])) +
                      ', saving...')
        tmp_x = list()
        tmp_y = list()
        for x, y in zip(data_train[0], data_train[1]):
Ejemplo n.º 4
0
    def search(self, number_of_test_queries, stop_on_result_mismatch,
               stop_on_crash):
        if exists(self.query_shelve_path):
            # Ensure a clean shelve will be created
            remove(self.query_shelve_path)

        start_time = time()
        impala_sql_writer = SqlWriter.create(dialect=IMPALA)
        reference_sql_writer = SqlWriter.create(
            dialect=self.reference_connection.db_type)
        query_result_comparator = QueryResultComparator(
            self.impala_connection, self.reference_connection)
        query_generator = QueryGenerator()
        query_count = 0
        queries_resulted_in_data_count = 0
        mismatch_count = 0
        query_timeout_count = 0
        known_error_count = 0
        impala_crash_count = 0
        last_error = None
        repeat_error_count = 0
        with open(self.query_log_path, 'w') as impala_query_log:
            impala_query_log.write('--\n' '-- Stating new run\n' '--\n')
            while number_of_test_queries > query_count:
                query = query_generator.create_query(self.common_tables)
                impala_sql = impala_sql_writer.write_query(query)
                if 'FULL OUTER JOIN' in impala_sql and self.reference_connection.db_type == MYSQL:
                    # Not supported by MySQL
                    continue

                query_count += 1
                LOG.info('Running query #%s', query_count)
                impala_query_log.write(impala_sql + ';\n')
                result = query_result_comparator.compare_query_results(query)
                if result.query_resulted_in_data:
                    queries_resulted_in_data_count += 1
                if result.error:
                    # TODO: These first two come from psycopg2, the postgres driver. Maybe we should
                    #       try a different driver? Or maybe the usage of the driver isn't correct.
                    #       Anyhow ignore these failures.
                    if 'division by zero' in result.error \
                        or 'out of range' in result.error \
                        or 'Too much data' in result.error:
                        LOG.debug('Ignoring error: %s', result.error)
                        query_count -= 1
                        continue

                    if result.is_known_error:
                        known_error_count += 1
                    elif result.query_timed_out:
                        query_timeout_count += 1
                    else:
                        mismatch_count += 1
                        with closing(open_shelve(
                                self.query_shelve_path)) as query_shelve:
                            query_shelve[str(query_count)] = query

                    print('---Impala Query---\n')
                    print(
                        impala_sql_writer.write_query(query, pretty=True) +
                        '\n')
                    print('---Reference Query---\n')
                    print(
                        reference_sql_writer.write_query(query, pretty=True) +
                        '\n')
                    print('---Error---\n')
                    print(result.error + '\n')
                    print('------\n')

                    if 'Could not connect' in result.error \
                        or "Couldn't open transport for" in result.error:
                        # if stop_on_crash:
                        #   break
                        # Assume Impala crashed and try restarting
                        impala_crash_count += 1
                        LOG.info('Restarting Impala')
                        call([
                            join(getenv('IMPALA_HOME'),
                                 'bin/start-impala-cluster.py'),
                            '--log_dir=%s' % getenv('LOG_DIR', "/tmp/")
                        ])
                        self.impala_connection.reconnect()
                        query_result_comparator.impala_cursor = self.impala_connection.create_cursor(
                        )
                        result = query_result_comparator.compare_query_results(
                            query)
                        if result.error:
                            LOG.info('Restarting Impala')
                            call([
                                join(getenv('IMPALA_HOME'),
                                     'bin/start-impala-cluster.py'),
                                '--log_dir=%s' % getenv('LOG_DIR', "/tmp/")
                            ])
                            self.impala_connection.reconnect()
                            query_result_comparator.impala_cursor = self.impala_connection.create_cursor(
                            )
                        else:
                            break

                    if stop_on_result_mismatch and \
                        not (result.is_known_error or result.query_timed_out):
                        break

                    if last_error == result.error \
                        and not (result.is_known_error or result.query_timed_out):
                        repeat_error_count += 1
                        if repeat_error_count == self.ABORT_ON_REPEAT_ERROR_COUNT:
                            break
                    else:
                        last_error = result.error
                        repeat_error_count = 0
                else:
                    if result.query_resulted_in_data:
                        LOG.info('Results matched (%s rows)',
                                 result.impala_row_count)
                    else:
                        LOG.info('Query did not produce meaningful data')
                    last_error = None
                    repeat_error_count = 0

            return SearchResults(query_count, queries_resulted_in_data_count,
                                 mismatch_count, query_timeout_count,
                                 known_error_count, impala_crash_count,
                                 time() - start_time)
Ejemplo n.º 5
0
def parse_arguments():
    parser = argparse.ArgumentParser(description='Classifies testing data.')
    parser.add_argument('-c', '--clf', type=str, required=True,
                        help='Classifier filename to classify with')
    parser.add_argument('-ds', '--dataset', type=str, required=True,
                        help='Dataset to classify on')
    return parser.parse_args()

if __name__ == '__main__':

    args = parse_arguments()
    clf_dir = '../cache/trained/'+args.clf+'.net'

    ''' Loading the classifier and testing data '''
    clf = open_shelve(clf_dir, 'c')
    nn_classifier = clf['net']
    structure = nn_classifier[0]
    weights = nn_classifier[1]
    biases = nn_classifier[2]
    labels = nn_classifier[3]
    dataset_dir = args.dataset
    print '\n\n ## Classification : training parameters:', clf['training_params']
    clf.close()

    net = NeuralNet(program=None, name=str(structure), structure=structure)
    net.weights = weights
    net.biases = biases
    net.labels = labels
    net.map_params()
Ejemplo n.º 6
0
    print_param(description='Number of experiment observations',
                param_str=str(args.n_obs))
    print_param(description='Initial number of hidden neurons',
                param_str=str(args.hidden_structure))
    print_param(description='Required accuracy', param_str=str(args.req_acc))

    params_str = '_hs' + str(args.hidden_structure) + '_ra' + str(
        args.req_acc).replace('.', '') + '_no' + str(args.n_obs)
    if args.generate:
        stats_data = list()
        for i_obs in range(1, args.n_obs + 1):
            print_message(message='MNIST experiment, observation ' +
                          str(i_obs) + '/' + str(args.n_obs))
            net = FeedForwardNet(hidden=args.hidden_structure,
                                 tf_name='Sigmoid')
            dataset = open_shelve('../examples/mnist/dataset_mnist_1K.ds', 'c')
            net.fit(x=dataset['x'],
                    y=dataset['y'],
                    x_val=dataset['x_val'],
                    y_val=dataset['y_val'],
                    learning_rate=0.3,
                    n_epoch=10,
                    req_acc=1.0,
                    batch_size=10)
            res = net.evaluate(x=dataset['x_test'], y=dataset['y_test'])
            print_message(message='Evaluation on test data after training:')
            print_param(description='Accuracy', param_str=str(res[1]))
            print_param(description='Error', param_str=str(res[0]))
            if net.learning.stats['t_acc'][-1] < args.req_acc:
                print 'Skipping observation'
                continue
Ejemplo n.º 7
0
    print_param(description='Border size (strictness)', param_str=str(args.border_size))
    print_param(description='Context size', param_str=str(args.context_size))
    print_param(description='Number of MEL filters', param_str=str(args.n_filters))
    print_param(description='Number of records', param_str=str(args.n_records))
    print_param(description='Number of samples', param_str=str(args.n_samples))
    print_param(description='Maximum number of other phonemes', param_str=str(args.max_rest))
    print_param(description='Phonemes as classes', param_str=str(args.phonemes) if args.phonemes else 'all')
    print_param(description='Data split (train/val/test)', param_str=str(args.data_split))
    print_param(description='Dataset destination file name', param_str=destination)
    
    mlf = dict()
    features = dict()
    samples = dict()
    data = {'x': list(), 'y': list(), 'x_val': list(), 'y_val': list(), 'x_test': list(), 'y_test': list(), 
            'record_keys': list(), 'record_keys_val': list(), 'record_keys_test': list()}
    get_speech_data()

    print_message(message='Saving dataset...')
    dataset = open_shelve(destination, 'c', protocol=2)
    for key, value in data.items():
        dataset[key] = value
    dataset['features'] = args.feature_filename
    dataset['alignments'] = args.alignment_filename
    dataset['border_size'] = str(args.border_size)
    dataset['context_size'] = str(args.context_size)
    dataset['n_filters'] = str(args.n_filters)
    dataset['n_records'] = str(args.n_records)
    dataset.close()
    print_message(message='Dataset dumped as '+destination)
    print len(features['spk-1-1.wav']), features['spk-1-1.wav'][0]
Ejemplo n.º 8
0
                                 X_val=dataset['x']['validation'], y_val=dataset['y']['validation'],
                                 req_acc=req_accuaracy)


if __name__ == '__main__':

    args = parse_arguments()
    net_dir = '../cache/trained/'+args.net+'.net'
    destination = '../cache/pruned/'+args.net+'_p_'+args.name_appendix+'.net'
    learning_rate = args.learning_rate
    max_epochs = args.max_iter
    n_stable = args.n_stable
    dataset_dir = args.dataset

    ''' Loading the classifier and testing data '''
    net_file = open_shelve(net_dir, 'c')
    nn_classifier = net_file['net']
    structure = nn_classifier[0]
    weights = nn_classifier[1]
    biases = nn_classifier[2]
    labels = nn_classifier[3]
    print '\n\n ##Net loaded. Training parameters:', net_file['training_params']
    net_file.close()

    net = NeuralNet(program=None, name=str(structure), structure=structure)
    net.weights = weights
    net.biases = biases
    net.labels = labels
    net.map_params()

    dataset = open_shelve('../cache/datasets/'+dataset_dir+'.ds', 'c')
get_ipython().magic(u'matplotlib qt')
from numpy import newaxis, delete, zeros, concatenate, unique

# In[2]:

net = FeedForwardNet(hidden=[], tf_name='Sigmoid')
net_sz = FeedForwardNet(hidden=[], tf_name='Sigmoid')

# In[3]:

net.load('../examples/speech/net_speech_pruned.net')
net_sz.load('../examples/speech/net_speech_sz_pruned.net')

# In[4]:

dataset = open_shelve(
    '../examples/speech/dataset_speech_bs2_cs5_ds811_nr500.ds', 'c')
dataset_sz = open_shelve('../examples/speech/dataset_speech_sz.ds', 'c')

# In[7]:

net.t_data = net.prepare_data(x=dataset['x'], y=dataset['y'])
net.v_data = net.prepare_data(x=dataset['x_val'], y=dataset['y_val'])
test_data = net.prepare_data(x=dataset['x_test'], y=dataset['y_test'])
print 'full:', len(net.t_data), len(net.v_data), len(test_data)

net_sz.t_data = net_sz.prepare_data(x=dataset_sz['x'], y=dataset_sz['y'])
net_sz.v_data = net_sz.prepare_data(x=dataset_sz['x_val'],
                                    y=dataset_sz['y_val'])
test_data_sz = net_sz.prepare_data(x=dataset_sz['x_test'],
                                   y=dataset_sz['y_test'])
print 'sz:', len(net_sz.t_data), len(net_sz.v_data), len(test_data_sz)
def map_f5reads_2_taxann(f5_path, tsv_taxann_lst, tax_annot_res_dir):
    # Function perform mapping of all reads stored in input FAST5 files
    #     to existing TSV files containing taxonomic annotation info.
    #
    # It creates an DBM index file.
    #
    # :param f5_path: path to current FAST5 file;
    # :type f5_path: str;
    # :param tsv_taxann_lst: list of path to TSV files that contain taxonomic annotation;
    # :type tsv_taxann_lst: list<str>;
    # :param tax_annot_res_dir: path to directory containing taxonomic annotation;
    # :type tax_annot_res_dir: str;

    index_dirpath = os.path.join(
        tax_annot_res_dir,
        index_name)  # name of directory that will contain indicies

    # File validation:
    #   RuntimeError will be raised if FAST5 file is broken.
    try:
        # File existance checking is performed while parsing CL arguments.
        # Therefore, this if-statement will trigger only if f5_path's file is not a valid HDF5 file.
        if not h5py.is_hdf5(f5_path):
            raise RuntimeError("file is not of HDF5 (i.e. not FAST5) format")
        # end if

        f5_file = h5py.File(f5_path, 'r')

        for _ in f5_file:
            break
        # end for
    except RuntimeError as runterr:
        printlog_error_time("FAST5 file is broken")
        printlog_error("Reading the file `{}` crashed.".format(
            os.path.basename(f5_path)))
        printlog_error("Reason: {}".format(str(runterr)))
        printlog_error("Omitting this file...")
        print()
        return
    # end try

    readids_to_seek = list(fast5_readids(f5_file))
    idx_dict = dict()  # dictionary for index

    # This saving is needed to compare with 'len(readids_to_seek)'
    #    after all TSV will be looked through in order to
    #    determine if some reads miss taxonomic annotation.
    len_before = len(readids_to_seek)

    # Iterate over TSV-taaxnn file
    for tsv_taxann_fpath in tsv_taxann_lst:

        with open(tsv_taxann_fpath, 'r') as taxann_file:

            # Get all read IDs in current TSV
            readids_in_tsv = list(
                map(lambda l: l.split('\t')[0], taxann_file.readlines()))

            # Iterate over all other reads in current FAST5
            #    ('reversed' is necessary because we remove items from list in this loop)
            for readid in reversed(readids_to_seek):
                fmt_id = fmt_read_id(readid)[1:]
                if fmt_id in readids_in_tsv:
                    # If not first -- write data to dict (and to index later)
                    try:
                        idx_dict[tsv_taxann_fpath].append(
                            "read_" + fmt_id)  # append to existing list
                    except KeyError:
                        idx_dict[tsv_taxann_fpath] = ["read_" + fmt_id
                                                      ]  # create a new list
                    finally:
                        readids_to_seek.remove(readid)
                    # end try
                # end if
            # end for
        # end with
        if len(readids_to_seek) == 0:
            break
        # end if
    # end for

    # If after all TSV is checked but nothing have changed -- we miss taxonomic annotation
    #     for some reads! And we will write their IDs to 'missing_reads_lst.txt' file.
    if len(readids_to_seek) == len_before:
        printlog_error_time("reads from FAST5 file not found")
        printlog_error("FAST5 file: `{}`".format(f5_path))
        printlog_error("Some reads have not undergone taxonomic annotation.")
        missing_log = "missing_reads_lst.txt"
        printlog_error("List of missing reads are in following file:")
        printlog_error("{}".format(missing_log))
        with open(missing_log, 'w') as missing_logfile:
            missing_logfile.write(
                "Missing reads from file '{}':\n\n".format(f5_path))
            for readid in readids_to_seek:
                missing_logfile.write(fmt_read_id(readid) + '\n')
            # end for
        try:
            for path in glob(os.path.join(index_dirpath, '*')):
                os.unlink(path)
            # end for
            os.rmdir(index_dirpath)
        except OSError as oserr:
            printlog_error(
                "Error occured while removing index directory: {}".format(
                    oserr))
        finally:
            platf_depend_exit(3)
        # end try
    # end if

    try:
        # Open index files appending to existing data ('c' parameter)
        with open_shelve(os.path.join(index_dirpath, index_name),
                         'c') as index_f5_2_tsv:
            # Update index
            index_f5_2_tsv[f5_path] = idx_dict
        # end with
    except OSError as oserr:
        printlog_error_time("Error: cannot create index file `{}`"\
            .format(os.path.join(index_dirpath, index_name)))
        printlog_error(str(oserr))
        platf_depend_exit(1)
Ejemplo n.º 11
0
  def search(self, number_of_test_queries, stop_on_result_mismatch, stop_on_crash):
    if exists(self.query_shelve_path):
      # Ensure a clean shelve will be created
      remove(self.query_shelve_path)

    start_time = time()
    impala_sql_writer = SqlWriter.create(dialect=IMPALA)
    reference_sql_writer = SqlWriter.create(
        dialect=self.reference_connection.db_type)
    query_result_comparator = QueryResultComparator(
        self.impala_connection, self.reference_connection)
    query_generator = QueryGenerator()
    query_count = 0
    queries_resulted_in_data_count = 0
    mismatch_count = 0
    query_timeout_count = 0
    known_error_count = 0
    impala_crash_count = 0
    last_error = None
    repeat_error_count = 0
    with open(self.query_log_path, 'w') as impala_query_log:
      impala_query_log.write(
         '--\n'
         '-- Stating new run\n'
         '--\n')
      while number_of_test_queries > query_count:
        query = query_generator.create_query(self.common_tables)
        impala_sql = impala_sql_writer.write_query(query)
        if 'FULL OUTER JOIN' in impala_sql and self.reference_connection.db_type == MYSQL:
          # Not supported by MySQL
          continue

        query_count += 1
        LOG.info('Running query #%s', query_count)
        impala_query_log.write(impala_sql + ';\n')
        result = query_result_comparator.compare_query_results(query)
        if result.query_resulted_in_data:
          queries_resulted_in_data_count += 1
        if result.error:
          # TODO: These first two come from psycopg2, the postgres driver. Maybe we should
          #       try a different driver? Or maybe the usage of the driver isn't correct.
          #       Anyhow ignore these failures.
          if 'division by zero' in result.error \
              or 'out of range' in result.error \
              or 'Too much data' in result.error:
            LOG.debug('Ignoring error: %s', result.error)
            query_count -= 1
            continue

          if result.is_known_error:
            known_error_count += 1
          elif result.query_timed_out:
            query_timeout_count += 1
          else:
            mismatch_count += 1
            with closing(open_shelve(self.query_shelve_path)) as query_shelve:
              query_shelve[str(query_count)] = query

          print('---Impala Query---\n')
          print(impala_sql_writer.write_query(query, pretty=True) + '\n')
          print('---Reference Query---\n')
          print(reference_sql_writer.write_query(query, pretty=True) + '\n')
          print('---Error---\n')
          print(result.error + '\n')
          print('------\n')

          if 'Could not connect' in result.error \
              or "Couldn't open transport for" in result.error:
            # if stop_on_crash:
            #   break
            # Assume Impala crashed and try restarting
            impala_crash_count += 1
            LOG.info('Restarting Impala')
            call([join(getenv('IMPALA_HOME'), 'bin/start-impala-cluster.py'),
                            '--log_dir=%s' % getenv('LOG_DIR', "/tmp/")])
            self.impala_connection.reconnect()
            query_result_comparator.impala_cursor = self.impala_connection.create_cursor()
            result = query_result_comparator.compare_query_results(query)
            if result.error:
              LOG.info('Restarting Impala')
              call([join(getenv('IMPALA_HOME'), 'bin/start-impala-cluster.py'),
                              '--log_dir=%s' % getenv('LOG_DIR', "/tmp/")])
              self.impala_connection.reconnect()
              query_result_comparator.impala_cursor = self.impala_connection.create_cursor()
            else:
              break

          if stop_on_result_mismatch and \
              not (result.is_known_error or result.query_timed_out):
            break

          if last_error == result.error \
              and not (result.is_known_error or result.query_timed_out):
            repeat_error_count += 1
            if repeat_error_count == self.ABORT_ON_REPEAT_ERROR_COUNT:
              break
          else:
            last_error = result.error
            repeat_error_count = 0
        else:
          if result.query_resulted_in_data:
            LOG.info('Results matched (%s rows)', result.impala_row_count)
          else:
            LOG.info('Query did not produce meaningful data')
          last_error = None
          repeat_error_count = 0

      return SearchResults(
          query_count,
          queries_resulted_in_data_count,
          mismatch_count,
          query_timeout_count,
          known_error_count,
          impala_crash_count,
          time() - start_time)
def bin_fast5_file(f5_path, tax_annot_res_dir, sens, min_qual, min_qlen,
                   min_pident, min_coverage, no_trash):
    # Function bins FAST5 file with untwisting.
    #
    # :param f5_path: path to FAST5 file meant to be processed;
    # :type f5_path: str;
    # :param tax_annot_res_dir: path to directory containing taxonomic annotation;
    # :type tax_annot_res_dir: str;
    # :param sens: binning sensitivity;
    # :type sens: str;
    # :param min_qual: threshold for quality filter;
    # :type min_qual: float;
    # :param min_qlen: threshold for length filter;
    # :type min_qlen: int (or None, if this filter is disabled);
    # :param min_pident: threshold for alignment identity filter;
    # :type min_pident: float (or None, if this filter is disabled);
    # :param min_coverage: threshold for alignment coverage filter;
    # :type min_coverage: float (or None, if this filter is disabled);
    # :param no_trash: loical value. True if user does NOT want to output trash files;
    # :type no_trash: bool;

    outdir_path = os.path.dirname(
        logging.getLoggerClass().root.handlers[0].baseFilename)

    seqs_pass = 0  # counter for sequences, which pass filters
    QL_seqs_fail = 0  # counter for too short or too low-quality sequences
    align_seqs_fail = 0  # counter for sequences, which align to their best hit with too low identity or coverage

    srt_file_dict = dict()

    index_dirpath = os.path.join(
        tax_annot_res_dir,
        index_name)  # name of directory that will contain indicies

    # Make filter for quality and length
    QL_filter = get_QL_filter(f5_path, min_qual, min_qlen)
    # Configure path to trash file
    if not no_trash:
        QL_trash_fpath = get_QL_trash_fpath(
            f5_path,
            outdir_path,
            min_qual,
            min_qlen,
        )
    else:
        QL_trash_fpath = None
    # end if

    # Make filter for identity and coverage
    align_filter = get_align_filter(min_pident, min_coverage)
    # Configure path to this trash file
    if not no_trash:
        align_trash_fpath = get_align_trash_fpath(f5_path, outdir_path,
                                                  min_pident, min_coverage)
    else:
        align_trash_fpath = None
    # end if

    # File validation:
    #   RuntimeError will be raised if FAST5 file is broken.
    try:
        # File existance checking is performed while parsing CL arguments.
        # Therefore, this if-statement will trigger only if f5_path's file is not a valid HDF5 file.
        if not h5py.is_hdf5(f5_path):
            raise RuntimeError("file is not of HDF5 (i.e. not FAST5) format")
        # end if

        from_f5 = h5py.File(f5_path, 'r')

        for _ in from_f5:
            break
        # end for
    except RuntimeError as runterr:
        printlog_error_time("FAST5 file is broken")
        printlog_error("Reading the file `{}` crashed.".format(
            os.path.basename(f5_path)))
        printlog_error("Reason: {}".format(str(runterr)))
        printlog_error("Omitting this file...")
        print()
        # Return zeroes -- inc_val won't be incremented and this file will be omitted
        return (0, 0, 0)
    # end try

    # singleFAST5 and multiFAST5 files should be processed in different ways
    # "Raw" group always in singleFAST5 root and never in multiFAST5 root
    if "Raw" in from_f5.keys():
        f5_cpy_func = copy_single_f5
    else:
        f5_cpy_func = copy_read_f5_2_f5
    # end if

    readids_to_seek = list(from_f5.keys())  # list of not-binned-yet read IDs

    # Fill the list 'readids_to_seek'
    for read_name in fast5_readids(from_f5):
        # Get rid of "read_"
        readids_to_seek.append(sys.intern(read_name))
    # end for

    # Walk through the index
    index_f5_2_tsv = open_shelve(os.path.join(index_dirpath, index_name), 'r')

    if not f5_path in index_f5_2_tsv.keys():
        printlog_error_time(
            "Source FAST5 file `{}` not found in index".format(f5_path))
        printlog_error("Try to rebuild index")
        platf_depend_exit(1)
    # end if

    for tsv_path in index_f5_2_tsv[f5_path].keys():

        read_names = index_f5_2_tsv[f5_path][tsv_path]
        taxonomy_path = os.path.join(tax_annot_res_dir, "taxonomy",
                                     "taxonomy.tsv")
        resfile_lines = configure_resfile_lines(tsv_path, sens, taxonomy_path)

        for read_name in read_names:
            try:
                hit_names, *vals_to_filter = resfile_lines[sys.intern(
                    fmt_read_id(read_name)[1:])]
            except KeyError:
                printlog_error_time("Error: missing taxonomic annotation info for read `{}`"\
                    .format(fmt_read_id(read_name)[1:]))
                printlog_error(
                    "It is stored in `{}` FAST5 file".format(f5_path))
                printlog_error(
                    "Try to make new index file (press ENTER on corresponding prompt)."
                )
                printlog_error(
                    "Or, if does not work for you, make sure that taxonomic annotation info \
for this read is present in one of TSV files generated by `barapost-prober.py` and `barapost-local.py`."
                )
                index_f5_2_tsv.close()
                platf_depend_exit(1)
            # end try

            if not QL_filter(vals_to_filter):
                # Get name of result FASTQ file to write this read in
                if QL_trash_fpath not in srt_file_dict.keys():
                    srt_file_dict = update_file_dict(srt_file_dict,
                                                     QL_trash_fpath)
                # end if
                f5_cpy_func(from_f5, read_name, srt_file_dict[QL_trash_fpath])
                QL_seqs_fail += 1
            elif not align_filter(vals_to_filter):
                # Get name of result FASTQ file to write this read in
                if align_trash_fpath not in srt_file_dict.keys():
                    srt_file_dict = update_file_dict(srt_file_dict,
                                                     align_trash_fpath)
                # end if
                f5_cpy_func(from_f5, read_name,
                            srt_file_dict[align_trash_fpath])
                align_seqs_fail += 1
            else:
                for hit_name in hit_names.split(
                        "&&"
                ):  # there can be multiple hits for single query sequence
                    # Get name of result FASTQ file to write this read in
                    binned_file_path = os.path.join(
                        outdir_path, "{}.fast5".format(hit_name))
                    if binned_file_path not in srt_file_dict.keys():
                        srt_file_dict = update_file_dict(
                            srt_file_dict, binned_file_path)
                    # end if
                    f5_cpy_func(from_f5, read_name,
                                srt_file_dict[binned_file_path])
                # end for
                seqs_pass += 1
            # end if
        # end for

    from_f5.close()
    index_f5_2_tsv.close()

    # Close all binned files
    for file_obj in filter(lambda x: not x is None, srt_file_dict.values()):
        file_obj.close()
    # end for

    sys.stdout.write('\r')
    printlog_info_time("File `{}` is binned.".format(
        os.path.basename(f5_path)))
    printn(" Working...")

    return (seqs_pass, QL_seqs_fail, align_seqs_fail)
Ejemplo n.º 13
0
    return parser.parse_args()

if __name__ == '__main__':
    args = parse_arguments()
    print_message(message='EXAMPLE: XOR dataset')
    print_param(description='Number of experiment observations', param_str=str(args.n_obs))
    print_param(description='Initial number of hidden neurons', param_str=str(args.hidden_structure))
    print_param(description='Required accuracy', param_str=str(args.req_acc))

    params_str = '_hs'+str(args.hidden_structure)+'_ra'+str(args.req_acc).replace('.', '')+'_no'+str(args.n_obs)
    if args.generate:
        stats_data = list()
        for i_obs in range(1, args.n_obs+1):
            print_message(message='XOR experiment, observation '+str(i_obs)+'/'+str(args.n_obs))
            net = FeedForwardNet(hidden=args.hidden_structure, tf_name='Sigmoid')
            dataset = open_shelve('../examples/xor/dataset_xor.ds', 'c')
            net.fit(x=dataset['x'], y=dataset['y'], x_val=dataset['x_val'], y_val=dataset['y_val'], learning_rate=0.4,
                    n_epoch=50, req_acc=1.0)
            res = net.evaluate(x=dataset['x_test'], y=dataset['y_test'])
            print_message(message='Evaluation on test data after training:')
            print_param(description='Accuracy', param_str=str(res[1]))
            print_param(description='Error', param_str=str(res[0]))
            if net.learning.stats['t_acc'][-1] < 0.9:
                print 'Skipping observation'
                continue
            net.prune(req_acc=args.req_acc, req_err=0.05, n_epoch=50, levels=args.levels)
            res = net.evaluate(x=dataset['x_test'], y=dataset['y_test'])
            print_message(message='Evaluation on test data after pruning:')
            print_param(description='Accuracy', param_str=str(res[1]))
            print_param(description='Error', param_str=str(res[0]))
            stats_data.append(net.opt['pruning'].stats)
Ejemplo n.º 14
0
                param_str=str(args.n_obs))
    print_param(description='Initial number of hidden neurons',
                param_str=str(args.hidden_structure))
    print_param(description='Required accuracy', param_str=str(args.req_acc))

    params_str = '_hs' + str(args.hidden_structure) + '_ra' + str(
        args.req_acc).replace('.', '') + '_no' + str(args.n_obs)
    if args.generate:
        stats_data = list()
        for i_obs in range(1, args.n_obs + 1):
            print_message(message='SPEECH experiment, observation ' +
                          str(i_obs) + '/' + str(args.n_obs))
            net = FeedForwardNet(hidden=args.hidden_structure,
                                 tf_name='Sigmoid')
            dataset = open_shelve(
                '../examples/speech/dataset_speech_bs2_cs5_nf40_ds811_nr200.ds',
                'c')
            net.fit(x=dataset['x'],
                    y=dataset['y'],
                    x_val=dataset['x_val'],
                    y_val=dataset['y_val'],
                    learning_rate=0.07,
                    n_epoch=50,
                    req_acc=0.7,
                    batch_size=10)
            res = net.evaluate(x=dataset['x_test'], y=dataset['y_test'])
            print_message(message='Evaluation on test data after training:')
            print_param(description='Accuracy', param_str=str(res[1]))
            print_param(description='Error', param_str=str(res[0]))
            if net.learning.stats['t_acc'][-1] < args.req_acc:
                print 'Skipping observation'
Ejemplo n.º 15
0
#!/usr/bin/env python
# -*- coding: utf-8 -*-

from kitt_net import FeedForwardNet
from shelve import open as open_shelve
import numpy as np
np.set_printoptions(threshold=np.nan)

if __name__ == '__main__':
    net = FeedForwardNet(hidden=[50], tf_name='Sigmoid')
    dataset = open_shelve('../examples/speech/dataset_speech_10K_bs2_cs5.ds')
    net.fit(x=dataset['x'],
            y=dataset['y'],
            x_val=dataset['x_val'],
            y_val=dataset['y_val'],
            learning_rate=0.01,
            n_epoch=400,
            req_acc=1.0,
            batch_size=10,
            dump_name='../examples/speech/net_speech_10K_bs2_cs5.net')
    #print 'Net structure to be dumped:', net.structure, '| Number of synapses:', net.count_synapses()
    #net.dump('../examples/speech/net_speech_5000.net')
    #net.prune(req_acc=0.6, req_err=0.05, n_epoch=5, levels=(75, 50, 30, 20, 10, 7, 5, 3, 2, 1, 0))
    #print 'Net structure to be dumped:', net.structure, '| Number of synapses:', net.count_synapses()
    #net.dump('../examples/speech/net_speech_5000_pruned.net')
    dataset.close()
Ejemplo n.º 16
0
                        help='App. to the filename')
    return parser.parse_args()

if __name__ == '__main__':

    args = parse_arguments()
    dataset_dir = '../cache/datasets/'+args.task+'/'+args.dataset+'.ds'
    learning_rate = args.learning_rate
    n_iter = args.n_iter

    destination_name = args.dataset+'&'+str(learning_rate)+'_'+str(n_iter)+'_'+str(args.structure)+'_'+args.name_appendix
    destination = '../cache/trained/kitt_'+destination_name+'.net'

    ''' Loading dataset and training '''
    print '\n\n ## Loading dataset', args.dataset, '...'
    dataset = open_shelve(dataset_dir, 'r')

    net_structure = [len(dataset['x']['training'][0])]+args.structure+[len(np.unique(dataset['y']['training']))]
    net = NeuralNet(program=None, name=str(net_structure), structure=net_structure)
    net.learning = BackPropagation(program=None, net=net, learning_rate=learning_rate, n_iter=n_iter)

    print '\n\n ## Fitting the training data...'
    acc_list, err_list, time_list = net.fit(X=dataset['x']['training'], y=dataset['y']['training'],
                                            X_val=dataset['x']['validation'], y_val=dataset['y']['validation'])

    ''' Getting results on a testing set '''
    print '\n\n ## Testing...'
    y_pred = net.predict(dataset['x']['testing'])

    c_accuracy = accuracy_score(y_true=np.array(dataset['y']['testing']), y_pred=y_pred)
    c_report = classification_report(np.array(dataset['y']['testing']), y_pred)
    the_y['validation'] = va_d[1]
    the_x['testing'] = [np.reshape(x, (784, 1)) for x in te_d[0]]
    the_y['testing'] = te_d[1]
    return the_x, the_y


def vectorized_result(j, output_neurons):
    e = np.zeros((output_neurons, 1))
    e[j] = 1.0
    return e

if __name__ == '__main__':
    args = parse_arguments()
    destination = '../cache/datasets/mnist/'+args.destination_name+'.ds'

    ''' Loading data '''
    print '\n\n ## Loading data...'
    x, y = load_data_wrapper('../cache/downloads')
    print 'Got dataset:', len(x['training']), ':', len(x['validation']), ':', len(x['testing'])

    ''' Saving dataset '''
    print '\n\n ## Saving dataset as', destination

    dataset = open_shelve(destination, 'c')
    dataset['x'] = x
    dataset['y'] = y
    dataset['size'] = (len(x['training']), len(x['validation']), len(x['testing']))
    dataset.close()

    print 'Dataset dumped.'