示例#1
0
def run_tumor_normal_classification(cancertype,
                                    how_many_training_steps=2000,
                                    dropout_keep_prob=0.8,
                                    label_names=['is_tumor'],
                                    optimizer='adam',
                                    is_weighted=0,
                                    nClass=2,
                                    treat_validation_as_test=True,
                                    do_not_train=True,
                                    avoid_gpu_for_testing=True,
                                    train_test_percentage=[70, 30]):
    # annotations_path='data/pancancer_annotations/tn_frozen_cache_anns'
    # results_path = 'data/run-results/frozen_undersampled/'
    # pancancer_tfrecords_path='tfrecords/frozen/tn'

    image_file_metadata_filename = '{:s}/{:s}/caches_basic_annotations.txt'.format(
        annotations_path, cancertype)
    #     tfrecords_path = os.path.join(pancancer_tfrecords_path, cancertype, 'caches_512x512/')
    tfrecords_path = os.path.join(pancancer_tfrecords_path, cancertype, '')
    print('copying files from GCS')
    input_bucket_path = 'gs://' + input_bucket + '/'
    util.gsutil_cp(os.path.join(input_bucket_path, tfrecords_path,
                                'tfrecord*'),
                   '/sdata/' + tfrecords_path,
                   make_dir=True,
                   payer_project_id=payer_project_id)
    util.gsutil_cp(os.path.join(input_bucket_path,
                                image_file_metadata_filename),
                   '/sdata/' + image_file_metadata_filename,
                   make_dir=False,
                   payer_project_id=payer_project_id)

    # output paths
    trecords_prefix = '/sdata/' + tfrecords_path + 'tfrecord'
    saved_model_path = os.path.join(results_path,
                                    'saved_models/{:s}'.format(cancertype))
    tensorboard_path = os.path.join(results_path,
                                    'tensorboard_logs/{:s}'.format(cancertype))
    pickle_path = os.path.join(
        results_path,
        'pickles/pickles_train{:d}_test{:d}/run_cnn_output_{:s}.pkl'.format(
            *train_test_percentage, cancertype))

    tfrecordfiles = glob.glob('{:s}*'.format(trecords_prefix))
    assert len(tfrecordfiles) > 0
    num_tfrecords = int(len(tfrecordfiles) / 3)

    tfrecordfiles_dict = {
        s: [
            '{:s}{:d}.{:s}'.format(trecords_prefix, n, s)
            for n in range(num_tfrecords)
        ]
        for s in ['training', 'testing', 'validation']
    }
    image_files_metadata = pd.read_csv('/sdata/' +
                                       image_file_metadata_filename,
                                       index_col=0)

    if treat_validation_as_test:
        image_files_metadata['crossval_group'].replace('validation',
                                                       'testing',
                                                       inplace=True)
        tfrecordfiles_dict['testing'] = tfrecordfiles_dict[
            'testing'] + tfrecordfiles_dict.pop('validation')

    test_batch_size = (
        image_files_metadata['crossval_group'] == 'testing').sum()

    if is_weighted:
        label_ratio = image_files_metadata[label_names].mean()
        pos_weight = (1 / label_ratio - 1).tolist()
    else:
        pos_weight = 1

    class_probs = image_files_metadata.loc[
        image_files_metadata['crossval_group'] == 'training',
        label_names[0]].value_counts(normalize=True,
                                     sort=False).sort_index().values

    test_accuracies_list, predictions_list, confusion_matrices_list, imagefilenames, final_softmax_outputs_list = \
    run_classification.run_multilabel_classification_with_inception_CNN(label_names, tfrecordfiles_dict, test_batch_size=test_batch_size, nClass=nClass,
                                                                        train_batch_size = 512, how_many_training_steps=how_many_training_steps, avoid_gpu_for_testing=avoid_gpu_for_testing,
                                                                        do_not_train = do_not_train, pos_weight = pos_weight, dropout_keep_prob = dropout_keep_prob,
                                                                        saved_model_path = os.path.join('/sdata', saved_model_path, 'mychckpt'),
                                                                        summaries_dir = '/sdata/'+ tensorboard_path, optimizer = optimizer,
                                                                        class_probs=class_probs)

    util.mkdir_if_not_exist(os.path.dirname('/sdata/' + pickle_path))

    pickle.dump([
        image_files_metadata, test_accuracies_list, predictions_list,
        confusion_matrices_list, imagefilenames, final_softmax_outputs_list
    ], open('/sdata/' + pickle_path, 'wb'))

    util.gsutil_cp(os.path.join('/sdata', saved_model_path),
                   os.path.join(input_bucket_path, saved_model_path),
                   payer_project_id=payer_project_id)
    util.gsutil_cp(os.path.join('/sdata', tensorboard_path),
                   os.path.join(input_bucket_path, tensorboard_path),
                   payer_project_id=payer_project_id)
    util.gsutil_cp(os.path.join('/sdata', pickle_path),
                   os.path.join(input_bucket_path, pickle_path),
                   payer_project_id=payer_project_id)
