def Logger(q):
    import time
    all_auc = []
    registered_gpus = {}
    logger = lib.logger.CSVLogger('results.csv', output_dir, [
        'fold', 'seq_acc', 'gnn_nuc_acc', 'bilstm_nuc_acc', 'auc',
        'original_seq_acc', 'original_gnn_nuc_acc', 'original_bilstm_nuc_acc',
        'original_auc'
    ])
    while True:
        msg = q.get()
        print(msg)
        if type(msg) is str and msg == 'kill':
            logger.close()
            print('%s ROC AUC: %.3f\u00B1%.3f' %
                  (TRAIN_RBP_ID, np.mean(all_auc), np.std(all_auc)))
            break
        elif type(msg) is str and msg.startswith('worker'):
            process_id = int(msg.split('_')[-1])
            if process_id in registered_gpus:
                print(process_id, 'found, returning',
                      registered_gpus[process_id])
                q.put('master_%d_' % (process_id) +
                      registered_gpus[process_id])
            else:
                print(process_id, 'not found')
                all_registered_devices = list(registered_gpus.values())
                from collections import Counter
                c1 = Counter(DEVICES)
                c2 = Counter(all_registered_devices)
                free_devices = list((c1 - c2).elements())
                # free_devices = list(set(DEVICES).difference(set(all_registered_devices)))
                if len(free_devices) > 0:
                    _device = np.random.choice(free_devices)
                    print('free device', _device)
                    q.put('master_%d_' % (process_id) + _device)
                    registered_gpus[process_id] = _device
                else:
                    print('no free device!')
                    print(registered_gpus)
                    q.put('master_%d_/cpu:0' % (process_id))
        elif type(msg) is dict:
            logger.update_with_dict(msg)
            all_auc.append(msg['original_auc'])
        else:
            q.put(msg)
        time.sleep(np.random.rand() * 5)
Beispiel #2
0
    # posterior decoding without RNA regularity
    print('Test_recon_acc_no_reg',
          ret_dict['recon_acc_noreg'] / ret_dict['total'] * 100)
    print('Test_post_valid_no_reg',
          ret_dict['post_valid_noreg'] / ret_dict['total'] * 100)
    print('Test_post_fe_deviation_no_reg',
          ret_dict['post_fe_deviation_noreg'] / ret_dict['post_valid_noreg'])
    print(
        'Test_post_fe_deviation_len_normed_no_reg',
        ret_dict['post_fe_deviation_noreg_len_normed'] /
        ret_dict['post_valid_noreg'])
    # posterior decoding without RNA regularity and greedy
    print('Test_recon_acc_no_reg_det',
          ret_dict['recon_acc_noreg_det'] / ret_dict['total'] * 100 * 5)
    print('Test_post_valid_no_reg_det',
          ret_dict['post_valid_noreg_det'] / ret_dict['total'] * 100 * 5)
    print(
        'Test_post_fe_deviation_no_reg_det',
        ret_dict['post_fe_deviation_noreg_det'] /
        ret_dict['post_valid_noreg_det'])
    print(
        'Test_post_fe_deviation_len_normed_no_reg_det',
        ret_dict['post_fe_deviation_noreg_det_len_normed'] /
        ret_dict['post_valid_noreg_det'])

    if mp_pool is not None:
        mp_pool.close()
        mp_pool.join()

    logger.close()
