Beispiel #1
0
def batch_dataset_generator(gen, args, is_testing=False):
    grid_size = subgrid_gen.grid_size(args.grid_config)
    channel_size = __channel_size(args)
    dataset = tf.data.Dataset.from_generator(
        gen,
        output_types=(tf.string, tf.float32, tf.float32),
        output_shapes=((), (2, grid_size, grid_size, grid_size, channel_size), (1,))
        )

    # Shuffle dataset
    if not is_testing:
        if args.shuffle:
            dataset = dataset.repeat(count=None)
        else:
            dataset = dataset.apply(
                tf.contrib.data.shuffle_and_repeat(buffer_size=1000))

    dataset = dataset.batch(args.batch_size)
    dataset = dataset.prefetch(8)

    iterator = dataset.make_one_shot_iterator()
    next_element = iterator.get_next()
    return dataset, next_element
Beispiel #2
0
def train_model(sess, args):
    # tf Graph input
    # Subgrid maps for each residue in a protein
    logging.debug('Create input placeholder...')
    grid_size = subgrid_gen.grid_size(args.grid_config)
    channel_size = __channel_size(args)
    feature_placeholder = tf.placeholder(
        tf.float32,
        [None, 2, grid_size, grid_size, grid_size, channel_size],
        name='main_input')
    label_placeholder = tf.placeholder(tf.int8, [None, 1], 'label')

    # Placeholder for model parameters
    training_placeholder = tf.placeholder(tf.bool, shape=[], name='is_training')
    conv_drop_rate_placeholder = tf.placeholder(tf.float32, name='conv_drop_rate')
    fc_drop_rate_placeholder = tf.placeholder(tf.float32, name='fc_drop_rate')
    top_nn_drop_rate_placeholder = tf.placeholder(tf.float32, name='top_nn_drop_rate')

    # Define loss and optimizer
    logging.debug('Define loss and optimizer...')
    logits_op, predict_op, loss_op, accuracy_op = conv_model(
        feature_placeholder, label_placeholder, training_placeholder,
        conv_drop_rate_placeholder, fc_drop_rate_placeholder,
        top_nn_drop_rate_placeholder, args)
    logging.debug('Generate training ops...')
    train_op = model.training(loss_op, args.learning_rate)

    # Initialize the variables (i.e. assign their default value)
    logging.debug('Initializing global variables...')
    init = tf.global_variables_initializer()

    # Create saver and summaries.
    logging.debug('Initializing saver...')
    saver = tf.train.Saver(max_to_keep=100000)
    logging.debug('Finished initializing saver...')

    def __loop(generator, mode, num_iters):
        tf_dataset, next_element = batch_dataset_generator(
            generator, args, is_testing=(mode=='test'))

        ensembles, losses, logits, preds, labels = [], [], [], [], []
        epoch_loss = 0
        epoch_acc = 0
        progress_format = mode + ' loss: {:6.6f}' + '; acc: {:6.4f}'

        # Loop over all batches (one batch is all feature for 1 protein)
        num_batches = int(math.ceil(float(num_iters)/args.batch_size))
        #print('\nRunning {:} -> {:} iters in {:} batches (batch size: {:})'.format(
        #    mode, num_iters, num_batches, args.batch_size))
        with tqdm.tqdm(total=num_batches, desc=progress_format.format(0, 0)) as t:
            for i in range(num_batches):
                try:
                    ensemble_, feature_, label_ = sess.run(next_element)
                    _, logit, pred, loss, accuracy = sess.run(
                        [train_op, logits_op, predict_op, loss_op, accuracy_op],
                        feed_dict={feature_placeholder: feature_,
                                   label_placeholder: label_,
                                   training_placeholder: (mode == 'train'),
                                   conv_drop_rate_placeholder:
                                       args.conv_drop_rate if mode == 'train' else 0.0,
                                   fc_drop_rate_placeholder:
                                       args.fc_drop_rate if mode == 'train' else 0.0,
                                   top_nn_drop_rate_placeholder:
                                       args.top_nn_drop_rate if mode == 'train' else 0.0})
                    #print('logit: {:}, predict: {:}, loss: {:.3f}, actual: {:}'.format(logit, pred, loss, label_))
                    epoch_loss += (np.mean(loss) - epoch_loss) / (i + 1)
                    epoch_acc += (np.mean(accuracy) - epoch_acc) / (i + 1)
                    ensembles.extend(ensemble_.astype(str))
                    losses.append(loss)
                    logits.extend(logit.astype(np.float))
                    preds.extend(pred.astype(np.int8))
                    labels.extend(label_.astype(np.int8))

                    t.set_description(progress_format.format(epoch_loss, epoch_acc))
                    t.update(1)
                except (tf.errors.OutOfRangeError, StopIteration):
                    logging.info("\nEnd of {:} dataset at iteration {:}".format(mode, i))
                    break

        def __concatenate(array):
            try:
                array = np.concatenate(array)
                return array
            except:
                return array

        ensembles = __concatenate(ensembles)
        logits = __concatenate(logits)
        preds = __concatenate(preds)
        labels = __concatenate(labels)
        losses = __concatenate(losses)
        return ensembles, logits, preds, labels, losses, epoch_loss

    # Run the initializer
    logging.debug('Running initializer...')
    sess.run(init)
    logging.debug('Finished running initializer...')

    ##### Training + validation
    if not args.test_only:
        prev_val_loss, best_val_loss = float("inf"), float("inf")

        multiplier = args.grid_config.max_pos_per_shard * int(
            1 + args.grid_config.neg_to_pos_ratio)

        train_num_ensembles = args.train_sharded.get_num_shards()*multiplier
        train_num_ensembles *= args.repeat_gen

        val_num_ensembles = args.val_sharded.get_num_shards()*multiplier
        val_num_ensembles *= args.repeat_gen

        logging.info("Start training with {:} ensembles for train and {:} ensembles for val per epoch".format(
            train_num_ensembles, val_num_ensembles))


        def _save():
            ckpt = saver.save(sess, os.path.join(args.output_dir, 'model-ckpt'),
                              global_step=epoch)
            return ckpt

        run_info_filename = os.path.join(args.output_dir, 'run_info.json')
        run_info = {}
        def __update_and_write_run_info(key, val):
            run_info[key] = val
            with open(run_info_filename, 'w') as f:
                json.dump(run_info, f, indent=4)

        per_epoch_val_losses = []
        for epoch in range(1, args.num_epochs+1):
            random_seed = args.random_seed #random.randint(1, 10e6)
            logging.info('Epoch {:} - random_seed: {:}'.format(epoch, args.random_seed))

            logging.debug('Creating train generator...')
            train_generator_callable = functools.partial(
                feature_mut.dataset_generator,
                args.train_sharded,
                args.grid_config,
                shuffle=args.shuffle,
                repeat=args.repeat_gen,
                add_flag=args.add_flag,
                center_at_mut=args.center_at_mut,
                testing=False,
                random_seed=random_seed)

            logging.debug('Creating val generator...')
            val_generator_callable = functools.partial(
                feature_mut.dataset_generator,
                args.val_sharded,
                args.grid_config,
                shuffle=args.shuffle,
                repeat=args.repeat_gen,
                add_flag=args.add_flag,
                center_at_mut=args.center_at_mut,
                testing=False,
                random_seed=random_seed)

            # Training
            train_ensembles, train_logits, train_preds, train_labels, _, curr_train_loss = __loop(
                train_generator_callable, 'train', num_iters=train_num_ensembles)
            # Validation
            val_ensembles, val_logits, val_preds, val_labels, _, curr_val_loss = __loop(
                val_generator_callable, 'val', num_iters=val_num_ensembles)

            per_epoch_val_losses.append(curr_val_loss)
            __update_and_write_run_info('val_losses', per_epoch_val_losses)

            if args.use_best or args.early_stopping:
                if curr_val_loss < best_val_loss:
                    # Found new best epoch.
                    best_val_loss = curr_val_loss
                    ckpt = _save()
                    __update_and_write_run_info('val_best_loss', best_val_loss)
                    __update_and_write_run_info('best_ckpt', ckpt)
                    logging.info("New best {:}".format(ckpt))

            if (epoch == args.num_epochs - 1 and not args.use_best):
                # At end and just using final checkpoint.
                ckpt = _save()
                __update_and_write_run_info('best_ckpt', ckpt)
                logging.info("Last checkpoint {:}".format(ckpt))

            if args.save_all_ckpts:
                # Save at every checkpoint
                ckpt = _save()
                logging.info("Saving checkpoint {:}".format(ckpt))

            ## Save train and val results
            logging.info("Saving train and val results")
            train_df = pd.DataFrame(
                np.array([train_ensembles, train_labels, train_preds, train_logits]).T,
                columns=['ensembles', 'true', 'pred', 'logits'],
                )
            train_df.to_pickle(os.path.join(args.output_dir, 'train_result-{:}.pkl'.format(epoch)))

            val_df = pd.DataFrame(
                np.array([val_ensembles, val_labels, val_preds, val_logits]).T,
                columns=['ensembles', 'true', 'pred', 'logits'],
                )
            val_df.to_pickle(os.path.join(args.output_dir, 'val_result-{:}.pkl'.format(epoch)))

            __stats('Train Epoch {:}'.format(epoch), train_df)
            __stats('Val Epoch {:}'.format(epoch), val_df)

            if args.early_stopping and curr_val_loss >= prev_val_loss:
                logging.info("Validation loss stopped decreasing, stopping...")
                break
            else:
                prev_val_loss = curr_val_loss

        logging.info("Finished training")

    ##### Testing
    logging.debug("Run testing")
    if not args.test_only:
        to_use = run_info['best_ckpt'] if args.use_best else ckpt
    else:
        if args.use_ckpt_num == None:
            with open(os.path.join(args.model_dir, 'run_info.json')) as f:
                run_info = json.load(f)
            to_use = run_info['best_ckpt']
        else:
            to_use = os.path.join(
                args.model_dir, 'model-ckpt-{:}'.format(args.use_ckpt_num))
        saver = tf.train.import_meta_graph(to_use + '.meta')

    logging.info("Using {:} for testing".format(to_use))
    saver.restore(sess, to_use)

    test_generator_callable = functools.partial(
        feature_mut.dataset_generator,
        args.test_sharded,
        args.grid_config,
        shuffle=args.shuffle,
        repeat=args.repeat_gen,
        add_flag=args.add_flag,
        center_at_mut=args.center_at_mut,
        testing=True,
        random_seed=args.random_seed)

    test_num_ensembles = args.test_sharded.get_num_keyed()
    test_num_ensembles *= args.repeat_gen
    logging.info("Start testing with {:} ensembles".format(test_num_ensembles))

    test_ensembles, test_logits, test_preds, test_labels, _, test_loss = __loop(
        test_generator_callable, 'test', num_iters=test_num_ensembles)
    logging.info("Finished testing")

    test_df = pd.DataFrame(
        np.array([test_ensembles, test_labels, test_preds, test_logits]).T,
        columns=['ensembles', 'true', 'pred', 'logits'],
        )
    test_df.to_pickle(os.path.join(args.output_dir, 'test_result.pkl'))
    __stats('Test', test_df)
