Пример #1
0
def predict_main():
    # parse the command line arguments
    parser = argparsers.predict_argsparser()
    args = parser.parse_args()

    # check if the output directory exists
    if not os.path.exists(args.output_dir):
        logging.error("Directory {} does not exist".format(args.output_dir))

        return

    if args.automate_filenames:
        # create a new directory using current date/time to store the
        # predictions and logs
        date_time_str = local_datetime_str(args.time_zone)
        pred_dir = '{}/{}'.format(args.output_dir, date_time_str)
        os.mkdir(pred_dir)
    elif os.path.isdir(args.output_dir):
        pred_dir = args.output_dir
    else:
        logging.error("Directory does not exist {}.".format(args.output_dir))
        return

    # filename to write debug logs
    logfname = "{}/predict.log".format(pred_dir)

    # set up the loggers
    logger.init_logger(logfname)

    # make sure the input_data json file exists
    if not os.path.isfile(args.input_data):
        raise quietexception.QuietException(
            "File not found: {} OR you may have accidentally "
            "specified a directory path.".format(args.input_data))

    # load the json file
    with open(args.input_data, 'r') as inp_json:
        try:
            #: dictionary of tasks for training
            input_data = json.loads(inp_json.read())
        except json.decoder.JSONDecodeError:
            raise quietexception.QuietException(
                "Unable to load json file {}. Valid json expected. "
                "Check the file for syntax errors.".format(args.input_data))

    logging.info("INPUT DATA -\n{}".format(input_data))

    # predict
    logging.info("Loading {}".format(args.model))
    with CustomObjectScope(
        {'MultichannelMultinomialNLL': MultichannelMultinomialNLL}):

        predict(args, input_data, pred_dir)
Пример #2
0
def predict_main():
    # parse the command line arguments
    parser = argparsers.fastpredict_argsparser()
    args = parser.parse_args()

    # check if the output directory exists
    if not os.path.exists(args.output_dir):
        logging.error("Directory {} does not exist".format(args.output_dir))

        return

    if args.automate_filenames:
        # create a new directory using current date/time to store the
        # predictions and logs
        date_time_str = local_datetime_str(args.time_zone)
        pred_dir = '{}/{}'.format(args.output_dir, date_time_str)
        os.mkdir(pred_dir)
    elif os.path.isdir(args.output_dir):
        pred_dir = args.output_dir
    else:
        logging.error("Directory does not exist {}.".format(args.output_dir))
        return

    # filename to write debug logs
    logfname = "{}/predict.log".format(pred_dir)

    # set up the loggers
    logger.init_logger(logfname)

    # predict
    logging.info("Loading {}".format(args.model))
    with CustomObjectScope({
            'MultichannelMultinomialNLL': MultichannelMultinomialNLL,
            'tf': tf,
            'CustomMeanSquaredError': CustomMeanSquaredError,
            'AttributionPriorModel': AttributionPriorModel,
            'CustomModel': CustomModel
    }):

        predict(args, pred_dir)