示例#2
0
def worker(msg):
    start_time = time.time()
    print(msg.message.data)

    task_id = int(msg.message.data)
    client = datastore.Client(project_id)
    key = client.key(task_kind, task_id)
    params = client.get(key)

    # Setting the status to 'InProgress'
    mark_in_progress(client, task_id)

    cancertype = params['cancertype']
    category = params['category']
    shard_length = int(params['shard_length'])
    shard_index = int(params['shard_index'])
    gcs_output_path = params['gcs_output_path']

    print('Loading metadata...')
    image_file_metadata_filename = 'data/caches_basic_annotations.txt'
    util.gsutil_cp('{}/{}/caches_basic_annotations.txt'.format(gcs_ann_path, cancertype), 'data/', make_dir=True, payer_project_id=payer_project_id)
    image_files_metadata = pd.read_csv(image_file_metadata_filename, skiprows=range(1, shard_index*shard_length+1), nrows=shard_length)
    
    shard_length_tiles = len(image_files_metadata.index)

    label_names = ['cnv']

    print('Downloading cache files...')
    image_files_metadata['cache_values'] = choose_input_list.load_cache_values(image_files_metadata, 
                                                                               bucket_name = tiles_input_bucket,
                                                                               notebook = False, user_project=payer_project_id)

    crossval_groups = ['training', 'testing', 'validation']
    if category not in crossval_groups+['all']:
        raise Exception('Unknown cross validation category.')

    # Create tfrecords for each category
    if category != 'all': # keyword 'all' will loop through all three categories
        crossval_groups = [category]

    tfrecords_folder = 'tfrecords_{}'.format(cancertype)
    util.mkdir_if_not_exist(tfrecords_folder)

    for category in crossval_groups:
        print('Creating TFRecord for {:s}...'.format(category))
        handle_tfrecords.create_tfrecords_per_category_for_caches(image_files_metadata, label_names, category,
                                                                 tfrecord_file_name_prefix = tfrecords_folder + '/tfrecord{:d}'.format(shard_index))

    tfrecords_bucket = re.search('gs://(.+?)/', gcs_output_path).group(1)
    prefix = 'gs://' + tfrecords_bucket + '/'
    gcs_directory = "".join(gcs_output_path.rsplit(prefix))

    bucket = handle_google_cloud_apis.gcsbucket(project_name=project_id, bucket_name=tfrecords_bucket)
    bucket.copy_files_to_gcs(tfrecords_folder, gcs_directory, verbose=True)

    command = 'du -s ' + tfrecords_folder + '/'
    tfrecord_size_MBi = round(int(os.popen(command).read().split()[0])/1000,1)  # in MB

    # Removing local files
    command = "rm -rf " + tfrecords_folder
    os.popen(command)
    os.popen("rm -rf tcga_tiles")

    elapsed_time_s = round((time.time() - start_time), 1)  # in seconds

    completed_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())

    # We now can comfirm the job
    client = datastore.Client(project_id)
    mark_done(client=client, task_id=task_id, completed_time=completed_time,
              elapsed_time_s=elapsed_time_s, shard_length_tiles=shard_length_tiles,
              tfrecord_size_MBi=tfrecord_size_MBi)

    print('Finish Timestamp: {} - Time elapsed: {} seconds.'.format(completed_time, elapsed_time_s))

    subscriber = pubsub.SubscriberClient()
    subscription_path = subscriber.subscription_path(project_id, subscription_name)

    # Acknowledging the message
    subscriber.acknowledge(subscription_path, [msg.ack_id])
    print("{}: Acknowledged {}".format(
        time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), msg.message.data))
