def test_one_hot_encode(): """ test once hot encoding of dna sequences """ # list of same length sequences sequences = ['ACGN', 'AAGG', 'CTCT', 'NNNN', 'CCCC'] # the expected one hot encoding expected_res = [[[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 0]], [[1, 0, 0, 0], [1, 0, 0, 0], [0, 0, 1, 0], [0, 0, 1, 0]], [[0, 1, 0, 0], [0, 0, 0, 1], [0, 1, 0, 0], [0, 0, 0, 1]], [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]], [[0, 1, 0, 0], [0, 1, 0, 0], [0, 1, 0, 0], [0, 1, 0, 0]]] res = sequtils.one_hot_encode(sequences, 4) np.testing.assert_array_equal(res, np.array(expected_res)) # list of unequal length sequences sequences = ['ACGN', 'AAGG', 'CTCTF', 'NNNN', 'CCCC'] # this will truncate the 3rd sequence res = sequtils.one_hot_encode(sequences, 4) np.testing.assert_array_equal(res, expected_res)
def train_and_validate(input_data, model_arch_name, model_arch_params_json, output_params, genome_params, batch_gen_params, hyper_params, parallelization_params, train_chroms, val_chroms, model_dir, bias_input_data=None, bias_model_arch_params_json=None, adjust_bias_model_logcounts=False, is_background_model=False, mnll_loss_sample_weight=1.0, mnll_loss_background_sample_weight=0.0, suffix_tag=None): """ Train and validate on a single train and validation set Note: the list & description for each of the required keys in all of the json parameter files passed to this fucntion can be found here: http:// Args: input_data (str): path to the tasks json file model_arch_name (str): name of the model definition function in the model_archs module model_arch_params_json (str): path to json file containing model architecture params output_params (dict): dictionary containing output parameters genome_params (dict): dictionary containing genome parameters batch_gen_params (dict): dictionary containing batch generation parameters hyper_params (dict): dictionary containing containing training & validation hyper parameters parallelization_params (dict): dictionary containing parameters for parallelization options train_chroms (list): list of training chromosomes val_chroms (list): list of validation chromosomes model_dir (str): the path to the output directory bias_input_data (str): path to the bias tasks json file bias_model_arch_params_json (str): path to json file containing bias model architecture params adjust_bias_model_logcounts (boolean): True if you need to adjust the the weights of the final Dense layer that predicts the logcounts when training a bias model for chromatin accessibility is_background_model (boolean): True if a background model is to be trained using 'background_loci' samples from the input json mnll_loss_sample_weight (float): weight for each (foreground) training sample for computing mnll loss mnll_loss_background_sample_weight (float): weight for each background sample for computing mnll loss suffix_tag (str): optional tag to add as a suffix to files (model, log, history & config params files) created in the model directory Returns: keras.models.Model """ # make sure the input_data json file exists if not os.path.isfile(input_data): raise NoTracebackException("File not found: {} ".format(input_data)) # load the json file with open(input_data, 'r') as inp_json: try: tasks = json.loads(inp_json.read()) # since the json has keys as strings, we convert the # top level keys to int so we can used them later for # indexing #: dictionary of tasks for training tasks = {int(k): v for k, v in tasks.items()} except json.decoder.JSONDecodeError: raise NoTracebackException( "Unable to load json file {}. Valid json expected. " "Check the file for syntax errors.".format(input_data)) # make sure the params json file exists if not os.path.isfile(model_arch_params_json): raise NoTracebackException( "File not found: {} ".format(model_arch_params_json)) # load the params json file with open(model_arch_params_json, 'r') as inp_json: try: model_arch_params = json.loads(inp_json.read()) except json.decoder.JSONDecodeError: raise NoTracebackException( "Unable to load json file {}. Valid json expected. " "Check the file for syntax errors.".format( model_arch_params_json)) if bias_input_data is not None: # load the bias json file with open(bias_input_data, 'r') as inp_json: try: bias_tasks = json.loads(inp_json.read()) # since the json has keys as strings, we convert the # top level keys to int so we can used them later for # indexing #: dictionary of tasks for training bias_tasks = {int(k): v for k, v in bias_tasks.items()} except json.decoder.JSONDecodeError: raise NoTracebackException( "Unable to load json file {}. Valid json expected. " "Check the file for syntax errors.".format( bias_input_data)) if bias_model_arch_params_json is not None: # make sure the bias params json file exists if not os.path.isfile(bias_model_arch_params_json): raise NoTracebackException( "File not found: {} ".format(bias_model_arch_params_json)) # load the bias params json file with open(bias_model_arch_params_json, 'r') as inp_json: try: bias_model_arch_params = json.loads(inp_json.read()) except json.decoder.JSONDecodeError: raise NoTracebackException( "Unable to load json file {}. Valid json expected. " "Check the file for syntax errors.".format( bias_model_arch_params_json)) # filename to write debug logs if suffix_tag is not None: logfname = '{}/trainer_{}.log'.format(model_dir, suffix_tag) else: logfname = '{}/trainer.log'.format(model_dir) # we need to initialize the logger for each process logger.init_logger(logfname) # parameters that are specific to the training batch generation # process train_batch_gen_params = batch_gen_params train_batch_gen_params['mode'] = 'train' # parameters that are specific to the validation batch generation # process. For validation we dont use jitter, reverse complement # augmentation and negative sampling val_batch_gen_params = copy.deepcopy(batch_gen_params) val_batch_gen_params['max_jitter'] = 0 val_batch_gen_params['rev_comp_aug'] = False val_batch_gen_params['negative_sampling_rate'] = 0.0 val_batch_gen_params['mode'] = 'val' # get the corresponding batch generator class for this model sequence_generator_class_name = generators.find_generator_by_name( batch_gen_params['sequence_generator_name']) logging.info("SEQGEN Class Name: {}".format(sequence_generator_class_name)) BatchGenerator = getattr(generators, sequence_generator_class_name) # instantiate the batch generator class for training train_gen = BatchGenerator( input_data, train_batch_gen_params, genome_params['reference_genome'], genome_params['chrom_sizes'], train_chroms, num_threads=parallelization_params['threads'], batch_size=hyper_params['batch_size'], background_only=is_background_model, foreground_weight=mnll_loss_sample_weight, background_weight=mnll_loss_background_sample_weight) # instantiate the batch generator class for validation val_gen = BatchGenerator( input_data, val_batch_gen_params, genome_params['reference_genome'], genome_params['chrom_sizes'], val_chroms, num_threads=parallelization_params['threads'], batch_size=hyper_params['batch_size'], background_only=is_background_model, foreground_weight=mnll_loss_sample_weight, background_weight=mnll_loss_background_sample_weight) # we need to calculate the number of training steps and # validation steps in each epoch, fit/evaluate requires this # to determine the end of an epoch train_steps = train_gen.len() val_steps = val_gen.len() # we may have to reduce the --threads sometimes # if the peak file has very few peaks, so we need to # check if these numbers will be 0 logging.info("TRAINING STEPS - {}".format(train_steps)) logging.info("VALIDATION STEPS - {}".format(val_steps)) # get an instance of the model logging.debug("New {} model".format(model_arch_name)) get_model = getattr(archs, model_arch_name) if model_arch_name == "BPNet_ATAC_DNase": model = get_model(tasks, bias_tasks, model_arch_params, bias_model_arch_params, name_prefix="main") else: model = get_model(tasks, model_arch_params, name_prefix="main") # print out the model summary model.summary() # compile the model logging.debug("Compiling model") logging.info("loss weights - {}".format(model_arch_params['loss_weights'])) model.compile(Adam(learning_rate=hyper_params['learning_rate']), loss=[ MultichannelMultinomialNLL( train_gen._total_signal_tracks), CustomMeanSquaredError() ], loss_weights=model_arch_params['loss_weights']) # begin time for training t1 = time.time() # track training losses, validation losses and start & end # times custom_history = { 'learning_rate': {}, 'loss': {}, 'batch_loss': {}, 'profile_predictions_loss': {}, 'logcounts_predictions_loss': {}, 'attribution_prior_loss': {}, 'val_loss': {}, 'val_batch_loss': {}, 'val_profile_predictions_loss': {}, 'val_logcounts_predictions_loss': {}, 'val_attribution_prior_loss': {}, 'start_time': {}, 'end_time': {}, 'elapsed': {} } # we maintain a separate list to track validation losses to make it # easier for early stopping val_losses = [] # track validation losses for learning rate update val_losses_lr = [] # track best loss so we can restore weights best_loss = 1e6 # keep a copy of the best weights best_weights = None # the epoch with the best validation loss best_epoch = 1 # start training logging.debug("Training started ...") for epoch in range(hyper_params['epochs']): # First, let's train for one epoch logging.info("Training Epoch {}".format(epoch + 1)) train_start_time = time.time() custom_history['learning_rate'][str(epoch + 1)] = \ model.optimizer.learning_rate.numpy() custom_history['start_time'][str(epoch + 1)] = \ time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime(train_start_time)) # training generator function that will be passed to fit train_generator = train_gen.gen() history = model.fit(train_generator, epochs=1, steps_per_epoch=train_steps) train_end_time = time.time() # record the training losses for key in history.history: custom_history[key][str(epoch + 1)] = history.history[key][0] # Then, we evaluate on the validation set logging.info("Validation Epoch {}".format(epoch + 1)) val_start_time = time.time() # validation generator function that will be passed to evaluate val_generator = val_gen.gen() val_loss = model.evaluate(val_generator, steps=val_steps, return_dict=True) val_losses.append(val_loss['loss']) val_losses_lr.append(val_loss['loss']) val_end_time = time.time() custom_history['end_time'][str(epoch + 1)] = \ time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime(val_end_time)) custom_history['elapsed'][str(epoch + 1)] = \ val_end_time - train_start_time # record the validation losses for key in val_loss: custom_history['val_' + key][str(epoch + 1)] = \ val_loss[key] # update best weights and loss if val_loss['loss'] < best_loss: best_weights = model.get_weights() best_loss = val_loss['loss'] best_epoch = epoch + 1 # check if early stopping criteria are satisfied if early_stopping_check( val_losses, patience=hyper_params['early_stopping_patience'], min_delta=hyper_params['early_stopping_min_delta']): # restore best weights logging.info( "Restoring best weights from epoch {}".format(best_epoch)) model.set_weights(best_weights) break # lower learning rate if criteria are satisfied current_lr = model.optimizer.learning_rate.numpy() new_lr = reduce_lr_on_plateau( val_losses_lr, current_lr, factor=hyper_params['lr_reduction_factor'], patience=hyper_params['reduce_lr_on_plateau_patience'], min_lr=hyper_params['min_learning_rate']) # reset the validation losses tracker for learning rate update if new_lr != current_lr: val_losses_lr = [val_losses_lr[-1]] # set the new learning rate model.optimizer.learning_rate.assign(new_lr) # display current learning rate and training status logging.info( "Current learning rate - {:5f}, Stop Training - {}".format( model.optimizer.learning_rate.numpy(), model.stop_training)) # end time for training t2 = time.time() logging.info("Total Elapsed Time: {}".format(t2 - t1)) # base model filename if output_params['automate_filenames']: # get random alphanumeric tag for model model_tag = getAlphaNumericTag(output_params['tag_length']) model_fname = "{}/{}".format(model_dir, model_tag) elif output_params['model_output_filename'] is not None: model_fname = "{}/{}".format(model_dir, output_params['model_output_filename']) else: model_fname = "{}/model".format(model_dir) # add suffix tag to model name if suffix_tag is not None: model_fname += "_{}".format(suffix_tag) # extension model_fname += ".h5" # save HDF5 model file model.save(model_fname) logging.info("Finished saving model: {}".format(model_fname)) if adjust_bias_model_logcounts: # all peaks and non peaks loci = train_gen.get_samples() # non-peaks only nonpeaks_loci = loci[loci['weight'] == 0.0] if len(nonpeaks_loci) == 0: logging.info("Non peaks length is 0. Bias model adjustment " "aborted.") else: # reference file to fetch sequences fasta_ref = pyfaidx.Fasta(genome_params['reference_genome']) #get all the bigWigs and peaks from the input_data bigWigs = [] for task in tasks: if 'signal' in tasks[task].keys(): bigWigs.extend(tasks[task]['signal']["source"]) # open each bigwig and add file pointers to a list fbigWigs = [] for bigWig in bigWigs: fbigWigs.append(pyBigWig.open(bigWig)) # get sequences and logcounts logging.info("Fetching non peak sequences and counts ...") sequences = [] logcounts = [] for _, row in nonpeaks_loci.iterrows(): # chrom, start and end chrom = row['chrom'] start = row['pos'] - (batch_gen_params['input_seq_len'] // 2) end = row['pos'] + (batch_gen_params['input_seq_len'] // 2) # get the sequences seq = fasta_ref[chrom][start:end].seq.upper() # collect all the sequences into a list sequences.append(seq) # get the total counts for i in range(len(fbigWigs)): bw = fbigWigs[i] logcounts.append( np.log( np.sum(np.nan_to_num(bw.values(chrom, start, end))) + 1)) fasta_ref.close() # one hot encode the sequences seqs = sequtils.one_hot_encode(sequences, batch_gen_params['input_seq_len']) adjusted_model = adjust_bias_logcounts(model, seqs, np.array(logcounts), "logcounts_predictions") # saving adjusted model model_fname = model_fname.replace('.h5', '.adjusted.h5') # save HDF5 model file adjusted_model.save(model_fname) logging.info( "Finished saving adjusted model: {}".format(model_fname)) # save history to json: # Step 1. convert the custom history dict to a pandas DataFrame: hist_df = pd.DataFrame(custom_history) # file name for json file hist_json = model_fname.replace('.h5', '.history.json') # Step 2. write the dataframe to json with open(hist_json, mode='w') as f: hist_df.to_json(f) logging.info("Finished saving training and validation history: {}".format( hist_json)) # write all the command line arguments to a json file # & include the number of epochs the training lasted for, and the # validation and testchroms config_file = '{}/config'.format(model_dir) # add suffix tag to model name if suffix_tag is not None: config_file += "_{}".format(suffix_tag) # extension config_file += ".json" with open(config_file, 'w') as fp: config = {} config['input_data'] = input_data config['output_params'] = output_params config['genome_params'] = genome_params config['batch_gen_params'] = batch_gen_params config['hyper_params'] = hyper_params config['parallelization_params'] = parallelization_params # the number of epochs the training lasted config['training_epochs'] = epoch + 1 # the epoch with best validation loss config['best_epoch'] = best_epoch config['train_chroms'] = train_chroms config['val_chroms'] = val_chroms config['model_filename'] = model_fname json.dump(config, fp) return model
def _generate_batch(self, coords): """Generate one batch of inputs and outputs for training BPNet For all coordinates in "coords" fetch sequences & one hot encode the sequences. Fetch corresponding signal values (for e.g. from a bigwig file). Package the one hot encoded sequences and the output values as a tuple. Args: coords (pandas.DataFrame): dataframe with 'chrom', 'pos' & 'status' columns specifying the chromosome, thecoordinate and whether the loci is a positive(1) or negative sample(-1) Returns: tuple: When 'mode' is 'train' or 'val' a batch tuple with one hot encoded sequences and corresponding outputs and when 'mode' is 'test' tuple of cordinates & the inputs """ # reference file to fetch sequences fasta_ref = pyfaidx.Fasta(self._reference) # Initialization # (batch_size, output_len, 1 + #smoothing_window_sizes) control_profile = np.zeros((coords.shape[0], self._output_flank * 2, 1 + len(self._control_smoothing)), dtype=np.float32) # (batch_size) control_profile_counts = np.zeros((coords.shape[0]), dtype=np.float32) # in 'test' mode we pass the true profile as part of the # returned tuple from the batch generator if self._mode == "train" or self._mode == "val" or \ self._mode == "test": # (batch_size, output_len, #tasks) profile = np.zeros((coords.shape[0], self._output_flank * 2, self._num_tasks), dtype=np.float32) # (batch_size, #tasks) profile_counts = np.zeros((coords.shape[0], self._num_tasks), dtype=np.float32) # if reverse complement augmentation is enabled then double the sizes if self._mode == "train" and self._rev_comp_aug: control_profile = control_profile.repeat(2, axis=0) control_profile_counts = control_profile_counts.repeat(2, axis=0) profile = profile.repeat(2, axis=0) profile_counts = profile_counts.repeat(2, axis=0) # list of sequences in the batch, these will be one hot # encoded together as a single sequence after iterating # over the batch sequences = [] # list of chromosome start/end coordinates # useful for tracking test batches coordinates = [] # open all the control bigwig files and store the file # objects in a dictionary control_files = {} for task in self._tasks: # the control is not necessary if 'control' in self._tasks[task]: control_files[task] = pyBigWig.open( self._tasks[task]['control']) # in 'test' mode we pass the true profile as part of the # returned tuple from the batch generator if self._mode == "train" or self._mode == "val" or \ self._mode == "test": # open all the required bigwig files and store the file # objects in a dictionary signal_files = {} for task in self._tasks: signal_files[task] = pyBigWig.open(self._tasks[task]['signal']) # iterate over the batch rowCnt = 0 for _, row in coords.iterrows(): # randomly set a jitter value to move the peak summit # slightly away from the exact center jitter = 0 if self._mode == "train" and self._max_jitter: jitter = random.randint(-self._max_jitter, self._max_jitter) # Step 1 get the sequence chrom = row['chrom'] # we use self._input_flank here and not self._output_flank because # input_seq_len is different from output_len start = row['pos'] - self._input_flank + jitter end = row['pos'] + self._input_flank + jitter seq = fasta_ref[chrom][start:end].seq.upper() # collect all the sequences into a list sequences.append(seq) start = row['pos'] - self._output_flank + jitter end = row['pos'] + self._output_flank + jitter # collect all the start/end coordinates into a list # we'll send this off along with 'test' batches coordinates.append((chrom, start, end)) # iterate over each task for task in self._tasks: # identifies the +/- strand pair task_id = self._tasks[task]['task_id'] # the strand id: 0-positive, 1-negative # easy to index with those values strand = self._tasks[task]['strand'] # Step 2. get the control values if task in control_files: control_values = control_files[task].values( chrom, start, end) # replace nans with zeros if np.any(np.isnan(control_values)): control_values = np.nan_to_num(control_values) # update row in batch with the control values # the values are summed across all tasks # the axis = 1 dimension accumulates the sum # there are 'n' copies of the sum along axis = 2, # n = #smoothing_windows control_profile[rowCnt, :, :] += np.expand_dims( control_values, axis=1) # in 'test' mode we pass the true profile as part of the # returned tuple from the batch generator if self._mode == "train" or self._mode == "val" or \ self._mode == "test": # Step 3. get the signal values # fetch values using the pyBigWig file objects values = signal_files[task].values(chrom, start, end) # replace nans with zeros if np.any(np.isnan(values)): values = np.nan_to_num(values) # update row in batch with the signal values if self._stranded: profile[rowCnt, :, task_id * 2 + strand] = values else: profile[rowCnt, :, task_id] = values rowCnt += 1 # Step 4. reverse complement augmentation if self._mode == "train" and self._rev_comp_aug: # Step 4.1 get list of reverse complement sequences rev_comp_sequences = \ sequtils.reverse_complement_of_sequences(sequences) # append the rev comp sequences to the original list sequences.extend(rev_comp_sequences) # Step 4.2 reverse complement of the control profile control_profile[rowCnt:, :, :] = \ sequtils.reverse_complement_of_profiles( control_profile[:rowCnt, :, :], self._stranded) # Step 4.3 reverse complement of the signal profile profile[rowCnt:, :, :] = \ sequtils.reverse_complement_of_profiles( profile[:rowCnt, :, :], self._stranded) # Step 5. one hot encode all the sequences in the batch if len(sequences) == profile.shape[0]: X = sequtils.one_hot_encode(sequences, self._input_flank * 2) else: raise NoTracebackException( "Unable to generate enough sequences for the batch") # we can perform smoothing on the entire batch of control values for i in range(len(self._control_smoothing)): sigma = self._control_smoothing[i][0] window_size = self._control_smoothing[i][1] # its i+1 because at index 0 we have the original # control control_profile[:, :, i + 1] = utils.gaussian1D_smoothing( control_profile[:, :, i + 1], sigma, window_size) # log of sum of control profile without smoothing (idx = 0) control_profile_counts = np.log( np.sum(control_profile[:, :, 0], axis=-1) + 1) # in 'train' and 'val' mode we need input and output # dictionaries if self._mode == "train" or self._mode == 'val': # we can now sum the profiles for the entire batch profile_counts = np.log(np.sum(profile, axis=1) + 1) # return a tuple of input and output dictionaries # 'coordinates' and 'status are not inputs to the model, # so you will see a warning about unused inputs while # training. It's safe to ignore the warning # We pass 'coordinates' so we can track the exact # coordinates of the inputs (because jitter is random) # 'status' refers to whether the data sample is a +ve (1) # or -ve (-1) example and is used by the attribution # prior loss function return ({'coordinates': coordinates, 'status': coords['status'].values, 'sequence': X, 'control_profile': control_profile, 'control_logcount': control_profile_counts}, {'profile_predictions': profile, 'logcount_predictions': profile_counts}) # in 'test' mode return a tuple of cordinates, true profiles # & the input dictionary return (coordinates, profile, {'sequence': X, 'control_profile': control_profile, 'control_logcount': control_profile_counts})