Пример #3
0
def train_and_validate_ksplits(input_data,
                               model_arch_name,
                               model_arch_params_json,
                               output_params,
                               genome_params,
                               batch_gen_params,
                               hyper_params,
                               parallelization_params,
                               splits,
                               bias_input_data=None,
                               bias_model_arch_params_json=None,
                               adjust_bias_model_logcounts=False,
                               is_background_model=False,
                               mnll_loss_sample_weight=1.0,
                               mnll_loss_background_sample_weight=0.0):
    """
        Train and validate on one or more train/val splits
        
        Args:
            input_data (str): path to the tasks json file
            
            model_arch_name (str): name of the model definition 
                function in the model_archs module
                
            model_arch_params_json (str): path to json file containing
                model architecture params
            
            output_params (dict): dictionary containing output 
                parameters
            
            genome_params (dict): dictionary containing genome
                parameters
            
            batch_gen_params (dict): dictionary containing batch
                generation parameters
            
            hyper_params (dict): dictionary containing containing 
                training & validation hyper parameters
            
            parallelization_params (dict): dictionary containing
                parameters for parallelization options
            
            splits (str): path to the json file containing train & 
                validation splits
            
            bias_input_data (str): path to the bias tasks json file

            bias_model_arch_params_json (str): path to json file 
                containing bias model architecture params
                
            is_background_model (boolean): True if a background model
                is to be trained using 'background_loci' samples from
                the input json
                
            mnll_loss_sample_weight (float): weight for each (foreground)
                training sample for computing mnll loss

            mnll_loss_background_sample_weight (float): weight for each
                background sample for computing mnll loss
    """

    # list of chromosomes after removing the excluded chromosomes
    chroms = set(genome_params['chroms']).difference(
        set(genome_params['exclude_chroms']))

    # list of models from all of the splits
    models = []

    # run training for each validation/test split
    num_splits = len(list(splits.keys()))
    for i in range(num_splits):

        if output_params['automate_filenames']:
            # create a new directory using current date/time to store the
            # model, the loss history and logs
            date_time_str = local_datetime_str(output_params['time_zone'])
            model_dir = '{}/{}_split{:03d}'.format(output_params['output_dir'],
                                                   date_time_str, i)
            os.mkdir(model_dir)
            split_tag = None
        elif os.path.isdir(output_params['output_dir']):
            model_dir = output_params['output_dir']
            split_tag = "split{:03d}".format(i)
        else:
            logging.error("Directory does not exist {}.".format(
                output_params['output_dir']))
            return

        # filename to write debug logs
        logfname = '{}/trainer.log'.format(model_dir)
        # set up logger for main procecss
        logger.init_logger(logfname)

        # train & validation chromosome split
        if 'val' not in splits[str(i)]:
            logging.error("KeyError: 'val' required for split {}".format(i))
            return
        val_chroms = splits[str(i)]['val']
        # if 'train' key is present
        if 'train' in splits[str(i)]:
            train_chroms = splits[str(i)]['train']
        # if 'test' key is present but train is not
        elif 'test' in splits[str(i)]:
            test_chroms = splits[str(i)]['test']
            # take the set difference of the whole list of
            # chroms with the union of val and test
            train_chroms = list(
                chroms.difference(set(val_chroms + test_chroms)))
        else:
            # take the set difference of the whole list of
            # chroms with val
            train_chroms = list(chroms.difference(val_chroms))

        logging.info("Split #{}".format(i))
        logging.info("Train: {}".format(train_chroms))
        logging.info("Val: {}".format(val_chroms))

        # Start training for the split in a separate process
        # This ensures that all resources are freed, when the
        # process terminates, & available for training the next split
        # Mitigates the problem where training subsequent splits
        # is considerably slow
        logging.debug("Split {}: Creating training process".format(i))
        p = mp.Process(target=train_and_validate,
                       args=[
                           input_data, model_arch_name, model_arch_params_json,
                           output_params, genome_params, batch_gen_params,
                           hyper_params, parallelization_params, train_chroms,
                           val_chroms, model_dir, bias_input_data,
                           bias_model_arch_params_json,
                           adjust_bias_model_logcounts, is_background_model,
                           mnll_loss_sample_weight,
                           mnll_loss_background_sample_weight, split_tag
                       ])
        p.start()

        # wait for the process to finish
        p.join()