示例#3
0
def worker(msg):
    start_time = time.time()
    print(msg.message.data)

    task_id = int(msg.message.data)
    client = datastore.Client(project_id)
    key = client.key(task_kind, task_id)
    params = client.get(key)

    # Setting the status to 'InProgress'
    mark_in_progress(client, task_id)
    svs_path = params['svs_path']

    bucket_name = svs_path.lstrip('gs://').split('/')[0]
    bucket_path = 'gs://' + bucket_name
    input_tiles_path = os.path.join(svs_path, 'tiles/tile_*.jpg')
    local_tiles_path = os.path.join(re.sub(bucket_path, '/sdata', svs_path),
                                    'tiles/')
    local_tiles_glob_path = os.path.join(local_tiles_path, 'tile_*.jpg')
    #     output_cache_path = re.sub('/tiles_', '/caches_', svs_path)
    x = local_tiles_path.rstrip('/').split('/')
    x.pop(-1)
    x[-2] += '_cache'
    local_cache_path = '/'.join(x)
    #     local_cache_path = re.sub(bucket_path, '/sdata', output_cache_path)
    output_cache_path = re.sub('/sdata', bucket_path, local_cache_path)

    print('copying files from GCS')
    util.gsutil_cp(input_tiles_path, local_tiles_path, make_dir=True)

    caches_metadata = pd.DataFrame(glob.glob(local_tiles_glob_path),
                                   columns=['image_filename'])

    def convert_to_cache_path(x, local_cache_path):
        return os.path.join(local_cache_path, x.split('/')[-1] + '_cached.txt')

    caches_metadata['rel_path'] = caches_metadata['image_filename'].map(
        lambda x: convert_to_cache_path(x, local_cache_path))

    print('cache the unchached...')
    graphs_folder = util.DATA_PATH + 'graphs/'
    run_classification.cache_the_uncacheded(caches_metadata,
                                            model_dir=graphs_folder,
                                            use_tqdm_notebook_widget=False)
    print('Finished caching {:s}'.format(svs_path))

    print('Copying files from disk to gcs...')
    util.gsutil_cp(os.path.join(local_cache_path, '*'), output_cache_path)

    # Calculate elapsed time
    elapsed_time_s = round((time.time() - start_time), 1)  # in seconds
    completed_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
    client = datastore.Client(project_id)
    mark_done(client, task_id, completed_time, elapsed_time_s)

    print('Completed caching of SVS file: {}'.format(svs_path))
    print('Finish Timestamp: {} - Time elapsed: {} seconds.'.format(
        completed_time, elapsed_time_s))

    subscriber = pubsub.SubscriberClient()
    subscription_path = subscriber.subscription_path(project_id,
                                                     subscription_name)

    # Acknowledging the message
    subscriber.acknowledge(subscription_path, [msg.ack_id])
    print("{}: Acknowledged {}".format(
        time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),
        msg.message.data))
