Пример #1
0
def load_human_data(dataset_id):
  if 'CSNVL' not in dataset_id:
    lib_nm = _data.get_lib_nm(dataset_id)
    lib_design, seq_col = _data.get_lib_design(dataset_id)
    nms = lib_design['Name (unique)']
    seqs = lib_design[seq_col]
  else:
    # Use any conds to load 12kChar, CtoT, and AtoG libs
    dids = ['190418_mES_12kChar_AID', '190329_HEK293T_AtoG_ABE', '190307_HEK_CtoT_BE4']
    nms, seqs = [], []
    for did in dids:
      lib_design, seq_col = _data.get_lib_design(did)
      nms += list(lib_design['Name (unique)'])
      seqs += list(lib_design[seq_col])

  nm_to_seq = {nm: seq for nm, seq in zip(nms, seqs)}


  Y_dir = _config.OUT_PLACE + 'combin_data_Y_imputewt/'
  with gzip.open(Y_dir + '%s.pkl.gz' % (dataset_id), 'rb') as f:
    Y = pickle.load(f)
  
  NAMES = list(Y.keys())
  Y = list(Y.values())

  # Load X
  if 'CSNVL' not in dataset_id:
    zero_idxs = [_data.pos_to_idx(0, dataset_id)] * len(NAMES)
  else:
    zero_idxs = []
    for nm in NAMES:
      if 'satmut' in nm:
        # 21
        zero_idxs.append(_data.zero_pos['12kChar'])
      else:
        # CtoT = AtoG = 10
        zero_idxs.append(_data.zero_pos['CtoT'])

  X = []
  timer = _util.Timer(total = len(NAMES))
  for nm, y, zero_idx in zip(NAMES, Y, zero_idxs):
    seq = nm_to_seq[nm]
    # seq_30nt = seq[zero_idx - 9 : zero_idx + 20 + 1]
    if zero_idx >= 9 + 10:
      # 12kChar
      pass
    else:
      # CtoT, AtoG libs
      prefix = 'GATGGGTGCGACGCGTCAT'
      seq = prefix + seq
      zero_idx += len(prefix)

    seq_50nt = seq[zero_idx - 9 - 10 : zero_idx + 20 + 10 + 1]
    assert len(seq_50nt) == 50
    X.append(seq_50nt)

  return X, Y, NAMES
Пример #2
0
  def to_autoregress_masked_tensors(self, Y):
    tensors_Y = []
    tensors_target = []
    editable_index_info = []
    print('Transforming Y into tensors...')
    timer = _util.Timer(total = len(Y))
    for idx, y in enumerate(Y):
      single_target_y = []
      single_target_targets = []

      nt_cols = self.all_nt_cols[idx]
      editable_pos_to_nt = {int(col[1:]): col[0] for col in nt_cols}
      editable_pos = sorted(list(editable_pos_to_nt.keys()))
      pos_to_col = {int(col[1:]): col for col in nt_cols}
      ref_nts = [nt_col[0] for nt_col in nt_cols]

      single_target_editable_info = {
        'pos': {idx: editable_pos[idx] for idx in range(len(editable_pos))},
        'ref_nt': {idx: pos_to_col[editable_pos[idx]][0] for idx in range(len(editable_pos))},
      }
      editable_index_info.append(single_target_editable_info)  

      # Append wild-type row
      wt_row = pd.DataFrame({col: col[0] for col in nt_cols}, index = [0])
      y = y.append(wt_row, ignore_index = True, sort = False)

      for jdx, row in y.iterrows():
        col_to_obs_edit = {col: row[col] for col in nt_cols}
        single_row_y = self.form_masked_edit_vectors(
          editable_pos,
          pos_to_col,
          col_to_obs_edit,
        )
        single_target_y.append(single_row_y)

        single_row_target = self.form_target_vectors(
          editable_pos,
          pos_to_col,
          col_to_obs_edit,
        )
        single_target_targets.append(single_row_target)

      '''
        single_target_y.shape = (
          num. unique edits + 1, 
          num. editable bases, 
          y_mask_dim
        )
      '''
      tensors_Y.append(torch.Tensor(single_target_y))
      tensors_target.append(torch.Tensor(single_target_targets))
      timer.update()

    y_mask_dim = tensors_Y[0].shape[-1]
    return tensors_Y, y_mask_dim, tensors_target, editable_index_info