Пример #4
0
def train_and_validate(input_data,
                       model_arch_name,
                       model_arch_params_json,
                       output_params,
                       genome_params,
                       batch_gen_params,
                       hyper_params,
                       parallelization_params,
                       train_chroms,
                       val_chroms,
                       model_dir,
                       bias_input_data=None,
                       bias_model_arch_params_json=None,
                       adjust_bias_model_logcounts=False,
                       is_background_model=False,
                       mnll_loss_sample_weight=1.0,
                       mnll_loss_background_sample_weight=0.0,
                       suffix_tag=None):
    """
        Train and validate on a single train and validation set
        
        Note: the list & description for each of the required keys
            in all of the json parameter files passed to this 
            fucntion can be found here:
            http://
        
        Args:
            input_data (str): path to the tasks json file
            
            model_arch_name (str): name of the model definition 
                function in the model_archs module
                
            model_arch_params_json (str): path to json file containing
                model architecture params

            output_params (dict): dictionary containing output 
                parameters
            
            genome_params (dict): dictionary containing genome
                parameters
            
            batch_gen_params (dict): dictionary containing batch
                generation parameters
            
            hyper_params (dict): dictionary containing containing 
                training & validation hyper parameters
            
            parallelization_params (dict): dictionary containing
                parameters for parallelization options
            
            train_chroms (list): list of training chromosomes
            
            val_chroms (list): list of validation chromosomes
            
            model_dir (str): the path to the output directory
            
            bias_input_data (str): path to the bias tasks json file

            bias_model_arch_params_json (str): path to json file 
                containing bias model architecture params
                
            adjust_bias_model_logcounts (boolean): True if you need to
                adjust the the weights of the final Dense layer that 
                predicts the logcounts when training a bias model for 
                chromatin accessibility
            
            is_background_model (boolean): True if a background model
                is to be trained using 'background_loci' samples from
                the input json
                
            mnll_loss_sample_weight (float): weight for each (foreground)
                training sample for computing mnll loss

            mnll_loss_background_sample_weight (float): weight for each
                background sample for computing mnll loss
                
            suffix_tag (str): optional tag to add as a suffix to files
                (model, log, history & config params files) created in
                the model directory
         
         Returns:
             keras.models.Model
             
    """

    # make sure the input_data json file exists
    if not os.path.isfile(input_data):
        raise NoTracebackException("File not found: {} ".format(input_data))

    # load the json file
    with open(input_data, 'r') as inp_json:
        try:
            tasks = json.loads(inp_json.read())
            # since the json has keys as strings, we convert the
            # top level keys to int so we can used them later for
            # indexing
            #: dictionary of tasks for training
            tasks = {int(k): v for k, v in tasks.items()}
        except json.decoder.JSONDecodeError:
            raise NoTracebackException(
                "Unable to load json file {}. Valid json expected. "
                "Check the file for syntax errors.".format(input_data))

    # make sure the params json file exists
    if not os.path.isfile(model_arch_params_json):
        raise NoTracebackException(
            "File not found: {} ".format(model_arch_params_json))

    # load the params json file
    with open(model_arch_params_json, 'r') as inp_json:
        try:
            model_arch_params = json.loads(inp_json.read())
        except json.decoder.JSONDecodeError:
            raise NoTracebackException(
                "Unable to load json file {}. Valid json expected. "
                "Check the file for syntax errors.".format(
                    model_arch_params_json))

    if bias_input_data is not None:
        # load the bias json file
        with open(bias_input_data, 'r') as inp_json:
            try:
                bias_tasks = json.loads(inp_json.read())
                # since the json has keys as strings, we convert the
                # top level keys to int so we can used them later for
                # indexing
                #: dictionary of tasks for training
                bias_tasks = {int(k): v for k, v in bias_tasks.items()}
            except json.decoder.JSONDecodeError:
                raise NoTracebackException(
                    "Unable to load json file {}. Valid json expected. "
                    "Check the file for syntax errors.".format(
                        bias_input_data))

    if bias_model_arch_params_json is not None:
        # make sure the bias params json file exists
        if not os.path.isfile(bias_model_arch_params_json):
            raise NoTracebackException(
                "File not found: {} ".format(bias_model_arch_params_json))

        # load the bias params json file
        with open(bias_model_arch_params_json, 'r') as inp_json:
            try:
                bias_model_arch_params = json.loads(inp_json.read())
            except json.decoder.JSONDecodeError:
                raise NoTracebackException(
                    "Unable to load json file {}. Valid json expected. "
                    "Check the file for syntax errors.".format(
                        bias_model_arch_params_json))

    # filename to write debug logs
    if suffix_tag is not None:
        logfname = '{}/trainer_{}.log'.format(model_dir, suffix_tag)
    else:
        logfname = '{}/trainer.log'.format(model_dir)

    # we need to initialize the logger for each process
    logger.init_logger(logfname)

    # parameters that are specific to the training batch generation
    # process
    train_batch_gen_params = batch_gen_params
    train_batch_gen_params['mode'] = 'train'

    # parameters that are specific to the validation batch generation
    # process. For validation we dont use jitter, reverse complement
    # augmentation and negative sampling
    val_batch_gen_params = copy.deepcopy(batch_gen_params)
    val_batch_gen_params['max_jitter'] = 0
    val_batch_gen_params['rev_comp_aug'] = False
    val_batch_gen_params['negative_sampling_rate'] = 0.0
    val_batch_gen_params['mode'] = 'val'

    # get the corresponding batch generator class for this model
    sequence_generator_class_name = generators.find_generator_by_name(
        batch_gen_params['sequence_generator_name'])
    logging.info("SEQGEN Class Name: {}".format(sequence_generator_class_name))
    BatchGenerator = getattr(generators, sequence_generator_class_name)

    # instantiate the batch generator class for training
    train_gen = BatchGenerator(
        input_data,
        train_batch_gen_params,
        genome_params['reference_genome'],
        genome_params['chrom_sizes'],
        train_chroms,
        num_threads=parallelization_params['threads'],
        batch_size=hyper_params['batch_size'],
        background_only=is_background_model,
        foreground_weight=mnll_loss_sample_weight,
        background_weight=mnll_loss_background_sample_weight)

    # instantiate the batch generator class for validation
    val_gen = BatchGenerator(
        input_data,
        val_batch_gen_params,
        genome_params['reference_genome'],
        genome_params['chrom_sizes'],
        val_chroms,
        num_threads=parallelization_params['threads'],
        batch_size=hyper_params['batch_size'],
        background_only=is_background_model,
        foreground_weight=mnll_loss_sample_weight,
        background_weight=mnll_loss_background_sample_weight)

    # we need to calculate the number of training steps and
    # validation steps in each epoch, fit/evaluate requires this
    # to determine the end of an epoch
    train_steps = train_gen.len()
    val_steps = val_gen.len()

    # we may have to reduce the --threads sometimes
    # if the peak file has very few peaks, so we need to
    # check if these numbers will be 0
    logging.info("TRAINING STEPS - {}".format(train_steps))
    logging.info("VALIDATION STEPS - {}".format(val_steps))

    # get an instance of the model
    logging.debug("New {} model".format(model_arch_name))
    get_model = getattr(archs, model_arch_name)
    if model_arch_name == "BPNet_ATAC_DNase":
        model = get_model(tasks,
                          bias_tasks,
                          model_arch_params,
                          bias_model_arch_params,
                          name_prefix="main")
    else:
        model = get_model(tasks, model_arch_params, name_prefix="main")

    # print out the model summary
    model.summary()

    # compile the model
    logging.debug("Compiling model")
    logging.info("loss weights - {}".format(model_arch_params['loss_weights']))
    model.compile(Adam(learning_rate=hyper_params['learning_rate']),
                  loss=[
                      MultichannelMultinomialNLL(
                          train_gen._total_signal_tracks),
                      CustomMeanSquaredError()
                  ],
                  loss_weights=model_arch_params['loss_weights'])

    # begin time for training
    t1 = time.time()

    # track training losses, validation losses and start & end
    # times
    custom_history = {
        'learning_rate': {},
        'loss': {},
        'batch_loss': {},
        'profile_predictions_loss': {},
        'logcounts_predictions_loss': {},
        'attribution_prior_loss': {},
        'val_loss': {},
        'val_batch_loss': {},
        'val_profile_predictions_loss': {},
        'val_logcounts_predictions_loss': {},
        'val_attribution_prior_loss': {},
        'start_time': {},
        'end_time': {},
        'elapsed': {}
    }

    # we maintain a separate list to track validation losses to make it
    # easier for early stopping
    val_losses = []

    # track validation losses for learning rate update
    val_losses_lr = []

    # track best loss so we can restore weights
    best_loss = 1e6

    # keep a copy of the best weights
    best_weights = None

    # the epoch with the best validation loss
    best_epoch = 1

    # start training
    logging.debug("Training started ...")
    for epoch in range(hyper_params['epochs']):
        # First, let's train for one epoch
        logging.info("Training Epoch {}".format(epoch + 1))
        train_start_time = time.time()
        custom_history['learning_rate'][str(epoch + 1)] = \
            model.optimizer.learning_rate.numpy()
        custom_history['start_time'][str(epoch + 1)] = \
            time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime(train_start_time))
        # training generator function that will be passed to fit
        train_generator = train_gen.gen()
        history = model.fit(train_generator,
                            epochs=1,
                            steps_per_epoch=train_steps)
        train_end_time = time.time()

        # record the training losses
        for key in history.history:
            custom_history[key][str(epoch + 1)] = history.history[key][0]

        # Then, we evaluate on the validation set
        logging.info("Validation Epoch {}".format(epoch + 1))
        val_start_time = time.time()
        # validation generator function that will be passed to evaluate
        val_generator = val_gen.gen()
        val_loss = model.evaluate(val_generator,
                                  steps=val_steps,
                                  return_dict=True)
        val_losses.append(val_loss['loss'])
        val_losses_lr.append(val_loss['loss'])
        val_end_time = time.time()
        custom_history['end_time'][str(epoch + 1)] = \
            time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime(val_end_time))
        custom_history['elapsed'][str(epoch + 1)] = \
            val_end_time - train_start_time

        # record the validation losses
        for key in val_loss:
            custom_history['val_' + key][str(epoch + 1)] = \
                val_loss[key]

        # update best weights and loss
        if val_loss['loss'] < best_loss:
            best_weights = model.get_weights()
            best_loss = val_loss['loss']
            best_epoch = epoch + 1

        # check if early stopping criteria are satisfied
        if early_stopping_check(
                val_losses,
                patience=hyper_params['early_stopping_patience'],
                min_delta=hyper_params['early_stopping_min_delta']):

            # restore best weights
            logging.info(
                "Restoring best weights from epoch {}".format(best_epoch))
            model.set_weights(best_weights)
            break

        # lower learning rate if criteria are satisfied
        current_lr = model.optimizer.learning_rate.numpy()
        new_lr = reduce_lr_on_plateau(
            val_losses_lr,
            current_lr,
            factor=hyper_params['lr_reduction_factor'],
            patience=hyper_params['reduce_lr_on_plateau_patience'],
            min_lr=hyper_params['min_learning_rate'])

        # reset the validation losses tracker for learning rate update
        if new_lr != current_lr:
            val_losses_lr = [val_losses_lr[-1]]

        # set the new learning rate
        model.optimizer.learning_rate.assign(new_lr)

        # display current learning rate and training status
        logging.info(
            "Current learning rate - {:5f}, Stop Training - {}".format(
                model.optimizer.learning_rate.numpy(), model.stop_training))

    # end time for training
    t2 = time.time()
    logging.info("Total Elapsed Time: {}".format(t2 - t1))

    # base model filename
    if output_params['automate_filenames']:
        # get random alphanumeric tag for model
        model_tag = getAlphaNumericTag(output_params['tag_length'])
        model_fname = "{}/{}".format(model_dir, model_tag)
    elif output_params['model_output_filename'] is not None:
        model_fname = "{}/{}".format(model_dir,
                                     output_params['model_output_filename'])
    else:
        model_fname = "{}/model".format(model_dir)

    # add suffix tag to model name
    if suffix_tag is not None:
        model_fname += "_{}".format(suffix_tag)

    # extension
    model_fname += ".h5"

    # save HDF5 model file
    model.save(model_fname)
    logging.info("Finished saving model: {}".format(model_fname))

    if adjust_bias_model_logcounts:
        # all peaks and non peaks
        loci = train_gen.get_samples()

        # non-peaks only
        nonpeaks_loci = loci[loci['weight'] == 0.0]

        if len(nonpeaks_loci) == 0:
            logging.info("Non peaks length is 0. Bias model adjustment "
                         "aborted.")
        else:
            # reference file to fetch sequences
            fasta_ref = pyfaidx.Fasta(genome_params['reference_genome'])

            #get all the bigWigs and peaks from the input_data
            bigWigs = []
            for task in tasks:
                if 'signal' in tasks[task].keys():
                    bigWigs.extend(tasks[task]['signal']["source"])

            # open each bigwig and add file pointers to a list
            fbigWigs = []
            for bigWig in bigWigs:
                fbigWigs.append(pyBigWig.open(bigWig))

            # get sequences and logcounts
            logging.info("Fetching non peak sequences and counts ...")
            sequences = []
            logcounts = []
            for _, row in nonpeaks_loci.iterrows():
                # chrom, start and end
                chrom = row['chrom']
                start = row['pos'] - (batch_gen_params['input_seq_len'] // 2)
                end = row['pos'] + (batch_gen_params['input_seq_len'] // 2)

                # get the sequences
                seq = fasta_ref[chrom][start:end].seq.upper()

                # collect all the sequences into a list
                sequences.append(seq)

                # get the total counts
                for i in range(len(fbigWigs)):
                    bw = fbigWigs[i]
                    logcounts.append(
                        np.log(
                            np.sum(np.nan_to_num(bw.values(chrom, start, end)))
                            + 1))

            fasta_ref.close()

            # one hot encode the sequences
            seqs = sequtils.one_hot_encode(sequences,
                                           batch_gen_params['input_seq_len'])

            adjusted_model = adjust_bias_logcounts(model, seqs,
                                                   np.array(logcounts),
                                                   "logcounts_predictions")

            # saving adjusted model
            model_fname = model_fname.replace('.h5', '.adjusted.h5')

            # save HDF5 model file
            adjusted_model.save(model_fname)
            logging.info(
                "Finished saving adjusted model: {}".format(model_fname))

    # save history to json:
    # Step 1. convert the custom history dict to a pandas DataFrame:
    hist_df = pd.DataFrame(custom_history)

    # file name for json file
    hist_json = model_fname.replace('.h5', '.history.json')

    # Step 2. write the dataframe to json
    with open(hist_json, mode='w') as f:
        hist_df.to_json(f)

    logging.info("Finished saving training and validation history: {}".format(
        hist_json))

    # write all the command line arguments to a json file
    # & include the number of epochs the training lasted for, and the
    # validation and testchroms
    config_file = '{}/config'.format(model_dir)
    # add suffix tag to model name
    if suffix_tag is not None:
        config_file += "_{}".format(suffix_tag)
    # extension
    config_file += ".json"

    with open(config_file, 'w') as fp:
        config = {}
        config['input_data'] = input_data
        config['output_params'] = output_params
        config['genome_params'] = genome_params
        config['batch_gen_params'] = batch_gen_params
        config['hyper_params'] = hyper_params
        config['parallelization_params'] = parallelization_params

        # the number of epochs the training lasted
        config['training_epochs'] = epoch + 1

        # the epoch with best validation loss
        config['best_epoch'] = best_epoch

        config['train_chroms'] = train_chroms
        config['val_chroms'] = val_chroms
        config['model_filename'] = model_fname

        json.dump(config, fp)

    return model
Пример #5
0
def outliers_main():

    # parse the command line arguments
    parser = outliers_argsparser()
    args = parser.parse_args()

    # filename to write debug logs
    logfname = "outliers.log"

    # set up the loggers
    logger.init_logger(logfname)

    # check if the input json file exists
    if not os.path.exists(args.input_data):
        raise NoTracebackException("File {} does not exist".format(
            args.input_data))

    # check if the chrom sizes file exists
    if not os.path.exists(args.chrom_sizes):
        raise NoTracebackException("File {} does not exist".format(
            args.chrom_sizes))

    # load the chrom sizes into a dataframe
    chrom_sizes_df = pd.read_csv(args.chrom_sizes,
                                 sep='\t',
                                 header=None,
                                 names=['chrom', 'size'])

    # load the tasks json file
    with open(args.input_data, 'r') as inp_json:
        try:
            tasks = json.loads(inp_json.read())
        except json.decoder.JSONDecodeError:
            raise NoTracebackException(
                "Unable to load json file {}. Valid json expected. "
                "Check the file for syntax errors.".format(tasks_json))

    # get all peaks for a given task in 10 column ENCODE narrowPeak
    # format dataframe
    peaks_df = getPeakPositions(tasks[args.task],
                                args.chroms,
                                chrom_sizes_df,
                                args.sequence_len // 2,
                                drop_duplicates=True)

    # if a global sample weight is specified set it here
    if args.global_sample_weight is not None:
        peaks_df['weight'] = args.global_sample_weight

    # remove peaks that fall within blacklist regions
    if args.blacklist != None:
        # check if the blacklist file exists
        if not os.path.exists(args.blacklist):
            raise NoTracebackException(
                "Blacklist file {} does not exist".format(args.blacklist))

        blacklist_df = pd.read_csv(args.blacklist,
                                   sep='\t',
                                   names=['chrom', 'st', 'e'])

        logging.info("Filtering blacklist peaks ...")
        logging.info("old size {}".format(len(peaks_df)))
        peaks_df = remove_blacklist_peaks(peaks_df, blacklist_df)
        logging.info("new size {}".format(len(peaks_df)))

    # open all the signal bigWigs for reading
    signal_files = []
    for signal_file in tasks[args.task]['signal']['source']:
        # check if the bigWig file exists
        if not os.path.exists(signal_file):
            raise NoTracebackException(
                "BigWig file {} does not exist".format(signal_file))

        signal_files.append(pyBigWig.open(signal_file))

    # counts dictionary maintains a list of counts for each
    # peak for each of the signal files
    # we will use this list to create a counts column for each of the
    # signal files
    counts = {}
    for signal_file in signal_files:
        counts[signal_file] = []

    # iterate through all peaks and read values from the bigWig files
    logging.info("Computing counts for each peak")
    for _, row in tqdm(peaks_df.iterrows(), desc='peaks', total=len(peaks_df)):
        chrom = row['chrom']
        start = row['start']
        end = row['end']

        for signal_file in signal_files:
            counts[signal_file].append(
                np.sum(np.nan_to_num(signal_file.values(chrom, start, end))))

    # add a new counts column to the peaks dataframe for each
    # signal file
    for signal_file in signal_files:
        peaks_df[signal_file] = counts[signal_file]

    # average the counts across the signal files
    peaks_df['avg_counts'] = peaks_df[signal_files].mean(axis=1)

    # sort the dataframe in ascending order of counts
    peaks_df = peaks_df.sort_values(by=['avg_counts'])

    # compute the quantile value
    counts = peaks_df['avg_counts'].values
    nth_quantile = np.quantile(counts, args.quantile)
    logging.info("{} quantile {}".format(args.quantile, nth_quantile))

    # get index of quantile value
    quantile_idx = abs(counts - nth_quantile).argmin()
    logging.info("quantile idx {}".format(quantile_idx))

    # scale value at quantile index
    scaled_value = counts[quantile_idx] * args.quantile_value_scale_factor
    logging.info("scaled_value {}".format(scaled_value))

    # check if any of the counts are above the scaled_value
    if np.sum(counts > scaled_value) > 0:
        # index of values greater than scaled_value
        max_idx = np.argmax(counts > scaled_value)
        logging.info("max_idx {}".format(max_idx))

        # trimmed data frame with outliers removed
        logging.info("original size {}".format(len(peaks_df)))
        peaks_df = peaks_df[:max_idx]
        logging.info("new size {}".format(len(peaks_df)))
    else:
        logging.info("No outliers found based on criteria. "
                     "Keeping original loci.")

    # save the new dataframe
    logging.info("Saving output bed file ... {}".format(args.output_bed))
    peaks_df = peaks_df[[
        'chrom', 'st', 'e', 'name', 'weight', 'strand', 'signal', 'p', 'q',
        'summit'
    ]]
    peaks_df.to_csv(args.output_bed, header=None, sep='\t', index=False)
Пример #6
0
def train_and_validate(input_params,
                       output_params,
                       genome_params,
                       batch_gen_params,
                       hyper_params,
                       parallelization_params,
                       network_params,
                       train_chroms,
                       val_chroms,
                       model_dir,
                       suffix_tag=None):
    """
        Train and validate on a single train and validation set
        
        Note: the list & description for each of the required keys
            in all of the json parameter files passed to this 
            fucntion can be found here:
            http://
        
        Args:
            input_params (str): path to json file containing input
                parameters
            
            output_params (str): path to json file containing output
                parameters
            
            genome_params (str): path to json file containing genome
                parameters
            
            batch_gen_params (str): path to json file containing batch
                generation parameters
            
            hyper_params (str): path to json file containing training &
                validation hyper parameters
            
            parallelization_params (str): path to json file containing
                parameters for parallelization options
            
            network_params (str): path to json file containing
                parameters specific to the deep learning architecture
            
            train_chroms (list): list of training chromosomes
            
            val_chroms (list): list of validation chromosomes
            
            model_dir (str): the path to the output directory
            
            suffix_tag (str): optional tag to add as a suffix to files
                (model, log, history & config params files) created in
                the model directory
         
         Returns:
             keras.models.Model
             
    """

    # filename to write debug logs
    if suffix_tag is not None:
        logfname = '{}/trainer_{}.log'.format(model_dir, suffix_tag)
    else:
        logfname = '{}/trainer.log'.format(model_dir)

    # we need to initialize the logger for each process
    logger.init_logger(logfname)

    # parameters that are specific to the training batch generation
    # process
    train_batch_gen_params = batch_gen_params
    train_batch_gen_params['mode'] = 'train'

    # parameters that are specific to the validation batch generation
    # process. For validation we dont use jitter, reverse complement
    # augmentation and negative sampling
    val_batch_gen_params = copy.deepcopy(batch_gen_params)
    val_batch_gen_params['max_jitter'] = 0
    val_batch_gen_params['rev_comp_aug'] = False
    val_batch_gen_params['negative_sampling_rate'] = 0.0
    val_batch_gen_params['mode'] = 'val'

    # get the corresponding batch generator class for this model
    sequence_generator_class_name = generators.find_generator_by_name(
        batch_gen_params['sequence_generator_name'])
    logging.info("SEQGEN Class Name: {}".format(sequence_generator_class_name))
    BatchGenerator = getattr(generators, sequence_generator_class_name)

    # instantiate the batch generator class for training
    train_gen = BatchGenerator(input_params,
                               train_batch_gen_params,
                               network_params,
                               genome_params['reference_genome'],
                               genome_params['chrom_sizes'],
                               train_chroms,
                               num_threads=parallelization_params['threads'],
                               epochs=hyper_params['epochs'],
                               batch_size=hyper_params['batch_size'])

    # training generator function that will be passed to
    # fit_generator
    train_generator = train_gen.gen()

    # instantiate the batch generator class for validation
    val_gen = BatchGenerator(input_params,
                             val_batch_gen_params,
                             network_params,
                             genome_params['reference_genome'],
                             genome_params['chrom_sizes'],
                             val_chroms,
                             num_threads=parallelization_params['threads'],
                             epochs=hyper_params['epochs'],
                             batch_size=hyper_params['batch_size'])

    # validation generator function that will be passed to
    # fit_generator
    val_generator = val_gen.gen()

    # lets make sure the sizes look reasonable
    logging.info("TRAINING SIZE - {}".format(train_gen._samples.shape))
    logging.info("VALIDATION SIZE - {}".format(val_gen._samples.shape))

    # we need to calculate the number of training steps and
    # validation steps in each epoch, fit_generator requires this
    # to determine the end of an epoch
    train_steps = train_gen.len()
    val_steps = val_gen.len()

    # we may have to reduce the --threads sometimes
    # if the peak file has very few peaks, so we need to
    # check if these numbers will be 0
    logging.info("TRAINING STEPS - {}".format(train_steps))
    logging.info("VALIDATION STEPS - {}".format(val_steps))

    # Here we specify all our callbacks
    # 1. Early stopping if validation loss doesn't decrease
    es = EarlyStopping(monitor='val_loss',
                       mode='min',
                       verbose=1,
                       patience=hyper_params['early_stopping_patience'],
                       min_delta=hyper_params['early_stopping_min_delta'],
                       restore_best_weights=True)

    # 2. Reduce learning rate if validation loss is plateuing
    reduce_lr = ReduceLROnPlateau(
        monitor='val_loss',
        factor=0.5,
        patience=hyper_params['reduce_lr_on_plateau_patience'],
        min_lr=hyper_params['min_learning_rate'])

    # 3. Timing hook to record start, end & elapsed time for each
    # epoch
    time_tracker = TimeHistory()

    # 4. Batch controller callbacks to ensure that batches are
    # generated only on a per epoch basis, also ensures graceful
    # termination of the batch generation
    train_batch_controller = BatchController(train_gen)
    val_batch_controller = BatchController(val_gen)

    # get an instance of the model
    logging.debug("New {} model".format(network_params['name']))
    get_model = getattr(model_archs, network_params['name'])
    model = get_model(train_batch_gen_params['input_seq_len'],
                      train_batch_gen_params['output_len'],
                      len(network_params['control_smoothing']) + 1,
                      filters=network_params['filters'],
                      num_tasks=train_gen._num_tasks)

    model.summary()

    # if running in multi gpu mode
    if parallelization_params['gpus'] > 1:
        logging.debug("Multi GPU model")
        model = multi_gpu_model(model, gpus=parallelization_params['gpus'])

    # compile the model
    logging.debug("Compiling model")
    logging.info("counts_loss_weight - {}".format(
        network_params['counts_loss_weight']))
    model.compile(
        Adam(lr=hyper_params['learning_rate']),
        loss=[MultichannelMultinomialNLL(train_gen._num_tasks), 'mse'],
        loss_weights=[1, network_params['counts_loss_weight']])

    # begin time for training
    t1 = time.time()

    # start training
    logging.debug("Training started ...")
    with warnings.catch_warnings():
        warnings.simplefilter('ignore')
        history = model.fit_generator(train_generator,
                                      validation_data=val_generator,
                                      epochs=hyper_params['epochs'],
                                      steps_per_epoch=train_steps,
                                      validation_steps=val_steps,
                                      callbacks=[
                                          es, reduce_lr, time_tracker,
                                          train_batch_controller,
                                          val_batch_controller
                                      ])

    # end time for training
    t2 = time.time()
    logging.info("Total Elapsed Time: {}".format(t2 - t1))

    # send the stop signal to the generators
    train_gen.set_stop()
    val_gen.set_stop()

    # base model filename
    if output_params['automate_filenames']:
        # get random alphanumeric tag for model
        model_tag = getAlphaNumericTag(output_params['tag_length'])
        model_fname = "{}/{}".format(model_dir, model_tag)

    elif output_params['model_output_filename'] is not None:
        model_fname = "{}/{}".format(model_dir,
                                     output_params['model_output_filename'])
    else:
        model_fname = "{}/model".format(model_dir)
    # add suffix tag to model name
    if suffix_tag is not None:
        model_fname += "_{}".format(suffix_tag)
    # extension
    model_fname += ".h5"

    # save HDF5 model file
    model.save(model_fname)
    logging.info("Finished saving model: {}".format(model_fname))

    # save history to json:
    # Step 1. create a custom history object with a new key for
    # epoch times
    custom_history = copy.deepcopy(history.history)
    custom_history['times'] = time_tracker.times

    # Step 2. convert the custom history dict to a pandas DataFrame:
    hist_df = pd.DataFrame(custom_history)

    # file name for json file
    hist_json = model_fname.replace('.h5', '.history.json')

    # Step 3. write the dataframe to json
    with open(hist_json, mode='w') as f:
        hist_df.to_json(f)

    logging.info("Finished saving training and validation history: {}".format(
        hist_json))

    # write all the command line arguments to a json file
    # & include the number of epochs the training lasted for, and the
    # validation and testchroms
    config_file = '{}/config'.format(model_dir)
    # add suffix tag to model name
    if suffix_tag is not None:
        config_file += "_{}".format(suffix_tag)
    # extension
    config_file += ".json"

    with open(config_file, 'w') as fp:
        config = {}
        config['input_params'] = input_params
        config['output_params'] = output_params
        config['genome_params'] = genome_params
        config['batch_gen_params'] = batch_gen_params
        config['hyper_params'] = hyper_params
        config['parallelization_params'] = parallelization_params
        config['network_params'] = network_params

        # the number of epochs the training lasted
        epochs = len(history.history['val_loss'])
        config['training_epochs'] = epochs

        config['train_chroms'] = train_chroms
        config['val_chroms'] = val_chroms
        config['model_filename'] = model_fname

        json.dump(config, fp)

    return model
Пример #7
0
def train_and_validate_ksplits(input_params, output_params, genome_params,
                               batch_gen_params, hyper_params,
                               parallelization_params, network_params, splits):
    """
        Train and validate on one or more train/val splits

        Note: the list & description for each of the required keys
            in all of the json parameter files passed to this 
            function can be found here:
            http://
        
        Args:
            input_params (str): path to json file containing input
                parameters
            
            output_params (str): path to json file containing output
                parameters
            
            genome_params (str): path to json file containing genome
                parameters
            
            batch_gen_params (str): path to json file containing batch
                generation parameters
            
            hyper_params (str): path to json file containing training &
                validation hyper parameters
            
            parallelization_params (str): path to json file containing
                parameters for parallelization options
            
            network_params (str): path to json file containing
                parameters specific to the deep learning architecture
            
            splits (str): path to the json file containing train & 
                validation splits
    """

    # list of chromosomes after removing the excluded chromosomes
    chroms = set(genome_params['chroms']).difference(
        set(genome_params['exclude_chroms']))

    # list of models from all of the splits
    models = []

    # run training for each validation/test split
    num_splits = len(list(splits.keys()))
    for i in range(num_splits):

        if output_params['automate_filenames']:
            # create a new directory using current date/time to store the
            # model, the loss history and logs
            date_time_str = local_datetime_str(output_params['time_zone'])
            model_dir = '{}/{}_split{:03d}'.format(output_params['output_dir'],
                                                   date_time_str, i)
            os.mkdir(model_dir)
            split_tag = None
        elif os.path.isdir(output_params['output_dir']):
            model_dir = output_params['output_dir']
            split_tag = "split{:03d}".format(i)
        else:
            logging.error("Directory does not exist {}.".format(
                output_params['output_dir']))
            return

        # filename to write debug logs
        logfname = '{}/trainer.log'.format(model_dir)
        # set up logger for main procecss
        logger.init_logger(logfname)

        # train & validation chromosome split
        if 'val' not in splits[str(i)]:
            logging.error("KeyError: 'val' required for split {}".format(i))
            return
        val_chroms = splits[str(i)]['val']
        # if 'train' key is present
        if 'train' in splits[str(i)]:
            train_chroms = splits[str(i)]['train']
        # if 'test' key is present but train is not
        elif 'test' in splits[str(i)]:
            test_chroms = splits[str(i)]['test']
            # take the set difference of the whole list of
            # chroms with the union of val and test
            train_chroms = list(
                chroms.difference(set(val_chroms + test_chroms)))
        else:
            # take the set difference of the whole list of
            # chroms with val
            train_chroms = list(chroms.difference(val_chroms))

        logging.info("Split #{}".format(i))
        logging.info("Train: {}".format(train_chroms))
        logging.info("Val: {}".format(val_chroms))

        # Start training for the split in a separate process
        # This ensures that all resources are freed, when the
        # process terminates, & available for training the next split
        # Mitigates the problem where training subsequent splits
        # is considerably slow
        logging.debug("Split {}: Creating training process".format(i))
        p = mp.Process(target=train_and_validate,
                       args=[
                           input_params, output_params, genome_params,
                           batch_gen_params, hyper_params,
                           parallelization_params, network_params,
                           train_chroms, val_chroms, model_dir, split_tag
                       ])
        p.start()

        # wait for the process to finish
        p.join()