示例#4
0
def cross_classify(cancertype1,
                   cancertype2,
                   include_training_set=True,
                   train_test_percentage=[70, 30],
                   label_terms=['normal', 'tumor'],
                   labal_names=['cnv'],
                   nClass=2):
    if cancertype1 == cancertype2:
        include_training_set = False
    print('cross classify {:s} and {:s} ...'.format(cancertype1, cancertype2))

    saved_model_path = os.path.join('/sdata', results_path, 'saved_models',
                                    cancertype1, '')
    tfrecpath = os.path.join('/sdata', pancancer_tfrecords_path, cancertype2,
                             '')
    image_file_metadata_filename = os.path.join(
        '/sdata', annotations_path, cancertype2,
        'caches_basic_annotations.txt')

    print('copying files from GCS')
    util.gsutil_cp('gs://' + input_bucket + '/' +
                   saved_model_path[len('/sdata/'):] + '*',
                   saved_model_path,
                   make_dir=True,
                   payer_project_id=payer_project_id)
    util.gsutil_cp('gs://' + input_bucket + '/' + tfrecpath[len('/sdata/'):] +
                   '*.testing',
                   tfrecpath,
                   make_dir=True,
                   payer_project_id=payer_project_id)
    util.gsutil_cp('gs://' + input_bucket + '/' + tfrecpath[len('/sdata/'):] +
                   '*.validation',
                   tfrecpath,
                   make_dir=True,
                   payer_project_id=payer_project_id)
    if include_training_set:
        util.gsutil_cp('gs://' + input_bucket + '/' +
                       tfrecpath[len('/sdata/'):] + '*.training',
                       tfrecpath,
                       make_dir=True,
                       payer_project_id=payer_project_id)
    util.gsutil_cp('gs://' + input_bucket + '/' +
                   image_file_metadata_filename[len('/sdata/'):],
                   image_file_metadata_filename,
                   make_dir=False,
                   payer_project_id=payer_project_id)

    image_files_metadata = pd.read_csv(image_file_metadata_filename, sep=',')
    image_files_metadata.rename(columns={'cnv': 'label'}, inplace=True)
    image_files_metadata['label_name'] = image_files_metadata['label'].map(
        lambda x: label_terms[x])

    tfrecordfileslist = glob.glob(tfrecpath +
                                  '*.testing') + glob.glob(tfrecpath +
                                                           '*.validation')
    if include_training_set:
        tfrecordfileslist += glob.glob(tfrecpath + '*.training')
    else:
        image_files_metadata = image_files_metadata[
            image_files_metadata['crossval_group'] != 'training']


#     test_batch_size = run_classification.get_total_tfrec_count(tfrecordfileslist)
    test_batch_size = len(image_files_metadata)

    print('running the CNN')

    test_accuracies_list, predictions_list, confusion_matrices_list, imagefilenames, final_softmax_outputs_list = \
    run_classification.test_multilabel_classification_with_inception_CNN_fast(image_files_metadata, labal_names, tfrecordfileslist = tfrecordfileslist,
                                                      saved_model_path = saved_model_path, nClass = nClass, test_batch_size = test_batch_size)

    print('caclulating the AUC')
    votes, predictions_df = plotting_cnn.get_per_slide_average_predictions(
        image_files_metadata, imagefilenames, predictions_list, ['label'])

    roc_auc = {}
    roc_auc['perslide'] = plotting_cnn.plot_perslide_roc(predictions_df,
                                                         plot_results=False)

    roc_auc['pertile'] = plotting_cnn.plot_pertile_roc(
        imagefilenames,
        predictions_list,
        final_softmax_outputs_list,
        image_files_metadata,
        plot_results=False)

    print('storing the output')
    jsonfile = 'roc_auc_{:s}_{:s}.json'.format(cancertype1, cancertype2)
    json.dump(roc_auc, open(jsonfile, 'w'))

    util.gsutil_cp(jsonfile,
                   gcs_output_path,
                   payer_project_id=payer_project_id)

    # save output to pickle file
    pickle_dir = os.path.join(
        results_path, 'pickles/pickles_train{:d}_test{:d}_pickled/'.format(
            *train_test_percentage))
    pickle_path = os.path.join(
        pickle_dir,
        'run_cnn_output_{:s}_{:s}.pkl'.format(cancertype1, cancertype2))

    util.mkdir_if_not_exist('/sdata/' + pickle_dir)
    pickle.dump([
        image_files_metadata, test_accuracies_list, predictions_list,
        confusion_matrices_list, imagefilenames, final_softmax_outputs_list
    ], open('/sdata/' + pickle_path, 'wb'))

    util.gsutil_cp(os.path.join('/sdata', pickle_path),
                   os.path.join('gs://' + input_bucket, pickle_path),
                   payer_project_id=payer_project_id)