def predict_main(): # parse the command line arguments parser = argparsers.predict_argsparser() args = parser.parse_args() # check if the output directory exists if not os.path.exists(args.output_dir): logging.error("Directory {} does not exist".format(args.output_dir)) return if args.automate_filenames: # create a new directory using current date/time to store the # predictions and logs date_time_str = local_datetime_str(args.time_zone) pred_dir = '{}/{}'.format(args.output_dir, date_time_str) os.mkdir(pred_dir) elif os.path.isdir(args.output_dir): pred_dir = args.output_dir else: logging.error("Directory does not exist {}.".format(args.output_dir)) return # filename to write debug logs logfname = "{}/predict.log".format(pred_dir) # set up the loggers logger.init_logger(logfname) # make sure the input_data json file exists if not os.path.isfile(args.input_data): raise quietexception.QuietException( "File not found: {} OR you may have accidentally " "specified a directory path.".format(args.input_data)) # load the json file with open(args.input_data, 'r') as inp_json: try: #: dictionary of tasks for training input_data = json.loads(inp_json.read()) except json.decoder.JSONDecodeError: raise quietexception.QuietException( "Unable to load json file {}. Valid json expected. " "Check the file for syntax errors.".format(args.input_data)) logging.info("INPUT DATA -\n{}".format(input_data)) # predict logging.info("Loading {}".format(args.model)) with CustomObjectScope( {'MultichannelMultinomialNLL': MultichannelMultinomialNLL}): predict(args, input_data, pred_dir)
def predict_main(): # parse the command line arguments parser = argparsers.fastpredict_argsparser() args = parser.parse_args() # check if the output directory exists if not os.path.exists(args.output_dir): logging.error("Directory {} does not exist".format(args.output_dir)) return if args.automate_filenames: # create a new directory using current date/time to store the # predictions and logs date_time_str = local_datetime_str(args.time_zone) pred_dir = '{}/{}'.format(args.output_dir, date_time_str) os.mkdir(pred_dir) elif os.path.isdir(args.output_dir): pred_dir = args.output_dir else: logging.error("Directory does not exist {}.".format(args.output_dir)) return # filename to write debug logs logfname = "{}/predict.log".format(pred_dir) # set up the loggers logger.init_logger(logfname) # predict logging.info("Loading {}".format(args.model)) with CustomObjectScope({ 'MultichannelMultinomialNLL': MultichannelMultinomialNLL, 'tf': tf, 'CustomMeanSquaredError': CustomMeanSquaredError, 'AttributionPriorModel': AttributionPriorModel, 'CustomModel': CustomModel }): predict(args, pred_dir)
def train_and_validate_ksplits(input_data, model_arch_name, model_arch_params_json, output_params, genome_params, batch_gen_params, hyper_params, parallelization_params, splits, bias_input_data=None, bias_model_arch_params_json=None, adjust_bias_model_logcounts=False, is_background_model=False, mnll_loss_sample_weight=1.0, mnll_loss_background_sample_weight=0.0): """ Train and validate on one or more train/val splits Args: input_data (str): path to the tasks json file model_arch_name (str): name of the model definition function in the model_archs module model_arch_params_json (str): path to json file containing model architecture params output_params (dict): dictionary containing output parameters genome_params (dict): dictionary containing genome parameters batch_gen_params (dict): dictionary containing batch generation parameters hyper_params (dict): dictionary containing containing training & validation hyper parameters parallelization_params (dict): dictionary containing parameters for parallelization options splits (str): path to the json file containing train & validation splits bias_input_data (str): path to the bias tasks json file bias_model_arch_params_json (str): path to json file containing bias model architecture params is_background_model (boolean): True if a background model is to be trained using 'background_loci' samples from the input json mnll_loss_sample_weight (float): weight for each (foreground) training sample for computing mnll loss mnll_loss_background_sample_weight (float): weight for each background sample for computing mnll loss """ # list of chromosomes after removing the excluded chromosomes chroms = set(genome_params['chroms']).difference( set(genome_params['exclude_chroms'])) # list of models from all of the splits models = [] # run training for each validation/test split num_splits = len(list(splits.keys())) for i in range(num_splits): if output_params['automate_filenames']: # create a new directory using current date/time to store the # model, the loss history and logs date_time_str = local_datetime_str(output_params['time_zone']) model_dir = '{}/{}_split{:03d}'.format(output_params['output_dir'], date_time_str, i) os.mkdir(model_dir) split_tag = None elif os.path.isdir(output_params['output_dir']): model_dir = output_params['output_dir'] split_tag = "split{:03d}".format(i) else: logging.error("Directory does not exist {}.".format( output_params['output_dir'])) return # filename to write debug logs logfname = '{}/trainer.log'.format(model_dir) # set up logger for main procecss logger.init_logger(logfname) # train & validation chromosome split if 'val' not in splits[str(i)]: logging.error("KeyError: 'val' required for split {}".format(i)) return val_chroms = splits[str(i)]['val'] # if 'train' key is present if 'train' in splits[str(i)]: train_chroms = splits[str(i)]['train'] # if 'test' key is present but train is not elif 'test' in splits[str(i)]: test_chroms = splits[str(i)]['test'] # take the set difference of the whole list of # chroms with the union of val and test train_chroms = list( chroms.difference(set(val_chroms + test_chroms))) else: # take the set difference of the whole list of # chroms with val train_chroms = list(chroms.difference(val_chroms)) logging.info("Split #{}".format(i)) logging.info("Train: {}".format(train_chroms)) logging.info("Val: {}".format(val_chroms)) # Start training for the split in a separate process # This ensures that all resources are freed, when the # process terminates, & available for training the next split # Mitigates the problem where training subsequent splits # is considerably slow logging.debug("Split {}: Creating training process".format(i)) p = mp.Process(target=train_and_validate, args=[ input_data, model_arch_name, model_arch_params_json, output_params, genome_params, batch_gen_params, hyper_params, parallelization_params, train_chroms, val_chroms, model_dir, bias_input_data, bias_model_arch_params_json, adjust_bias_model_logcounts, is_background_model, mnll_loss_sample_weight, mnll_loss_background_sample_weight, split_tag ]) p.start() # wait for the process to finish p.join()
def train_and_validate(input_data, model_arch_name, model_arch_params_json, output_params, genome_params, batch_gen_params, hyper_params, parallelization_params, train_chroms, val_chroms, model_dir, bias_input_data=None, bias_model_arch_params_json=None, adjust_bias_model_logcounts=False, is_background_model=False, mnll_loss_sample_weight=1.0, mnll_loss_background_sample_weight=0.0, suffix_tag=None): """ Train and validate on a single train and validation set Note: the list & description for each of the required keys in all of the json parameter files passed to this fucntion can be found here: http:// Args: input_data (str): path to the tasks json file model_arch_name (str): name of the model definition function in the model_archs module model_arch_params_json (str): path to json file containing model architecture params output_params (dict): dictionary containing output parameters genome_params (dict): dictionary containing genome parameters batch_gen_params (dict): dictionary containing batch generation parameters hyper_params (dict): dictionary containing containing training & validation hyper parameters parallelization_params (dict): dictionary containing parameters for parallelization options train_chroms (list): list of training chromosomes val_chroms (list): list of validation chromosomes model_dir (str): the path to the output directory bias_input_data (str): path to the bias tasks json file bias_model_arch_params_json (str): path to json file containing bias model architecture params adjust_bias_model_logcounts (boolean): True if you need to adjust the the weights of the final Dense layer that predicts the logcounts when training a bias model for chromatin accessibility is_background_model (boolean): True if a background model is to be trained using 'background_loci' samples from the input json mnll_loss_sample_weight (float): weight for each (foreground) training sample for computing mnll loss mnll_loss_background_sample_weight (float): weight for each background sample for computing mnll loss suffix_tag (str): optional tag to add as a suffix to files (model, log, history & config params files) created in the model directory Returns: keras.models.Model """ # make sure the input_data json file exists if not os.path.isfile(input_data): raise NoTracebackException("File not found: {} ".format(input_data)) # load the json file with open(input_data, 'r') as inp_json: try: tasks = json.loads(inp_json.read()) # since the json has keys as strings, we convert the # top level keys to int so we can used them later for # indexing #: dictionary of tasks for training tasks = {int(k): v for k, v in tasks.items()} except json.decoder.JSONDecodeError: raise NoTracebackException( "Unable to load json file {}. Valid json expected. " "Check the file for syntax errors.".format(input_data)) # make sure the params json file exists if not os.path.isfile(model_arch_params_json): raise NoTracebackException( "File not found: {} ".format(model_arch_params_json)) # load the params json file with open(model_arch_params_json, 'r') as inp_json: try: model_arch_params = json.loads(inp_json.read()) except json.decoder.JSONDecodeError: raise NoTracebackException( "Unable to load json file {}. Valid json expected. " "Check the file for syntax errors.".format( model_arch_params_json)) if bias_input_data is not None: # load the bias json file with open(bias_input_data, 'r') as inp_json: try: bias_tasks = json.loads(inp_json.read()) # since the json has keys as strings, we convert the # top level keys to int so we can used them later for # indexing #: dictionary of tasks for training bias_tasks = {int(k): v for k, v in bias_tasks.items()} except json.decoder.JSONDecodeError: raise NoTracebackException( "Unable to load json file {}. Valid json expected. " "Check the file for syntax errors.".format( bias_input_data)) if bias_model_arch_params_json is not None: # make sure the bias params json file exists if not os.path.isfile(bias_model_arch_params_json): raise NoTracebackException( "File not found: {} ".format(bias_model_arch_params_json)) # load the bias params json file with open(bias_model_arch_params_json, 'r') as inp_json: try: bias_model_arch_params = json.loads(inp_json.read()) except json.decoder.JSONDecodeError: raise NoTracebackException( "Unable to load json file {}. Valid json expected. " "Check the file for syntax errors.".format( bias_model_arch_params_json)) # filename to write debug logs if suffix_tag is not None: logfname = '{}/trainer_{}.log'.format(model_dir, suffix_tag) else: logfname = '{}/trainer.log'.format(model_dir) # we need to initialize the logger for each process logger.init_logger(logfname) # parameters that are specific to the training batch generation # process train_batch_gen_params = batch_gen_params train_batch_gen_params['mode'] = 'train' # parameters that are specific to the validation batch generation # process. For validation we dont use jitter, reverse complement # augmentation and negative sampling val_batch_gen_params = copy.deepcopy(batch_gen_params) val_batch_gen_params['max_jitter'] = 0 val_batch_gen_params['rev_comp_aug'] = False val_batch_gen_params['negative_sampling_rate'] = 0.0 val_batch_gen_params['mode'] = 'val' # get the corresponding batch generator class for this model sequence_generator_class_name = generators.find_generator_by_name( batch_gen_params['sequence_generator_name']) logging.info("SEQGEN Class Name: {}".format(sequence_generator_class_name)) BatchGenerator = getattr(generators, sequence_generator_class_name) # instantiate the batch generator class for training train_gen = BatchGenerator( input_data, train_batch_gen_params, genome_params['reference_genome'], genome_params['chrom_sizes'], train_chroms, num_threads=parallelization_params['threads'], batch_size=hyper_params['batch_size'], background_only=is_background_model, foreground_weight=mnll_loss_sample_weight, background_weight=mnll_loss_background_sample_weight) # instantiate the batch generator class for validation val_gen = BatchGenerator( input_data, val_batch_gen_params, genome_params['reference_genome'], genome_params['chrom_sizes'], val_chroms, num_threads=parallelization_params['threads'], batch_size=hyper_params['batch_size'], background_only=is_background_model, foreground_weight=mnll_loss_sample_weight, background_weight=mnll_loss_background_sample_weight) # we need to calculate the number of training steps and # validation steps in each epoch, fit/evaluate requires this # to determine the end of an epoch train_steps = train_gen.len() val_steps = val_gen.len() # we may have to reduce the --threads sometimes # if the peak file has very few peaks, so we need to # check if these numbers will be 0 logging.info("TRAINING STEPS - {}".format(train_steps)) logging.info("VALIDATION STEPS - {}".format(val_steps)) # get an instance of the model logging.debug("New {} model".format(model_arch_name)) get_model = getattr(archs, model_arch_name) if model_arch_name == "BPNet_ATAC_DNase": model = get_model(tasks, bias_tasks, model_arch_params, bias_model_arch_params, name_prefix="main") else: model = get_model(tasks, model_arch_params, name_prefix="main") # print out the model summary model.summary() # compile the model logging.debug("Compiling model") logging.info("loss weights - {}".format(model_arch_params['loss_weights'])) model.compile(Adam(learning_rate=hyper_params['learning_rate']), loss=[ MultichannelMultinomialNLL( train_gen._total_signal_tracks), CustomMeanSquaredError() ], loss_weights=model_arch_params['loss_weights']) # begin time for training t1 = time.time() # track training losses, validation losses and start & end # times custom_history = { 'learning_rate': {}, 'loss': {}, 'batch_loss': {}, 'profile_predictions_loss': {}, 'logcounts_predictions_loss': {}, 'attribution_prior_loss': {}, 'val_loss': {}, 'val_batch_loss': {}, 'val_profile_predictions_loss': {}, 'val_logcounts_predictions_loss': {}, 'val_attribution_prior_loss': {}, 'start_time': {}, 'end_time': {}, 'elapsed': {} } # we maintain a separate list to track validation losses to make it # easier for early stopping val_losses = [] # track validation losses for learning rate update val_losses_lr = [] # track best loss so we can restore weights best_loss = 1e6 # keep a copy of the best weights best_weights = None # the epoch with the best validation loss best_epoch = 1 # start training logging.debug("Training started ...") for epoch in range(hyper_params['epochs']): # First, let's train for one epoch logging.info("Training Epoch {}".format(epoch + 1)) train_start_time = time.time() custom_history['learning_rate'][str(epoch + 1)] = \ model.optimizer.learning_rate.numpy() custom_history['start_time'][str(epoch + 1)] = \ time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime(train_start_time)) # training generator function that will be passed to fit train_generator = train_gen.gen() history = model.fit(train_generator, epochs=1, steps_per_epoch=train_steps) train_end_time = time.time() # record the training losses for key in history.history: custom_history[key][str(epoch + 1)] = history.history[key][0] # Then, we evaluate on the validation set logging.info("Validation Epoch {}".format(epoch + 1)) val_start_time = time.time() # validation generator function that will be passed to evaluate val_generator = val_gen.gen() val_loss = model.evaluate(val_generator, steps=val_steps, return_dict=True) val_losses.append(val_loss['loss']) val_losses_lr.append(val_loss['loss']) val_end_time = time.time() custom_history['end_time'][str(epoch + 1)] = \ time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime(val_end_time)) custom_history['elapsed'][str(epoch + 1)] = \ val_end_time - train_start_time # record the validation losses for key in val_loss: custom_history['val_' + key][str(epoch + 1)] = \ val_loss[key] # update best weights and loss if val_loss['loss'] < best_loss: best_weights = model.get_weights() best_loss = val_loss['loss'] best_epoch = epoch + 1 # check if early stopping criteria are satisfied if early_stopping_check( val_losses, patience=hyper_params['early_stopping_patience'], min_delta=hyper_params['early_stopping_min_delta']): # restore best weights logging.info( "Restoring best weights from epoch {}".format(best_epoch)) model.set_weights(best_weights) break # lower learning rate if criteria are satisfied current_lr = model.optimizer.learning_rate.numpy() new_lr = reduce_lr_on_plateau( val_losses_lr, current_lr, factor=hyper_params['lr_reduction_factor'], patience=hyper_params['reduce_lr_on_plateau_patience'], min_lr=hyper_params['min_learning_rate']) # reset the validation losses tracker for learning rate update if new_lr != current_lr: val_losses_lr = [val_losses_lr[-1]] # set the new learning rate model.optimizer.learning_rate.assign(new_lr) # display current learning rate and training status logging.info( "Current learning rate - {:5f}, Stop Training - {}".format( model.optimizer.learning_rate.numpy(), model.stop_training)) # end time for training t2 = time.time() logging.info("Total Elapsed Time: {}".format(t2 - t1)) # base model filename if output_params['automate_filenames']: # get random alphanumeric tag for model model_tag = getAlphaNumericTag(output_params['tag_length']) model_fname = "{}/{}".format(model_dir, model_tag) elif output_params['model_output_filename'] is not None: model_fname = "{}/{}".format(model_dir, output_params['model_output_filename']) else: model_fname = "{}/model".format(model_dir) # add suffix tag to model name if suffix_tag is not None: model_fname += "_{}".format(suffix_tag) # extension model_fname += ".h5" # save HDF5 model file model.save(model_fname) logging.info("Finished saving model: {}".format(model_fname)) if adjust_bias_model_logcounts: # all peaks and non peaks loci = train_gen.get_samples() # non-peaks only nonpeaks_loci = loci[loci['weight'] == 0.0] if len(nonpeaks_loci) == 0: logging.info("Non peaks length is 0. Bias model adjustment " "aborted.") else: # reference file to fetch sequences fasta_ref = pyfaidx.Fasta(genome_params['reference_genome']) #get all the bigWigs and peaks from the input_data bigWigs = [] for task in tasks: if 'signal' in tasks[task].keys(): bigWigs.extend(tasks[task]['signal']["source"]) # open each bigwig and add file pointers to a list fbigWigs = [] for bigWig in bigWigs: fbigWigs.append(pyBigWig.open(bigWig)) # get sequences and logcounts logging.info("Fetching non peak sequences and counts ...") sequences = [] logcounts = [] for _, row in nonpeaks_loci.iterrows(): # chrom, start and end chrom = row['chrom'] start = row['pos'] - (batch_gen_params['input_seq_len'] // 2) end = row['pos'] + (batch_gen_params['input_seq_len'] // 2) # get the sequences seq = fasta_ref[chrom][start:end].seq.upper() # collect all the sequences into a list sequences.append(seq) # get the total counts for i in range(len(fbigWigs)): bw = fbigWigs[i] logcounts.append( np.log( np.sum(np.nan_to_num(bw.values(chrom, start, end))) + 1)) fasta_ref.close() # one hot encode the sequences seqs = sequtils.one_hot_encode(sequences, batch_gen_params['input_seq_len']) adjusted_model = adjust_bias_logcounts(model, seqs, np.array(logcounts), "logcounts_predictions") # saving adjusted model model_fname = model_fname.replace('.h5', '.adjusted.h5') # save HDF5 model file adjusted_model.save(model_fname) logging.info( "Finished saving adjusted model: {}".format(model_fname)) # save history to json: # Step 1. convert the custom history dict to a pandas DataFrame: hist_df = pd.DataFrame(custom_history) # file name for json file hist_json = model_fname.replace('.h5', '.history.json') # Step 2. write the dataframe to json with open(hist_json, mode='w') as f: hist_df.to_json(f) logging.info("Finished saving training and validation history: {}".format( hist_json)) # write all the command line arguments to a json file # & include the number of epochs the training lasted for, and the # validation and testchroms config_file = '{}/config'.format(model_dir) # add suffix tag to model name if suffix_tag is not None: config_file += "_{}".format(suffix_tag) # extension config_file += ".json" with open(config_file, 'w') as fp: config = {} config['input_data'] = input_data config['output_params'] = output_params config['genome_params'] = genome_params config['batch_gen_params'] = batch_gen_params config['hyper_params'] = hyper_params config['parallelization_params'] = parallelization_params # the number of epochs the training lasted config['training_epochs'] = epoch + 1 # the epoch with best validation loss config['best_epoch'] = best_epoch config['train_chroms'] = train_chroms config['val_chroms'] = val_chroms config['model_filename'] = model_fname json.dump(config, fp) return model
def outliers_main(): # parse the command line arguments parser = outliers_argsparser() args = parser.parse_args() # filename to write debug logs logfname = "outliers.log" # set up the loggers logger.init_logger(logfname) # check if the input json file exists if not os.path.exists(args.input_data): raise NoTracebackException("File {} does not exist".format( args.input_data)) # check if the chrom sizes file exists if not os.path.exists(args.chrom_sizes): raise NoTracebackException("File {} does not exist".format( args.chrom_sizes)) # load the chrom sizes into a dataframe chrom_sizes_df = pd.read_csv(args.chrom_sizes, sep='\t', header=None, names=['chrom', 'size']) # load the tasks json file with open(args.input_data, 'r') as inp_json: try: tasks = json.loads(inp_json.read()) except json.decoder.JSONDecodeError: raise NoTracebackException( "Unable to load json file {}. Valid json expected. " "Check the file for syntax errors.".format(tasks_json)) # get all peaks for a given task in 10 column ENCODE narrowPeak # format dataframe peaks_df = getPeakPositions(tasks[args.task], args.chroms, chrom_sizes_df, args.sequence_len // 2, drop_duplicates=True) # if a global sample weight is specified set it here if args.global_sample_weight is not None: peaks_df['weight'] = args.global_sample_weight # remove peaks that fall within blacklist regions if args.blacklist != None: # check if the blacklist file exists if not os.path.exists(args.blacklist): raise NoTracebackException( "Blacklist file {} does not exist".format(args.blacklist)) blacklist_df = pd.read_csv(args.blacklist, sep='\t', names=['chrom', 'st', 'e']) logging.info("Filtering blacklist peaks ...") logging.info("old size {}".format(len(peaks_df))) peaks_df = remove_blacklist_peaks(peaks_df, blacklist_df) logging.info("new size {}".format(len(peaks_df))) # open all the signal bigWigs for reading signal_files = [] for signal_file in tasks[args.task]['signal']['source']: # check if the bigWig file exists if not os.path.exists(signal_file): raise NoTracebackException( "BigWig file {} does not exist".format(signal_file)) signal_files.append(pyBigWig.open(signal_file)) # counts dictionary maintains a list of counts for each # peak for each of the signal files # we will use this list to create a counts column for each of the # signal files counts = {} for signal_file in signal_files: counts[signal_file] = [] # iterate through all peaks and read values from the bigWig files logging.info("Computing counts for each peak") for _, row in tqdm(peaks_df.iterrows(), desc='peaks', total=len(peaks_df)): chrom = row['chrom'] start = row['start'] end = row['end'] for signal_file in signal_files: counts[signal_file].append( np.sum(np.nan_to_num(signal_file.values(chrom, start, end)))) # add a new counts column to the peaks dataframe for each # signal file for signal_file in signal_files: peaks_df[signal_file] = counts[signal_file] # average the counts across the signal files peaks_df['avg_counts'] = peaks_df[signal_files].mean(axis=1) # sort the dataframe in ascending order of counts peaks_df = peaks_df.sort_values(by=['avg_counts']) # compute the quantile value counts = peaks_df['avg_counts'].values nth_quantile = np.quantile(counts, args.quantile) logging.info("{} quantile {}".format(args.quantile, nth_quantile)) # get index of quantile value quantile_idx = abs(counts - nth_quantile).argmin() logging.info("quantile idx {}".format(quantile_idx)) # scale value at quantile index scaled_value = counts[quantile_idx] * args.quantile_value_scale_factor logging.info("scaled_value {}".format(scaled_value)) # check if any of the counts are above the scaled_value if np.sum(counts > scaled_value) > 0: # index of values greater than scaled_value max_idx = np.argmax(counts > scaled_value) logging.info("max_idx {}".format(max_idx)) # trimmed data frame with outliers removed logging.info("original size {}".format(len(peaks_df))) peaks_df = peaks_df[:max_idx] logging.info("new size {}".format(len(peaks_df))) else: logging.info("No outliers found based on criteria. " "Keeping original loci.") # save the new dataframe logging.info("Saving output bed file ... {}".format(args.output_bed)) peaks_df = peaks_df[[ 'chrom', 'st', 'e', 'name', 'weight', 'strand', 'signal', 'p', 'q', 'summit' ]] peaks_df.to_csv(args.output_bed, header=None, sep='\t', index=False)
def train_and_validate(input_params, output_params, genome_params, batch_gen_params, hyper_params, parallelization_params, network_params, train_chroms, val_chroms, model_dir, suffix_tag=None): """ Train and validate on a single train and validation set Note: the list & description for each of the required keys in all of the json parameter files passed to this fucntion can be found here: http:// Args: input_params (str): path to json file containing input parameters output_params (str): path to json file containing output parameters genome_params (str): path to json file containing genome parameters batch_gen_params (str): path to json file containing batch generation parameters hyper_params (str): path to json file containing training & validation hyper parameters parallelization_params (str): path to json file containing parameters for parallelization options network_params (str): path to json file containing parameters specific to the deep learning architecture train_chroms (list): list of training chromosomes val_chroms (list): list of validation chromosomes model_dir (str): the path to the output directory suffix_tag (str): optional tag to add as a suffix to files (model, log, history & config params files) created in the model directory Returns: keras.models.Model """ # filename to write debug logs if suffix_tag is not None: logfname = '{}/trainer_{}.log'.format(model_dir, suffix_tag) else: logfname = '{}/trainer.log'.format(model_dir) # we need to initialize the logger for each process logger.init_logger(logfname) # parameters that are specific to the training batch generation # process train_batch_gen_params = batch_gen_params train_batch_gen_params['mode'] = 'train' # parameters that are specific to the validation batch generation # process. For validation we dont use jitter, reverse complement # augmentation and negative sampling val_batch_gen_params = copy.deepcopy(batch_gen_params) val_batch_gen_params['max_jitter'] = 0 val_batch_gen_params['rev_comp_aug'] = False val_batch_gen_params['negative_sampling_rate'] = 0.0 val_batch_gen_params['mode'] = 'val' # get the corresponding batch generator class for this model sequence_generator_class_name = generators.find_generator_by_name( batch_gen_params['sequence_generator_name']) logging.info("SEQGEN Class Name: {}".format(sequence_generator_class_name)) BatchGenerator = getattr(generators, sequence_generator_class_name) # instantiate the batch generator class for training train_gen = BatchGenerator(input_params, train_batch_gen_params, network_params, genome_params['reference_genome'], genome_params['chrom_sizes'], train_chroms, num_threads=parallelization_params['threads'], epochs=hyper_params['epochs'], batch_size=hyper_params['batch_size']) # training generator function that will be passed to # fit_generator train_generator = train_gen.gen() # instantiate the batch generator class for validation val_gen = BatchGenerator(input_params, val_batch_gen_params, network_params, genome_params['reference_genome'], genome_params['chrom_sizes'], val_chroms, num_threads=parallelization_params['threads'], epochs=hyper_params['epochs'], batch_size=hyper_params['batch_size']) # validation generator function that will be passed to # fit_generator val_generator = val_gen.gen() # lets make sure the sizes look reasonable logging.info("TRAINING SIZE - {}".format(train_gen._samples.shape)) logging.info("VALIDATION SIZE - {}".format(val_gen._samples.shape)) # we need to calculate the number of training steps and # validation steps in each epoch, fit_generator requires this # to determine the end of an epoch train_steps = train_gen.len() val_steps = val_gen.len() # we may have to reduce the --threads sometimes # if the peak file has very few peaks, so we need to # check if these numbers will be 0 logging.info("TRAINING STEPS - {}".format(train_steps)) logging.info("VALIDATION STEPS - {}".format(val_steps)) # Here we specify all our callbacks # 1. Early stopping if validation loss doesn't decrease es = EarlyStopping(monitor='val_loss', mode='min', verbose=1, patience=hyper_params['early_stopping_patience'], min_delta=hyper_params['early_stopping_min_delta'], restore_best_weights=True) # 2. Reduce learning rate if validation loss is plateuing reduce_lr = ReduceLROnPlateau( monitor='val_loss', factor=0.5, patience=hyper_params['reduce_lr_on_plateau_patience'], min_lr=hyper_params['min_learning_rate']) # 3. Timing hook to record start, end & elapsed time for each # epoch time_tracker = TimeHistory() # 4. Batch controller callbacks to ensure that batches are # generated only on a per epoch basis, also ensures graceful # termination of the batch generation train_batch_controller = BatchController(train_gen) val_batch_controller = BatchController(val_gen) # get an instance of the model logging.debug("New {} model".format(network_params['name'])) get_model = getattr(model_archs, network_params['name']) model = get_model(train_batch_gen_params['input_seq_len'], train_batch_gen_params['output_len'], len(network_params['control_smoothing']) + 1, filters=network_params['filters'], num_tasks=train_gen._num_tasks) model.summary() # if running in multi gpu mode if parallelization_params['gpus'] > 1: logging.debug("Multi GPU model") model = multi_gpu_model(model, gpus=parallelization_params['gpus']) # compile the model logging.debug("Compiling model") logging.info("counts_loss_weight - {}".format( network_params['counts_loss_weight'])) model.compile( Adam(lr=hyper_params['learning_rate']), loss=[MultichannelMultinomialNLL(train_gen._num_tasks), 'mse'], loss_weights=[1, network_params['counts_loss_weight']]) # begin time for training t1 = time.time() # start training logging.debug("Training started ...") with warnings.catch_warnings(): warnings.simplefilter('ignore') history = model.fit_generator(train_generator, validation_data=val_generator, epochs=hyper_params['epochs'], steps_per_epoch=train_steps, validation_steps=val_steps, callbacks=[ es, reduce_lr, time_tracker, train_batch_controller, val_batch_controller ]) # end time for training t2 = time.time() logging.info("Total Elapsed Time: {}".format(t2 - t1)) # send the stop signal to the generators train_gen.set_stop() val_gen.set_stop() # base model filename if output_params['automate_filenames']: # get random alphanumeric tag for model model_tag = getAlphaNumericTag(output_params['tag_length']) model_fname = "{}/{}".format(model_dir, model_tag) elif output_params['model_output_filename'] is not None: model_fname = "{}/{}".format(model_dir, output_params['model_output_filename']) else: model_fname = "{}/model".format(model_dir) # add suffix tag to model name if suffix_tag is not None: model_fname += "_{}".format(suffix_tag) # extension model_fname += ".h5" # save HDF5 model file model.save(model_fname) logging.info("Finished saving model: {}".format(model_fname)) # save history to json: # Step 1. create a custom history object with a new key for # epoch times custom_history = copy.deepcopy(history.history) custom_history['times'] = time_tracker.times # Step 2. convert the custom history dict to a pandas DataFrame: hist_df = pd.DataFrame(custom_history) # file name for json file hist_json = model_fname.replace('.h5', '.history.json') # Step 3. write the dataframe to json with open(hist_json, mode='w') as f: hist_df.to_json(f) logging.info("Finished saving training and validation history: {}".format( hist_json)) # write all the command line arguments to a json file # & include the number of epochs the training lasted for, and the # validation and testchroms config_file = '{}/config'.format(model_dir) # add suffix tag to model name if suffix_tag is not None: config_file += "_{}".format(suffix_tag) # extension config_file += ".json" with open(config_file, 'w') as fp: config = {} config['input_params'] = input_params config['output_params'] = output_params config['genome_params'] = genome_params config['batch_gen_params'] = batch_gen_params config['hyper_params'] = hyper_params config['parallelization_params'] = parallelization_params config['network_params'] = network_params # the number of epochs the training lasted epochs = len(history.history['val_loss']) config['training_epochs'] = epochs config['train_chroms'] = train_chroms config['val_chroms'] = val_chroms config['model_filename'] = model_fname json.dump(config, fp) return model
def train_and_validate_ksplits(input_params, output_params, genome_params, batch_gen_params, hyper_params, parallelization_params, network_params, splits): """ Train and validate on one or more train/val splits Note: the list & description for each of the required keys in all of the json parameter files passed to this function can be found here: http:// Args: input_params (str): path to json file containing input parameters output_params (str): path to json file containing output parameters genome_params (str): path to json file containing genome parameters batch_gen_params (str): path to json file containing batch generation parameters hyper_params (str): path to json file containing training & validation hyper parameters parallelization_params (str): path to json file containing parameters for parallelization options network_params (str): path to json file containing parameters specific to the deep learning architecture splits (str): path to the json file containing train & validation splits """ # list of chromosomes after removing the excluded chromosomes chroms = set(genome_params['chroms']).difference( set(genome_params['exclude_chroms'])) # list of models from all of the splits models = [] # run training for each validation/test split num_splits = len(list(splits.keys())) for i in range(num_splits): if output_params['automate_filenames']: # create a new directory using current date/time to store the # model, the loss history and logs date_time_str = local_datetime_str(output_params['time_zone']) model_dir = '{}/{}_split{:03d}'.format(output_params['output_dir'], date_time_str, i) os.mkdir(model_dir) split_tag = None elif os.path.isdir(output_params['output_dir']): model_dir = output_params['output_dir'] split_tag = "split{:03d}".format(i) else: logging.error("Directory does not exist {}.".format( output_params['output_dir'])) return # filename to write debug logs logfname = '{}/trainer.log'.format(model_dir) # set up logger for main procecss logger.init_logger(logfname) # train & validation chromosome split if 'val' not in splits[str(i)]: logging.error("KeyError: 'val' required for split {}".format(i)) return val_chroms = splits[str(i)]['val'] # if 'train' key is present if 'train' in splits[str(i)]: train_chroms = splits[str(i)]['train'] # if 'test' key is present but train is not elif 'test' in splits[str(i)]: test_chroms = splits[str(i)]['test'] # take the set difference of the whole list of # chroms with the union of val and test train_chroms = list( chroms.difference(set(val_chroms + test_chroms))) else: # take the set difference of the whole list of # chroms with val train_chroms = list(chroms.difference(val_chroms)) logging.info("Split #{}".format(i)) logging.info("Train: {}".format(train_chroms)) logging.info("Val: {}".format(val_chroms)) # Start training for the split in a separate process # This ensures that all resources are freed, when the # process terminates, & available for training the next split # Mitigates the problem where training subsequent splits # is considerably slow logging.debug("Split {}: Creating training process".format(i)) p = mp.Process(target=train_and_validate, args=[ input_params, output_params, genome_params, batch_gen_params, hyper_params, parallelization_params, network_params, train_chroms, val_chroms, model_dir, split_tag ]) p.start() # wait for the process to finish p.join()