Beispiel #3
0
def run_one_rbp(fold_idx, q):
    fold_output = os.path.join(output_dir, 'fold%d' % (fold_idx))
    os.makedirs(fold_output)

    outfile = open(os.path.join(fold_output, str(os.getpid())) + ".out", "w")
    sys.stdout = outfile
    sys.stderr = outfile

    import time
    # todo: replace _identity with pid and let logger check if pid still alive
    process_id = mp.current_process()._identity[0]
    print('sending process id', mp.current_process()._identity[0])
    q.put('worker_%d' % (process_id))
    while True:
        msg = q.get()
        if type(msg) is str and msg.startswith('master'):
            print('worker %d received' % (process_id), msg, str(int(msg.split('_')[1])))
            if int(msg.split('_')[1]) == process_id:
                device = msg.split('_')[-1]
                print('Process', mp.current_process(), 'received', device)
                break
        q.put(msg)
        time.sleep(np.random.rand() * 2)

    print('training fold', fold_idx)
    train_idx, test_idx = dataset['splits'][fold_idx]
    model = JMRT(dataset['VOCAB_VEC'].shape[1], dataset['VOCAB_VEC'], device, **hp)

    train_data = [dataset['seq'][train_idx], dataset['segment_size'][train_idx], dataset['raw_seq'][train_idx]]
    model.fit(train_data, dataset['label'][train_idx], EPOCHS, BATCH_SIZE, fold_output, logging=True)

    test_data = [dataset['seq'][test_idx], dataset['segment_size'][test_idx], dataset['raw_seq'][test_idx]]
    cost, acc, auc = model.evaluate(test_data, dataset['label'][test_idx], BATCH_SIZE, random_crop=False)
    print('Evaluation (with masking) on modified held-out test set, acc: %s, auc: %.3f' % (acc, auc))

    original_test_data = [original_dataset['seq'][test_idx], original_dataset['segment_size'][test_idx],
                          original_dataset['raw_seq'][test_idx]]
    original_cost, original_acc, original_auc = model.evaluate(original_test_data, original_dataset['label'][test_idx],
                                                               BATCH_SIZE, random_crop=False)
    print('Evaluation (with masking) on original held-out test set, acc: %s, auc: %.3f' % (original_acc, original_auc))

    # get predictions
    logger = lib.logger.CSVLogger('predictions.csv', fold_output,
                                  ['id', 'label', 'pred_neg', 'pred_pos'])
    all_pos_preds = []
    all_idx = []
    for idx, (_id, _label, _pred) in enumerate(
            zip(original_dataset['id'][test_idx], original_dataset['label'][test_idx],
                model.predict(original_test_data, BATCH_SIZE))):
        logger.update_with_dict({
            'id': _id,
            'label': np.max(_label),
            'pred_neg': _pred[0],
            'pred_pos': _pred[1],
        })
        if np.max(_label) == 1:
            all_pos_preds.append(_pred[1])
            all_idx.append(idx)
    logger.close()

    # plot some motifs
    graph_dir = os.path.join(fold_output, 'integrated_gradients')
    if not os.path.exists(graph_dir):
        os.makedirs(graph_dir)

    all_pos_preds = np.array(all_pos_preds)
    all_idx = np.array(all_idx)
    # top 10 strongly predicted examples, descending order
    idx = all_idx[np.argsort(all_pos_preds)[::-1][:min(10, len(all_pos_preds))]]

    model.integrated_gradients(model.indexing_iterable(original_test_data, idx),
                               original_dataset['label'][test_idx][idx],
                               original_dataset['id'][test_idx][idx], save_path=graph_dir)

    # common ig plots
    idx = []
    for i, _id in enumerate(dataset['id'][test_idx]):
        if _id in ig_ids:
            idx.append(i)

    common_graph_path = os.path.join(output_dir, 'common_integrated_gradients')
    if not os.path.exists(common_graph_path):
        os.makedirs(common_graph_path)

    model.integrated_gradients(model.indexing_iterable(original_test_data, idx),
                               original_dataset['label'][test_idx][idx],
                               original_dataset['id'][test_idx][idx], save_path=common_graph_path)

    model.delete()
    reload(lib.plot)
    reload(lib.logger)
    q.put({
        'fold': fold_idx,
        'seq_acc': acc[0],
        'nuc_acc': acc[1],
        'auc': auc,
        'original_seq_acc': original_acc[0],
        'original_nuc_acc': original_acc[1],
        'original_auc': original_auc
    })
    def fit(self, X, y, epochs, batch_size, output_dir, logging=False, epoch_to_start=0):
        checkpoints_dir = os.path.join(output_dir, 'checkpoints/')
        if not os.path.exists(checkpoints_dir):
            os.makedirs(checkpoints_dir)

        # split validation set
        row_sum = np.array(list(map(lambda label: np.sum(label), y)))
        pos_idx, neg_idx = np.where(row_sum > 0)[0], np.where(row_sum == 0)[0]

        dev_idx = np.array(list(np.random.choice(pos_idx, int(len(pos_idx) * 0.1), False)) + \
                           list(np.random.choice(neg_idx, int(len(neg_idx) * 0.1), False)))
        train_idx = np.delete(np.arange(len(y)), dev_idx)

        dev_data = self.indexing_iterable(X, dev_idx)
        dev_targets = y[dev_idx]
        X = self.indexing_iterable(X, train_idx)
        train_targets = y[train_idx]

        best_dev_cost = np.inf
        # best_dev_auc = 0.
        lib.plot.set_output_dir(output_dir)
        if logging:
            logger = lib.logger.CSVLogger('run.csv', output_dir,
                                          ['epoch', 'cost', 'graph_cost', 'gnn_cost', 'bilstm_cost',
                                           'seq_acc', 'gnn_acc', 'bilstm_acc', 'auc',
                                           'dev_cost', 'dev_graph_cost', 'dev_gnn_cost', 'dev_bilstm_cost',
                                           'dev_seq_acc', 'dev_gnn_acc', 'dev_bilstm_acc', 'dev_auc'])

        train_generator = BackgroundGenerator(X, train_targets, batch_size, random_crop=False)
        val_generator = BackgroundGenerator(dev_data, dev_targets, batch_size)
        iters_per_epoch = train_generator.iters_per_epoch

        for epoch in range(epoch_to_start, epochs):

            prepro_time = 0.
            training_time = 0.
            for i in range(iters_per_epoch):
                prepro_start = time.time()
                _node_tensor, _mask_offset, all_adj_mat, _labels = train_generator.next()
                feed_dict = {
                    self.node_input_ph: _node_tensor,
                    self.adj_mat_ph: all_adj_mat,
                    self.labels: _labels,
                    self.mask_offset: _mask_offset,
                    self.global_step: i + epoch * iters_per_epoch,
                    self.hf_iters_per_epoch: iters_per_epoch // 2,
                    self.is_training_ph: True,
                }
                prepro_end = time.time()
                prepro_time += (prepro_end - prepro_start)
                self.sess.run(self.train_op, feed_dict)
                training_time += (time.time() - prepro_end)
            print('preprocessing time: %.4f, training time: %.4f' % (prepro_time / (i + 1), training_time / (i + 1)))
            train_cost, train_acc, train_auc = self.evaluate_with_generator(train_generator)
            lib.plot.plot('train_cost', train_cost[0])
            lib.plot.plot('train_graph_cost', train_cost[1])
            lib.plot.plot('train_gnn_cost', train_cost[2])
            lib.plot.plot('train_bilstm_cost', train_cost[3])
            lib.plot.plot('train_seq_acc', train_acc[0])
            lib.plot.plot('train_gnn_acc', train_acc[1])
            lib.plot.plot('train_bilstm_acc', train_acc[2])
            lib.plot.plot('train_auc', train_auc)

            dev_cost, dev_acc, dev_auc = self.evaluate_with_generator(val_generator)
            lib.plot.plot('dev_cost', dev_cost[0])
            lib.plot.plot('dev_graph_cost', dev_cost[1])
            lib.plot.plot('dev_gnn_cost', dev_cost[2])
            lib.plot.plot('dev_bilstm_cost', dev_cost[3])
            lib.plot.plot('dev_seq_acc', dev_acc[0])
            lib.plot.plot('dev_gnn_acc', dev_acc[1])
            lib.plot.plot('dev_bilstm_acc', dev_acc[2])
            lib.plot.plot('dev_auc', dev_auc)

            logger.update_with_dict({
                'epoch': epoch, 'cost': train_cost[0], 'graph_cost': train_cost[1], 'gnn_cost': train_cost[2],
                'bilstm_cost': train_cost[3], 'seq_acc': train_acc[0], 'gnn_acc': train_acc[1],
                'bilstm_acc': train_acc[2], 'auc': train_auc,

                'dev_cost': dev_cost[0], 'dev_graph_cost': dev_cost[1], 'dev_gnn_cost': dev_cost[2],
                'dev_bilstm_cost': dev_cost[3], 'dev_seq_acc': dev_acc[0], 'dev_gnn_acc': dev_acc[1],
                'dev_bilstm_acc': dev_acc[2], 'dev_auc': dev_auc,
            })

            lib.plot.flush()
            lib.plot.tick()

            if dev_cost[0] < best_dev_cost and epoch - epoch_to_start >= 10:  # unstable loss in the beginning
                best_dev_cost = dev_cost[0]
                save_path = self.saver.save(self.sess, checkpoints_dir, global_step=epoch)
                print('Validation sample cost improved. Saved to path %s\n' % (save_path), flush=True)
            else:
                print('\n', flush=True)

        print('Loading best weights %s' % (save_path), flush=True)
        self.saver.restore(self.sess, save_path)
        if logging:
            logger.close()
        train_generator.kill.set()
        val_generator.kill.set()
        train_generator.next()
        val_generator.next()
        train_generator.join()
        val_generator.join()