Beispiel #3
0
def train_model(sess, args):
    # tf Graph input
    # Subgrid maps for each residue in a protein
    logging.debug('Create input placeholder...')
    grid_size = subgrid_gen.grid_size(args.grid_config)
    channel_size = subgrid_gen.num_channels(args.grid_config)
    feature_placeholder = tf.placeholder(
        tf.float32,
        [None, grid_size, grid_size, grid_size, channel_size],
        name='main_input')
    label_placeholder = tf.placeholder(tf.float32, [None, 1], 'label')

    # Placeholder for model parameters
    training_placeholder = tf.placeholder(tf.bool, shape=[], name='is_training')
    conv_drop_rate_placeholder = tf.placeholder(tf.float32, name='conv_drop_rate')
    fc_drop_rate_placeholder = tf.placeholder(tf.float32, name='fc_drop_rate')
    top_nn_drop_rate_placeholder = tf.placeholder(tf.float32, name='top_nn_drop_rate')

    # Define loss and optimizer
    logging.debug('Define loss and optimizer...')
    predict_op, loss_op = conv_model(
        feature_placeholder, label_placeholder, training_placeholder,
        conv_drop_rate_placeholder, fc_drop_rate_placeholder,
        top_nn_drop_rate_placeholder, args)
    logging.debug('Generate training ops...')
    train_op = model.training(loss_op, args.learning_rate)

    # Initialize the variables (i.e. assign their default value)
    logging.debug('Initializing global variables...')
    init = tf.global_variables_initializer()

    # Create saver and summaries.
    logging.debug('Initializing saver...')
    saver = tf.train.Saver(max_to_keep=100000)
    logging.debug('Finished initializing saver...')

    def __loop(generator, mode, num_iters):
        tf_dataset, next_element = batch_dataset_generator(
            generator, args, is_testing=(mode=='test'))

        structs, losses, preds, labels = [], [], [], []
        epoch_loss = 0
        progress_format = mode + ' loss: {:6.6f}'

        # Loop over all batches (one batch is all feature for 1 protein)
        num_batches = int(math.ceil(float(num_iters)/args.batch_size))
        #print('Running {:} -> {:} iters in {:} batches (batch size: {:})'.format(
        #    mode, num_iters, num_batches, args.batch_size))
        with tqdm.tqdm(total=num_batches, desc=progress_format.format(0)) as t:
            for i in range(num_batches):
                try:
                    struct_, feature_, label_ = sess.run(next_element)
                    _, pred, loss = sess.run(
                        [train_op, predict_op, loss_op],
                        feed_dict={feature_placeholder: feature_,
                                   label_placeholder: label_,
                                   training_placeholder: (mode == 'train'),
                                   conv_drop_rate_placeholder:
                                       args.conv_drop_rate if mode == 'train' else 0.0,
                                   fc_drop_rate_placeholder:
                                       args.fc_drop_rate if mode == 'train' else 0.0,
                                   top_nn_drop_rate_placeholder:
                                       args.top_nn_drop_rate if mode == 'train' else 0.0})
                    epoch_loss += (np.mean(loss) - epoch_loss) / (i + 1)
                    structs.extend(struct_)
                    losses.append(loss)
                    preds.extend(pred)
                    labels.extend(label_)

                    t.set_description(progress_format.format(epoch_loss))
                    t.update(1)
                except StopIteration:
                    logging.info("\nEnd of dataset at iteration {:}".format(i))
                    break

        def __concatenate(array):
            try:
                array = np.concatenate(array)
                return array
            except:
                return array

        structs = __concatenate(structs)
        preds = __concatenate(preds)
        labels = __concatenate(labels)
        losses = __concatenate(losses)
        return structs, preds, labels, losses, epoch_loss

    # Run the initializer
    logging.debug('Running initializer...')
    sess.run(init)
    logging.debug('Finished running initializer...')

    ##### Training + validation
    if not args.test_only:
        prev_val_loss, best_val_loss = float("inf"), float("inf")

        if (args.max_pdbs_train == None):
            pdbcodes = feature_pdbbind.read_split(args.train_split_filename)
            train_num_structs = len(pdbcodes)
        else:
            train_num_structs = args.max_pdbs_train

        if (args.max_pdbs_val == None):
            pdbcodes = feature_pdbbind.read_split(args.val_split_filename)
            val_num_structs = len(pdbcodes)
        else:
            val_num_structs = args.max_pdbs_val

        train_num_structs *= args.repeat_gen
        val_num_structs *= args.repeat_gen

        logging.info("Start training with {:} structs for train and {:} structs for val per epoch".format(
            train_num_structs, val_num_structs))


        def _save():
            ckpt = saver.save(sess, os.path.join(args.output_dir, 'model-ckpt'),
                              global_step=epoch)
            return ckpt

        run_info_filename = os.path.join(args.output_dir, 'run_info.json')
        run_info = {}
        def __update_and_write_run_info(key, val):
            run_info[key] = val
            with open(run_info_filename, 'w') as f:
                json.dump(run_info, f, indent=4)

        per_epoch_val_losses = []
        for epoch in range(1, args.num_epochs+1):
            random_seed = args.random_seed #random.randint(1, 10e6)
            logging.info('Epoch {:} - random_seed: {:}'.format(epoch, args.random_seed))

            logging.debug('Creating train generator...')
            train_generator_callable = functools.partial(
                feature_pdbbind.dataset_generator,
                args.data_filename,
                args.train_split_filename,
                args.labels_filename,
                args.grid_config,
                shuffle=args.shuffle,
                repeat=args.repeat_gen,
                max_pdbs=args.max_pdbs_train,
                random_seed=random_seed)

            logging.debug('Creating val generator...')
            val_generator_callable = functools.partial(
                feature_pdbbind.dataset_generator,
                args.data_filename,
                args.val_split_filename,
                args.labels_filename,
                args.grid_config,
                shuffle=args.shuffle,
                repeat=args.repeat_gen,
                max_pdbs=args.max_pdbs_val,
                random_seed=random_seed)

            # Training
            train_structs, train_preds, train_labels, _, curr_train_loss = __loop(
                train_generator_callable, 'train', num_iters=train_num_structs)
            # Validation
            val_structs, val_preds, val_labels, _, curr_val_loss = __loop(
                val_generator_callable, 'val', num_iters=val_num_structs)

            per_epoch_val_losses.append(curr_val_loss)
            __update_and_write_run_info('val_losses', per_epoch_val_losses)

            if args.use_best or args.early_stopping:
                if curr_val_loss < best_val_loss:
                    # Found new best epoch.
                    best_val_loss = curr_val_loss
                    ckpt = _save()
                    __update_and_write_run_info('val_best_loss', best_val_loss)
                    __update_and_write_run_info('best_ckpt', ckpt)
                    logging.info("New best {:}".format(ckpt))

            if (epoch == args.num_epochs - 1 and not args.use_best):
                # At end and just using final checkpoint.
                ckpt = _save()
                __update_and_write_run_info('best_ckpt', ckpt)
                logging.info("Last checkpoint {:}".format(ckpt))

            if args.save_all_ckpts:
                # Save at every checkpoint
                ckpt = _save()
                logging.info("Saving checkpoint {:}".format(ckpt))

            if args.early_stopping and curr_val_loss >= prev_val_loss:
                logging.info("Validation loss stopped decreasing, stopping...")
                break
            else:
                prev_val_loss = curr_val_loss

        logging.info("Finished training")

        ## Save last train and val results
        logging.info("Saving train and val results")
        train_df = pd.DataFrame(
            np.array([train_structs, train_labels, train_preds]).T,
            columns=['structure', 'true', 'pred'],
            )
        train_df.to_pickle(os.path.join(args.output_dir, 'train_result.pkl'))

        val_df = pd.DataFrame(
            np.array([val_structs, val_labels, val_preds]).T,
            columns=['structure', 'true', 'pred'],
            )
        val_df.to_pickle(os.path.join(args.output_dir, 'val_result.pkl'))


    ##### Testing
    logging.debug("Run testing")
    if not args.test_only:
        to_use = run_info['best_ckpt'] if args.use_best else ckpt
    else:
        if args.use_ckpt_num == None:
            with open(os.path.join(args.model_dir, 'run_info.json')) as f:
                run_info = json.load(f)
            to_use = run_info['best_ckpt']
        else:
            to_use = os.path.join(
                args.model_dir, 'model-ckpt-{:}'.format(args.use_ckpt_num))
        saver = tf.train.import_meta_graph(to_use + '.meta')

    test_generator_callable = functools.partial(
        feature_pdbbind.dataset_generator,
        args.data_filename,
        args.test_split_filename,
        args.labels_filename,
        args.grid_config,
        shuffle=args.shuffle,
        repeat=1,
        max_pdbs=args.max_pdbs_test,
        random_seed=args.random_seed)

    if (args.max_pdbs_test == None):
        pdbcodes = feature_pdbbind.read_split(args.test_split_filename)
        test_num_structs = len(pdbcodes)
    else:
        test_num_structs = args.max_pdbs_test

    logging.info("Start testing with {:} structs".format(test_num_structs))


    test_structs, test_preds, test_labels, _, test_loss = __loop(
        test_generator_callable, 'test', num_iters=test_num_structs)
    logging.info("Finished testing")

    test_df = pd.DataFrame(
        np.array([test_structs, test_labels, test_preds]).T,
        columns=['structure', 'true', 'pred'],
        )
    test_df.to_pickle(os.path.join(args.output_dir, 'test_result.pkl'))

    # Compute global correlations
    res = compute_stats(test_df)
    logging.info(
        '\nStats\n'
        '    RMSE: {:.3f}\n'
        '    Pearson: {:.3f}\n'
        '    Spearman: {:.3f}'.format(
        float(res["rmse"]),
        float(res["all_pearson"]),
        float(res["all_spearman"])))