Пример #3
0
    def run_suite(self, test_suite):
        '''
        Run all tests/suites. From the given test_suite.

        1. Run child testcases passing them their required fixtures.
           - (We don't setup since the test case might override the fixture)
           - Collect results as tests are performed.
        2. Handle teardown for all fixtures in the test_suite.
        '''
        for logger in self.result_loggers:
            logger.begin(test_suite)

        suite_iterator = enumerate(test_suite.iter_testlists())

        outcomes = set()

        suite_timer = _util.Timer()
        suite_timer.start()
        for (idx, (testlist, testcase)) in suite_iterator:
            assert isinstance(testcase, TestCase)
            outcome = self.run_test(testcase, fixtures=test_suite.fixtures)
            outcomes.add(outcome)

            # If there was a chance we might need to skip the remaining
            # tests...
            if outcome in Outcome.failfast \
                    and idx < len(test_suite):
                if config.fail_fast:
                    log.bold('Test failed with the --fail-fast flag provided.')
                    log.bold('Ignoring remaining tests.')
                    break
                elif test_suite.fail_fast:
                    log.bold('Test failed in a fail_fast TestSuite. Skipping'
                             ' remaining tests.')
                    rem_iter = (testcase for _, (_, testcase) \
                                in suite_iterator)
                    self._generate_skips(testcase.name, rem_iter)
                elif testlist.fail_fast:
                    log.bold('Test failed in a fail_fast TestList. Skipping'
                             ' its remaining items.')
                    rem_iter = self._remaining_testlist_tests(
                        testcase, testlist, suite_iterator)
                    # Iterate through the current testlist skipping its tests.
                    self._generate_skips(testcase.name, rem_iter)

        for fixture in test_suite.fixtures.values():
            fixture.teardown()
        suite_timer.stop()

        outcome = self._suite_outcome(outcomes)
        self._log_outcome(outcome, runtime=suite_timer.runtime())
        for logger in self.result_loggers:
            logger.end_current()

        return outcome
Пример #4
0
 def get_obs_freqs(self, Y):
   '''
     List of tensors with shape: (
       num. unique obs. edits,
     )
   '''
   obs_freqs = []
   print('Getting obs freqs...')
   timer = _util.Timer(total = len(Y))
   for _, y in enumerate(Y):
     freqs = torch.Tensor(list(y['Y']))
     obs_freqs.append(freqs)
     timer.update()
   return obs_freqs
Пример #5
0
  def featurize(self, X):
    # x provided is 50-nt, but we care only about center 30-nt.
    # 30-nt ranges from positions -9 to 20 relative to gRNA.
    ftx = []
    offset = 10
    print('Featurizing X')
    timer = _util.Timer(total = len(X))
    for _, seq in enumerate(X):

      nt_cols = self.all_nt_cols[_]
      editable_pos = sorted([int(col[1:]) for col in nt_cols])
      seq_30nt = seq[10:-10]

      # TO DO: subset to only editable positions
      single_target_ftx = []
      for pos in editable_pos:
        idx = pos + 9
        # use offset_idx to query seq at current pos
        offset_idx = offset + idx

        assert len(seq) == 30 + offset * 2, 'Bad offset'

        X_singlepos = []
        if hyperparameters['context_feature'] == True:
          radii = hyperparameters['context_radii']
          X_singlepos += self.ohe_seq(seq[offset_idx - radii : offset_idx + radii + 1])
        
        if hyperparameters['fullcontext_feature'] == True:
          X_singlepos += self.ohe_seq(seq_30nt)

        if hyperparameters['position_feature'] == True:
          X_singlepos += self.ohe_position(pos, -9, 20)

        single_target_ftx.append(X_singlepos)

      # single_target_ftx.shape = (num. editable bases, x_dim)
      ftx.append(torch.Tensor(single_target_ftx))
      timer.update()

    x_dim = ftx[0].shape[-1]
    return ftx, x_dim
Пример #6
0
    def _run_test(self, testobj, fstdout_name, fstderr_name, fixtures):
        if fixtures is None:
            fixtures = {}

        # We'll use a local shallow copy of fixtures to make it easier to
        # cleanup and override suite level fixtures with testcase level ones.
        fixtures = fixtures.copy()
        fixtures.update(testobj.fixtures)

        test_timer = _util.Timer()
        test_timer.start()

        for logger in self.result_loggers:
            logger.begin(testobj)

        def _run_test():
            reason = None
            try:
                testobj(fixtures=fixtures)
            except AssertionError as e:
                reason = e.message
                if not reason:
                    reason = traceback.format_exc()
                outcome = Outcome.FAIL
            except test.TestSkipException as e:
                reason = e.message
                outcome = Outcome.SKIP
            except test.TestFailException as e:
                reason = e.message
                outcome = Outcome.FAIL
            except Exception as e:
                reason = traceback.format_exc()
                outcome = Outcome.FAIL
            else:
                outcome = Outcome.PASS

            return (outcome, reason)

        # Build any fixtures that haven't been built yet.
        log.debug('Building fixtures for TestCase: %s' % testobj.name)
        failed_builds = self.setup_unbuilt(fixtures.values(),
                                           setup_lazy_init=True)

        if failed_builds:
            reason = ''
            for fixture, error in failed_builds:
                reason += 'Failed to build %s\n' % fixture
                reason += '%s' % error
            reason = reason
            outcome = Outcome.ERROR
        else:
            (outcome, reason) = _run_test()

        for fixture in testobj.fixtures.values():
            fixture.teardown()

        test_timer.stop()
        self._log_outcome(outcome,
                          reason=reason,
                          runtime=test_timer.runtime(),
                          fstdout_name=fstdout_name,
                          fstderr_name=fstderr_name)

        for logger in self.result_loggers:
            logger.end_current()

        return outcome