Beispiel #5
0
    def fit(self,
            X,
            y,
            epochs,
            batch_size,
            output_dir,
            logging=False,
            epoch_to_start=0,
            random_crop=False):
        checkpoints_dir = os.path.join(output_dir, 'checkpoints/')
        if not os.path.exists(checkpoints_dir):
            os.makedirs(checkpoints_dir)

        # split validation set
        row_sum = np.array(list(map(lambda label: np.sum(label), y)))
        pos_idx, neg_idx = np.where(row_sum > 0)[0], np.where(row_sum == 0)[0]
        dev_idx = np.array(list(np.random.choice(pos_idx, int(len(pos_idx) * 0.1), False)) + \
                           list(np.random.choice(neg_idx, int(len(neg_idx) * 0.1), False)))
        train_idx = np.delete(np.arange(len(y)), dev_idx)

        dev_data = self.indexing_iterable(X, dev_idx)
        dev_targets = y[dev_idx]
        X = self.indexing_iterable(X, train_idx)
        train_targets = y[train_idx]

        size_train = train_targets.shape[0]
        iters_per_epoch = size_train // batch_size + (0 if size_train %
                                                      batch_size == 0 else 1)
        best_dev_cost = np.inf
        lib.plot.set_output_dir(output_dir)
        if logging:
            logger = lib.logger.CSVLogger('run.csv', output_dir, [
                'epoch', 'cost', 'graph_cost', 'nuc_cost', 'seq_acc',
                'nuc_acc', 'auc', 'dev_cost', 'dev_graph_cost', 'dev_nuc_cost',
                'dev_seq_acc', 'dev_nuc_acc', 'dev_auc'
            ])

        for epoch in range(epoch_to_start, epochs):

            permute = np.random.permutation(size_train)
            node_tensor, segment_length, raw_seq = self.indexing_iterable(
                X, permute)
            y = train_targets[permute]

            if random_crop:
                # augmentation
                node_tensor, segment_length, y = \
                    self.random_crop(node_tensor, raw_seq, y)

            prepro_time = 0.
            training_time = 0.
            for i in range(iters_per_epoch):
                prepro_start = time.time()
                _node_tensor, _segment, _labels \
                    = node_tensor[i * batch_size: (i + 1) * batch_size], \
                      segment_length[i * batch_size: (i + 1) * batch_size], \
                      y[i * batch_size: (i + 1) * batch_size]

                _max_len = max(_segment)
                _labels = np.array([
                    np.pad(label, [_max_len - len(label), 0], mode='constant')
                    for label in _labels
                ])

                feed_dict = {
                    self.node_input_ph: np.concatenate(_node_tensor, axis=0),
                    self.labels: _labels,
                    self.max_len: _max_len,
                    self.segment_length: _segment,
                    self.global_step: i,
                    self.hf_iters_per_epoch: iters_per_epoch // 2,
                    self.is_training_ph: True
                }
                prepro_end = time.time()
                prepro_time += (prepro_end - prepro_start)
                self.sess.run(self.train_op, feed_dict)
                training_time += (time.time() - prepro_end)
            print('preprocessing time: %.4f, training time: %.4f' %
                  (prepro_time / (i + 1), training_time / (i + 1)))
            train_cost, train_acc, train_auc = self.evaluate(
                X, train_targets, batch_size)
            lib.plot.plot('train_cost', train_cost[0])
            lib.plot.plot('train_graph_cost', train_cost[1])
            lib.plot.plot('train_nuc_cost', train_cost[2])
            lib.plot.plot('train_seq_acc', train_acc[0])
            lib.plot.plot('train_nuc_acc', train_acc[1])
            lib.plot.plot('train_auc', train_auc)

            dev_cost, dev_acc, dev_auc = self.evaluate(dev_data, dev_targets,
                                                       batch_size)
            lib.plot.plot('dev_cost', dev_cost[0])
            lib.plot.plot('dev_graph_cost', dev_cost[1])
            lib.plot.plot('dev_nuc_cost', dev_cost[2])
            lib.plot.plot('dev_seq_acc', dev_acc[0])
            lib.plot.plot('dev_nuc_acc', dev_acc[1])
            lib.plot.plot('dev_auc', dev_auc)

            logger.update_with_dict({
                'epoch': epoch,
                'cost': train_cost[0],
                'graph_cost': train_cost[1],
                'nuc_cost': train_cost[2],
                'seq_acc': train_acc[0],
                'nuc_acc': train_acc[1],
                'auc': train_auc,
                'dev_cost': dev_cost[0],
                'dev_graph_cost': dev_cost[1],
                'dev_nuc_cost': dev_cost[2],
                'dev_seq_acc': dev_acc[0],
                'dev_nuc_acc': dev_acc[1],
                'dev_auc': dev_auc,
            })

            lib.plot.flush()
            lib.plot.tick()

            if dev_cost[
                    0] < best_dev_cost and epoch - epoch_to_start >= 10:  # unstable loss in the beginning
                best_dev_cost = dev_cost[0]
                save_path = self.saver.save(self.sess,
                                            checkpoints_dir,
                                            global_step=epoch)
                print('Validation sample cost improved. Saved to path %s\n' %
                      (save_path),
                      flush=True)
            else:
                print('\n', flush=True)

        print('Loading best weights %s' % (save_path), flush=True)
        self.saver.restore(self.sess, save_path)
        if logging:
            logger.close()