def batch_dataset_generator(gen, args, is_testing=False): grid_size = subgrid_gen.grid_size(args.grid_config) channel_size = __channel_size(args) dataset = tf.data.Dataset.from_generator( gen, output_types=(tf.string, tf.float32, tf.float32), output_shapes=((), (2, grid_size, grid_size, grid_size, channel_size), (1,)) ) # Shuffle dataset if not is_testing: if args.shuffle: dataset = dataset.repeat(count=None) else: dataset = dataset.apply( tf.contrib.data.shuffle_and_repeat(buffer_size=1000)) dataset = dataset.batch(args.batch_size) dataset = dataset.prefetch(8) iterator = dataset.make_one_shot_iterator() next_element = iterator.get_next() return dataset, next_element
def train_model(sess, args): # tf Graph input # Subgrid maps for each residue in a protein logging.debug('Create input placeholder...') grid_size = subgrid_gen.grid_size(args.grid_config) channel_size = __channel_size(args) feature_placeholder = tf.placeholder( tf.float32, [None, 2, grid_size, grid_size, grid_size, channel_size], name='main_input') label_placeholder = tf.placeholder(tf.int8, [None, 1], 'label') # Placeholder for model parameters training_placeholder = tf.placeholder(tf.bool, shape=[], name='is_training') conv_drop_rate_placeholder = tf.placeholder(tf.float32, name='conv_drop_rate') fc_drop_rate_placeholder = tf.placeholder(tf.float32, name='fc_drop_rate') top_nn_drop_rate_placeholder = tf.placeholder(tf.float32, name='top_nn_drop_rate') # Define loss and optimizer logging.debug('Define loss and optimizer...') logits_op, predict_op, loss_op, accuracy_op = conv_model( feature_placeholder, label_placeholder, training_placeholder, conv_drop_rate_placeholder, fc_drop_rate_placeholder, top_nn_drop_rate_placeholder, args) logging.debug('Generate training ops...') train_op = model.training(loss_op, args.learning_rate) # Initialize the variables (i.e. assign their default value) logging.debug('Initializing global variables...') init = tf.global_variables_initializer() # Create saver and summaries. logging.debug('Initializing saver...') saver = tf.train.Saver(max_to_keep=100000) logging.debug('Finished initializing saver...') def __loop(generator, mode, num_iters): tf_dataset, next_element = batch_dataset_generator( generator, args, is_testing=(mode=='test')) ensembles, losses, logits, preds, labels = [], [], [], [], [] epoch_loss = 0 epoch_acc = 0 progress_format = mode + ' loss: {:6.6f}' + '; acc: {:6.4f}' # Loop over all batches (one batch is all feature for 1 protein) num_batches = int(math.ceil(float(num_iters)/args.batch_size)) #print('\nRunning {:} -> {:} iters in {:} batches (batch size: {:})'.format( # mode, num_iters, num_batches, args.batch_size)) with tqdm.tqdm(total=num_batches, desc=progress_format.format(0, 0)) as t: for i in range(num_batches): try: ensemble_, feature_, label_ = sess.run(next_element) _, logit, pred, loss, accuracy = sess.run( [train_op, logits_op, predict_op, loss_op, accuracy_op], feed_dict={feature_placeholder: feature_, label_placeholder: label_, training_placeholder: (mode == 'train'), conv_drop_rate_placeholder: args.conv_drop_rate if mode == 'train' else 0.0, fc_drop_rate_placeholder: args.fc_drop_rate if mode == 'train' else 0.0, top_nn_drop_rate_placeholder: args.top_nn_drop_rate if mode == 'train' else 0.0}) #print('logit: {:}, predict: {:}, loss: {:.3f}, actual: {:}'.format(logit, pred, loss, label_)) epoch_loss += (np.mean(loss) - epoch_loss) / (i + 1) epoch_acc += (np.mean(accuracy) - epoch_acc) / (i + 1) ensembles.extend(ensemble_.astype(str)) losses.append(loss) logits.extend(logit.astype(np.float)) preds.extend(pred.astype(np.int8)) labels.extend(label_.astype(np.int8)) t.set_description(progress_format.format(epoch_loss, epoch_acc)) t.update(1) except (tf.errors.OutOfRangeError, StopIteration): logging.info("\nEnd of {:} dataset at iteration {:}".format(mode, i)) break def __concatenate(array): try: array = np.concatenate(array) return array except: return array ensembles = __concatenate(ensembles) logits = __concatenate(logits) preds = __concatenate(preds) labels = __concatenate(labels) losses = __concatenate(losses) return ensembles, logits, preds, labels, losses, epoch_loss # Run the initializer logging.debug('Running initializer...') sess.run(init) logging.debug('Finished running initializer...') ##### Training + validation if not args.test_only: prev_val_loss, best_val_loss = float("inf"), float("inf") multiplier = args.grid_config.max_pos_per_shard * int( 1 + args.grid_config.neg_to_pos_ratio) train_num_ensembles = args.train_sharded.get_num_shards()*multiplier train_num_ensembles *= args.repeat_gen val_num_ensembles = args.val_sharded.get_num_shards()*multiplier val_num_ensembles *= args.repeat_gen logging.info("Start training with {:} ensembles for train and {:} ensembles for val per epoch".format( train_num_ensembles, val_num_ensembles)) def _save(): ckpt = saver.save(sess, os.path.join(args.output_dir, 'model-ckpt'), global_step=epoch) return ckpt run_info_filename = os.path.join(args.output_dir, 'run_info.json') run_info = {} def __update_and_write_run_info(key, val): run_info[key] = val with open(run_info_filename, 'w') as f: json.dump(run_info, f, indent=4) per_epoch_val_losses = [] for epoch in range(1, args.num_epochs+1): random_seed = args.random_seed #random.randint(1, 10e6) logging.info('Epoch {:} - random_seed: {:}'.format(epoch, args.random_seed)) logging.debug('Creating train generator...') train_generator_callable = functools.partial( feature_mut.dataset_generator, args.train_sharded, args.grid_config, shuffle=args.shuffle, repeat=args.repeat_gen, add_flag=args.add_flag, center_at_mut=args.center_at_mut, testing=False, random_seed=random_seed) logging.debug('Creating val generator...') val_generator_callable = functools.partial( feature_mut.dataset_generator, args.val_sharded, args.grid_config, shuffle=args.shuffle, repeat=args.repeat_gen, add_flag=args.add_flag, center_at_mut=args.center_at_mut, testing=False, random_seed=random_seed) # Training train_ensembles, train_logits, train_preds, train_labels, _, curr_train_loss = __loop( train_generator_callable, 'train', num_iters=train_num_ensembles) # Validation val_ensembles, val_logits, val_preds, val_labels, _, curr_val_loss = __loop( val_generator_callable, 'val', num_iters=val_num_ensembles) per_epoch_val_losses.append(curr_val_loss) __update_and_write_run_info('val_losses', per_epoch_val_losses) if args.use_best or args.early_stopping: if curr_val_loss < best_val_loss: # Found new best epoch. best_val_loss = curr_val_loss ckpt = _save() __update_and_write_run_info('val_best_loss', best_val_loss) __update_and_write_run_info('best_ckpt', ckpt) logging.info("New best {:}".format(ckpt)) if (epoch == args.num_epochs - 1 and not args.use_best): # At end and just using final checkpoint. ckpt = _save() __update_and_write_run_info('best_ckpt', ckpt) logging.info("Last checkpoint {:}".format(ckpt)) if args.save_all_ckpts: # Save at every checkpoint ckpt = _save() logging.info("Saving checkpoint {:}".format(ckpt)) ## Save train and val results logging.info("Saving train and val results") train_df = pd.DataFrame( np.array([train_ensembles, train_labels, train_preds, train_logits]).T, columns=['ensembles', 'true', 'pred', 'logits'], ) train_df.to_pickle(os.path.join(args.output_dir, 'train_result-{:}.pkl'.format(epoch))) val_df = pd.DataFrame( np.array([val_ensembles, val_labels, val_preds, val_logits]).T, columns=['ensembles', 'true', 'pred', 'logits'], ) val_df.to_pickle(os.path.join(args.output_dir, 'val_result-{:}.pkl'.format(epoch))) __stats('Train Epoch {:}'.format(epoch), train_df) __stats('Val Epoch {:}'.format(epoch), val_df) if args.early_stopping and curr_val_loss >= prev_val_loss: logging.info("Validation loss stopped decreasing, stopping...") break else: prev_val_loss = curr_val_loss logging.info("Finished training") ##### Testing logging.debug("Run testing") if not args.test_only: to_use = run_info['best_ckpt'] if args.use_best else ckpt else: if args.use_ckpt_num == None: with open(os.path.join(args.model_dir, 'run_info.json')) as f: run_info = json.load(f) to_use = run_info['best_ckpt'] else: to_use = os.path.join( args.model_dir, 'model-ckpt-{:}'.format(args.use_ckpt_num)) saver = tf.train.import_meta_graph(to_use + '.meta') logging.info("Using {:} for testing".format(to_use)) saver.restore(sess, to_use) test_generator_callable = functools.partial( feature_mut.dataset_generator, args.test_sharded, args.grid_config, shuffle=args.shuffle, repeat=args.repeat_gen, add_flag=args.add_flag, center_at_mut=args.center_at_mut, testing=True, random_seed=args.random_seed) test_num_ensembles = args.test_sharded.get_num_keyed() test_num_ensembles *= args.repeat_gen logging.info("Start testing with {:} ensembles".format(test_num_ensembles)) test_ensembles, test_logits, test_preds, test_labels, _, test_loss = __loop( test_generator_callable, 'test', num_iters=test_num_ensembles) logging.info("Finished testing") test_df = pd.DataFrame( np.array([test_ensembles, test_labels, test_preds, test_logits]).T, columns=['ensembles', 'true', 'pred', 'logits'], ) test_df.to_pickle(os.path.join(args.output_dir, 'test_result.pkl')) __stats('Test', test_df)
def train_model(sess, args): # tf Graph input # Subgrid maps for each residue in a protein logging.debug('Create input placeholder...') grid_size = subgrid_gen.grid_size(args.grid_config) channel_size = subgrid_gen.num_channels(args.grid_config) feature_placeholder = tf.placeholder( tf.float32, [None, grid_size, grid_size, grid_size, channel_size], name='main_input') label_placeholder = tf.placeholder(tf.float32, [None, 1], 'label') # Placeholder for model parameters training_placeholder = tf.placeholder(tf.bool, shape=[], name='is_training') conv_drop_rate_placeholder = tf.placeholder(tf.float32, name='conv_drop_rate') fc_drop_rate_placeholder = tf.placeholder(tf.float32, name='fc_drop_rate') top_nn_drop_rate_placeholder = tf.placeholder(tf.float32, name='top_nn_drop_rate') # Define loss and optimizer logging.debug('Define loss and optimizer...') predict_op, loss_op = conv_model( feature_placeholder, label_placeholder, training_placeholder, conv_drop_rate_placeholder, fc_drop_rate_placeholder, top_nn_drop_rate_placeholder, args) logging.debug('Generate training ops...') train_op = model.training(loss_op, args.learning_rate) # Initialize the variables (i.e. assign their default value) logging.debug('Initializing global variables...') init = tf.global_variables_initializer() # Create saver and summaries. logging.debug('Initializing saver...') saver = tf.train.Saver(max_to_keep=100000) logging.debug('Finished initializing saver...') def __loop(generator, mode, num_iters): tf_dataset, next_element = batch_dataset_generator( generator, args, is_testing=(mode=='test')) structs, losses, preds, labels = [], [], [], [] epoch_loss = 0 progress_format = mode + ' loss: {:6.6f}' # Loop over all batches (one batch is all feature for 1 protein) num_batches = int(math.ceil(float(num_iters)/args.batch_size)) #print('Running {:} -> {:} iters in {:} batches (batch size: {:})'.format( # mode, num_iters, num_batches, args.batch_size)) with tqdm.tqdm(total=num_batches, desc=progress_format.format(0)) as t: for i in range(num_batches): try: struct_, feature_, label_ = sess.run(next_element) _, pred, loss = sess.run( [train_op, predict_op, loss_op], feed_dict={feature_placeholder: feature_, label_placeholder: label_, training_placeholder: (mode == 'train'), conv_drop_rate_placeholder: args.conv_drop_rate if mode == 'train' else 0.0, fc_drop_rate_placeholder: args.fc_drop_rate if mode == 'train' else 0.0, top_nn_drop_rate_placeholder: args.top_nn_drop_rate if mode == 'train' else 0.0}) epoch_loss += (np.mean(loss) - epoch_loss) / (i + 1) structs.extend(struct_) losses.append(loss) preds.extend(pred) labels.extend(label_) t.set_description(progress_format.format(epoch_loss)) t.update(1) except StopIteration: logging.info("\nEnd of dataset at iteration {:}".format(i)) break def __concatenate(array): try: array = np.concatenate(array) return array except: return array structs = __concatenate(structs) preds = __concatenate(preds) labels = __concatenate(labels) losses = __concatenate(losses) return structs, preds, labels, losses, epoch_loss # Run the initializer logging.debug('Running initializer...') sess.run(init) logging.debug('Finished running initializer...') ##### Training + validation if not args.test_only: prev_val_loss, best_val_loss = float("inf"), float("inf") if (args.max_pdbs_train == None): pdbcodes = feature_pdbbind.read_split(args.train_split_filename) train_num_structs = len(pdbcodes) else: train_num_structs = args.max_pdbs_train if (args.max_pdbs_val == None): pdbcodes = feature_pdbbind.read_split(args.val_split_filename) val_num_structs = len(pdbcodes) else: val_num_structs = args.max_pdbs_val train_num_structs *= args.repeat_gen val_num_structs *= args.repeat_gen logging.info("Start training with {:} structs for train and {:} structs for val per epoch".format( train_num_structs, val_num_structs)) def _save(): ckpt = saver.save(sess, os.path.join(args.output_dir, 'model-ckpt'), global_step=epoch) return ckpt run_info_filename = os.path.join(args.output_dir, 'run_info.json') run_info = {} def __update_and_write_run_info(key, val): run_info[key] = val with open(run_info_filename, 'w') as f: json.dump(run_info, f, indent=4) per_epoch_val_losses = [] for epoch in range(1, args.num_epochs+1): random_seed = args.random_seed #random.randint(1, 10e6) logging.info('Epoch {:} - random_seed: {:}'.format(epoch, args.random_seed)) logging.debug('Creating train generator...') train_generator_callable = functools.partial( feature_pdbbind.dataset_generator, args.data_filename, args.train_split_filename, args.labels_filename, args.grid_config, shuffle=args.shuffle, repeat=args.repeat_gen, max_pdbs=args.max_pdbs_train, random_seed=random_seed) logging.debug('Creating val generator...') val_generator_callable = functools.partial( feature_pdbbind.dataset_generator, args.data_filename, args.val_split_filename, args.labels_filename, args.grid_config, shuffle=args.shuffle, repeat=args.repeat_gen, max_pdbs=args.max_pdbs_val, random_seed=random_seed) # Training train_structs, train_preds, train_labels, _, curr_train_loss = __loop( train_generator_callable, 'train', num_iters=train_num_structs) # Validation val_structs, val_preds, val_labels, _, curr_val_loss = __loop( val_generator_callable, 'val', num_iters=val_num_structs) per_epoch_val_losses.append(curr_val_loss) __update_and_write_run_info('val_losses', per_epoch_val_losses) if args.use_best or args.early_stopping: if curr_val_loss < best_val_loss: # Found new best epoch. best_val_loss = curr_val_loss ckpt = _save() __update_and_write_run_info('val_best_loss', best_val_loss) __update_and_write_run_info('best_ckpt', ckpt) logging.info("New best {:}".format(ckpt)) if (epoch == args.num_epochs - 1 and not args.use_best): # At end and just using final checkpoint. ckpt = _save() __update_and_write_run_info('best_ckpt', ckpt) logging.info("Last checkpoint {:}".format(ckpt)) if args.save_all_ckpts: # Save at every checkpoint ckpt = _save() logging.info("Saving checkpoint {:}".format(ckpt)) if args.early_stopping and curr_val_loss >= prev_val_loss: logging.info("Validation loss stopped decreasing, stopping...") break else: prev_val_loss = curr_val_loss logging.info("Finished training") ## Save last train and val results logging.info("Saving train and val results") train_df = pd.DataFrame( np.array([train_structs, train_labels, train_preds]).T, columns=['structure', 'true', 'pred'], ) train_df.to_pickle(os.path.join(args.output_dir, 'train_result.pkl')) val_df = pd.DataFrame( np.array([val_structs, val_labels, val_preds]).T, columns=['structure', 'true', 'pred'], ) val_df.to_pickle(os.path.join(args.output_dir, 'val_result.pkl')) ##### Testing logging.debug("Run testing") if not args.test_only: to_use = run_info['best_ckpt'] if args.use_best else ckpt else: if args.use_ckpt_num == None: with open(os.path.join(args.model_dir, 'run_info.json')) as f: run_info = json.load(f) to_use = run_info['best_ckpt'] else: to_use = os.path.join( args.model_dir, 'model-ckpt-{:}'.format(args.use_ckpt_num)) saver = tf.train.import_meta_graph(to_use + '.meta') test_generator_callable = functools.partial( feature_pdbbind.dataset_generator, args.data_filename, args.test_split_filename, args.labels_filename, args.grid_config, shuffle=args.shuffle, repeat=1, max_pdbs=args.max_pdbs_test, random_seed=args.random_seed) if (args.max_pdbs_test == None): pdbcodes = feature_pdbbind.read_split(args.test_split_filename) test_num_structs = len(pdbcodes) else: test_num_structs = args.max_pdbs_test logging.info("Start testing with {:} structs".format(test_num_structs)) test_structs, test_preds, test_labels, _, test_loss = __loop( test_generator_callable, 'test', num_iters=test_num_structs) logging.info("Finished testing") test_df = pd.DataFrame( np.array([test_structs, test_labels, test_preds]).T, columns=['structure', 'true', 'pred'], ) test_df.to_pickle(os.path.join(args.output_dir, 'test_result.pkl')) # Compute global correlations res = compute_stats(test_df) logging.info( '\nStats\n' ' RMSE: {:.3f}\n' ' Pearson: {:.3f}\n' ' Spearman: {:.3f}'.format( float(res["rmse"]), float(res["all_pearson"]), float(res["all_spearman"])))