예제 #1
0
 def unique_beams(self):
   ''' Iterate through unique beams. '''
   from dxtbx.imageset import ImageSweep
   from libtbx.containers import OrderedDict
   obj = OrderedDict()
   for iset in self._imagesets:
     if isinstance(iset, ImageSweep):
       obj[iset.get_beam()] = None
     else:
       for i in xrange(len(iset)):
         obj[iset.get_beam(i)] = None
   return obj.keys()
예제 #2
0
 def unique_beams(self):
   ''' Iterate through unique beams. '''
   from dxtbx.imageset import ImageSweep
   from libtbx.containers import OrderedDict
   obj = OrderedDict()
   for iset in self._imagesets:
     if isinstance(iset, ImageSweep):
       obj[iset.get_beam()] = None
     else:
       for i in range(len(iset)):
         obj[iset.get_beam(i)] = None
   return obj.keys()
예제 #3
0
    class image_data_cache(object):
      def __init__(self, imageset, size=10):
        self.imageset = imageset
        self.size = size
        self._image_data = OrderedDict()

      def __getitem__(self, i):
        image_data = self._image_data.get(i)
        if image_data is None:
          image_data = self.imageset.get_raw_data(i)
          if len(self._image_data) >= self.size:
            # remove the oldest entry in the cache
            del self._image_data[self._image_data.keys()[0]]
          self._image_data[i] = image_data
        return image_data
예제 #4
0
        class image_data_cache(object):
            def __init__(self, imageset, size=10):
                self.imageset = imageset
                self.size = size
                self._image_data = OrderedDict()

            def __getitem__(self, i):
                image_data = self._image_data.get(i)
                if image_data is None:
                    image_data = self.imageset.get_raw_data(i)
                    if len(self._image_data) >= self.size:
                        # remove the oldest entry in the cache
                        del self._image_data[self._image_data.keys()[0]]
                    self._image_data[i] = image_data
                return image_data
예제 #5
0
 def _unique_detectors_dict(self):
   ''' Returns an ordered dictionary of detector objects. '''
   from dxtbx.imageset import ImageSweep
   from libtbx.containers import OrderedDict
   obj = OrderedDict()
   for iset in self._imagesets:
     if isinstance(iset, ImageSweep):
       obj[iset.get_detector()] = None
     else:
       for i in range(len(iset)):
         obj[iset.get_detector(i)] = None
   detector_id = 0
   for detector in obj.keys():
     obj[detector] = detector_id
     detector_id = detector_id + 1
   return obj
예제 #6
0
 def _unique_detectors_dict(self):
   ''' Returns an ordered dictionary of detector objects. '''
   from dxtbx.imageset import ImageSweep
   from libtbx.containers import OrderedDict
   obj = OrderedDict()
   for iset in self._imagesets:
     if isinstance(iset, ImageSweep):
       obj[iset.get_detector()] = None
     else:
       for i in xrange(len(iset)):
         obj[iset.get_detector(i)] = None
   detector_id = 0
   for detector in obj.keys():
     obj[detector] = detector_id
     detector_id = detector_id + 1
   return obj
예제 #7
0
 def unique_scans(self):
   ''' Iterate through unique scans. '''
   from dxtbx.imageset import ImageSweep
   from libtbx.containers import OrderedDict
   obj = OrderedDict()
   for iset in self._imagesets:
     if isinstance(iset, ImageSweep):
       obj[iset.get_scan()] = None
     else:
       for i in range(len(iset)):
         try:
           model = iset.get_scan(i)
           if model is not None:
             obj[model] = None
         except Exception:
           pass
   return obj.keys()
예제 #8
0
 def unique_scans(self):
   ''' Iterate through unique scans. '''
   from dxtbx.imageset import ImageSweep
   from libtbx.containers import OrderedDict
   obj = OrderedDict()
   for iset in self._imagesets:
     if isinstance(iset, ImageSweep):
       obj[iset.get_scan()] = None
     else:
       for i in xrange(len(iset)):
         try:
           model = iset.get_scan(i)
           if model is not None:
             obj[model] = None
         except Exception:
           pass
   return obj.keys()
예제 #9
0
class loop(DictMixin):
    def __init__(self, header=None, data=None):
        self._columns = OrderedDict()
        self.keys_lower = {}
        if header is not None:
            for key in header:
                self.setdefault(key, flex.std_string())
            if data is not None:
                # the number of data items must be an exact multiple of the number of headers
                assert len(data) % len(
                    header) == 0, "Wrong number of data items for loop"
                n_rows = len(data) // len(header)
                n_columns = len(header)
                for i in range(n_rows):
                    self.add_row(
                        [data[i * n_columns + j] for j in range(n_columns)])
        elif header is None and data is not None:
            assert isinstance(data, dict) or isinstance(data, OrderedDict)
            self.add_columns(data)
            self.keys_lower = dict([(key.lower(), key)
                                    for key in self._columns.keys()])

    def __setitem__(self, key, value):
        if not re.match(tag_re, key):
            raise Sorry("%s is not a valid data name" % key)
        if len(self) > 0:
            assert len(value) == self.size()
        if not isinstance(value, flex.std_string):
            for flex_numeric_type in (flex.int, flex.double):
                if isinstance(value, flex_numeric_type):
                    value = value.as_string()
                else:
                    try:
                        value = flex_numeric_type(value).as_string()
                    except TypeError:
                        continue
                    else:
                        break
            if not isinstance(value, flex.std_string):
                value = flex.std_string(value)
        # value must be a mutable type
        assert hasattr(value, '__setitem__')
        self._columns[key] = value
        self.keys_lower[key.lower()] = key

    def __getitem__(self, key):
        return self._columns[self.keys_lower[key.lower()]]

    def __delitem__(self, key):
        del self._columns[self.keys_lower[key.lower()]]
        del self.keys_lower[key.lower()]

    def keys(self):
        return self._columns.keys()

    def __repr__(self):
        return repr(OrderedDict(self.iteritems()))

    def name(self):
        return common_substring(self.keys()).rstrip('_').rstrip('.')

    def size(self):
        size = 0
        for column in self.values():
            size = max(size, len(column))
        return size

    def n_rows(self):
        return self.size()

    def n_columns(self):
        return len(self.keys())

    def add_row(self, row, default_value="?"):
        if isinstance(row, dict):
            for key in self:
                if key in row:
                    self[key].append(str(row[key]))
                else:
                    self[key].append(default_value)
        else:
            assert len(row) == len(self)
            for i, key in enumerate(self):
                self[key].append(str(row[i]))

    def add_column(self, key, values):
        if self.size() != 0:
            assert len(values) == self.size()
        self[key] = values
        self.keys_lower[key.lower()] = key

    def add_columns(self, columns):
        assert isinstance(columns, dict) or isinstance(columns, OrderedDict)
        for key, value in columns.iteritems():
            self.add_column(key, value)

    def update_column(self, key, values):
        assert type(key) == type(""), "first argument is column key string"
        if self.size() != 0:
            assert len(
                values) == self.size(), "len(values) %d != self.size() %d" % (
                    len(values),
                    self.size(),
                )
        self[key] = values
        self.keys_lower[key.lower()] = key

    def delete_row(self, index):
        assert index < self.n_rows()
        for column in self._columns.values():
            del column[index]

    def __copy__(self):
        new = loop()
        new._columns = self._columns.copy()
        new.keys_lower = self.keys_lower.copy()
        return new

    copy = __copy__

    def __deepcopy__(self, memo):
        new = loop()
        new._columns = copy.deepcopy(self._columns, memo)
        new.keys_lower = copy.deepcopy(self.keys_lower, memo)
        return new

    def deepcopy(self):
        return copy.deepcopy(self)

    def show(self,
             out=None,
             indent="  ",
             indent_row=None,
             fmt_str=None,
             align_columns=True):
        assert self.n_rows() > 0 and self.n_columns() > 0, "keys: %s %d %d" % (
            self.keys(),
            self.n_rows(),
            self.n_columns(),
        )
        if out is None:
            out = sys.stdout
        if indent_row is None:
            indent_row = indent
        assert indent.strip() == ""
        assert indent_row.strip() == ""
        print >> out, "loop_"
        for k in self.keys():
            print >> out, indent + k
        values = self._columns.values()
        range_len_values = range(len(values))
        if fmt_str is not None:
            # Pretty printing:
            #   The user is responsible for providing a valid format string.
            #   Values are not quoted - it is the user's responsibility to place
            #   appropriate quotes in the format string if a particular value may
            #   contain spaces.
            values = copy.deepcopy(values)
            for i, v in enumerate(values):
                for flex_numeric_type in (flex.int, flex.double):
                    if not isinstance(v, flex_numeric_type):
                        try:
                            values[i] = flex_numeric_type(v)
                        except ValueError:
                            continue
                        else:
                            break
            if fmt_str is None:
                fmt_str = indent_row + ' '.join(["%s"] * len(values))
            for i in range(self.size()):
                print >> out, fmt_str % tuple(
                    [values[j][i] for j in range_len_values])
        elif align_columns:
            fmt_str = []
            for i, (k, v) in enumerate(self.iteritems()):
                for i_v in range(v.size()):
                    v[i_v] = format_value(v[i_v])
                # exclude and semicolon text fields from column width calculation
                v_ = flex.std_string(item for item in v if "\n" not in item)
                width = v_.max_element_length()
                # See if column contains only number, '.' or '?'
                # right-align numerical columns, left-align everything else
                v = v.select(~((v == ".") | (v == "?")))
                try:
                    flex.double(v)
                except ValueError:
                    width *= -1
                fmt_str.append("%%%is" % width)
            fmt_str = indent_row + "  ".join(fmt_str)
            for i in range(self.size()):
                print >> out, (fmt_str %
                               tuple([values[j][i]
                                      for j in range_len_values])).rstrip()
        else:
            for i in range(self.size()):
                values_to_print = [
                    format_value(values[j][i]) for j in range_len_values
                ]
                print >> out, ' '.join([indent] + values_to_print)

    def __str__(self):
        s = StringIO()
        self.show(out=s)
        return s.getvalue()

    def iterrows(self):
        """ Warning! Still super-slow! """
        keys = self.keys()
        s_values = self.values()
        range_len_self = range(len(self))
        # range is 1% faster than xrange in this particular place.
        # tuple (s_values...) is slightly faster than list
        for j in range(self.size()):
            yield OrderedDict(
                zip(keys, (s_values[i][j] for i in range_len_self)))

    def find_row(self, kv_dict):
        self_keys = self.keys()
        for k in kv_dict.keys():
            assert k in self_keys
        result = []
        s_values = self.values()
        range_len_self = range(len(self))
        for i in range(self.size()):
            goodrow = True
            for k, v in kv_dict.iteritems():
                if self[k][i] != v:
                    goodrow = False
                    break
            if goodrow:
                result.append(
                    OrderedDict(
                        zip(self_keys,
                            [s_values[j][i] for j in range_len_self])))
        return result

    def sort(self, key=None, reverse=False):
        self._columns = OrderedDict(
            sorted(self._columns.items(), key=key, reverse=reverse))

    def order(self, order):
        def _cmp_key(k1, k2):
            for i, o in enumerate(order):
                if k1 == o: break
            for j, o in enumerate(order):
                if k2 == o: break
            if k1 < k2: return -1
            return 1

        keys = self._columns.keys()
        keys.sort(_cmp_key)
        tmp = OrderedDict()
        for o in order:
            tmp[o] = self._columns[o]
        self._columns = tmp

    def __eq__(self, other):
        if (len(self) != len(other) or self.size() != other.size()
                or self.keys() != other.keys()):
            return False
        for value, other_value in zip(self.values(), other.values()):
            if (value == other_value).count(True) != len(value):
                return False
        return True
예제 #10
0
class cif(DictMixin):
    def __init__(self, blocks=None):
        if blocks is not None:
            self.blocks = OrderedDict(blocks)
        else:
            self.blocks = OrderedDict()
        self.keys_lower = dict([(key.lower(), key)
                                for key in self.blocks.keys()])

    def __setitem__(self, key, value):
        assert isinstance(value, block)
        if not re.match(tag_re, '_' + key):
            raise Sorry("%s is not a valid data block name" % key)
        self.blocks[key] = value
        self.keys_lower[key.lower()] = key

    def get(self, key, default=None):
        key_lower = self.keys_lower.get(key.lower())
        if (key_lower is None):
            return default
        return self.blocks.get(key_lower, default)

    def __getitem__(self, key):
        result = self.get(key)
        if (result is None):
            raise KeyError('Unknown CIF data block name: "%s"' % key)
        return result

    def __delitem__(self, key):
        del self.blocks[self.keys_lower[key.lower()]]
        del self.keys_lower[key.lower()]

    def keys(self):
        return self.blocks.keys()

    def __repr__(self):
        return repr(OrderedDict(self.iteritems()))

    def __copy__(self):
        return cif(self.blocks.copy())

    copy = __copy__

    def __deepcopy__(self, memo):
        return cif(copy.deepcopy(self.blocks, memo))

    def deepcopy(self):
        return copy.deepcopy(self)

    def show(self,
             out=None,
             indent="  ",
             indent_row=None,
             data_name_field_width=34,
             loop_format_strings=None,
             align_columns=True):
        if out is None:
            out = sys.stdout
        for name, block in self.items():
            print >> out, "data_%s" % name
            block.show(out=out,
                       indent=indent,
                       indent_row=indent_row,
                       data_name_field_width=data_name_field_width,
                       loop_format_strings=loop_format_strings,
                       align_columns=align_columns)

    def __str__(self):
        s = StringIO()
        self.show(out=s)
        return s.getvalue()

    def validate(self,
                 dictionary,
                 show_warnings=True,
                 error_handler=None,
                 out=None):
        if out is None: out = sys.stdout
        from iotbx.cif import validation
        errors = {}
        if error_handler is None:
            error_handler = validation.ErrorHandler()
        for key, block in self.blocks.iteritems():
            error_handler = error_handler.__class__()
            dictionary.set_error_handler(error_handler)
            block.validate(dictionary)
            errors.setdefault(key, error_handler)
            if error_handler.error_count or error_handler.warning_count:
                error_handler.show(show_warnings=show_warnings, out=out)
        return error_handler

    def sort(self, recursive=False, key=None, reverse=False):
        self.blocks = OrderedDict(
            sorted(self.blocks.items(), key=key, reverse=reverse))
        if recursive:
            for b in self.blocks.values():
                b.sort(recursive=recursive, reverse=reverse)
예제 #11
0
class loop(DictMixin):
  def __init__(self, header=None, data=None):
    self._columns = OrderedDict()
    self.keys_lower = {}
    if header is not None:
      for key in header:
        self.setdefault(key, flex.std_string())
      if data is not None:
        # the number of data items must be an exact multiple of the number of headers
        assert len(data) % len(header) == 0, "Wrong number of data items for loop"
        n_rows = len(data)//len(header)
        n_columns = len(header)
        for i in range(n_rows):
          self.add_row([data[i*n_columns+j] for j in range(n_columns)])
    elif header is None and data is not None:
      assert isinstance(data, dict) or isinstance(data, OrderedDict)
      self.add_columns(data)
      self.keys_lower = dict(
        [(key.lower(), key) for key in self._columns.keys()])

  def __setitem__(self, key, value):
    if not re.match(tag_re, key):
      raise Sorry("%s is not a valid data name" %key)
    if len(self) > 0:
      assert len(value) == self.size()
    if not isinstance(value, flex.std_string):
      for flex_numeric_type in (flex.int, flex.double):
        if isinstance(value, flex_numeric_type):
          value = value.as_string()
        else:
          try:
            value = flex_numeric_type(value).as_string()
          except TypeError:
            continue
          else:
            break
      if not isinstance(value, flex.std_string):
        value = flex.std_string(value)
    # value must be a mutable type
    assert hasattr(value, '__setitem__')
    self._columns[key] = value
    self.keys_lower[key.lower()] = key

  def __getitem__(self, key):
    return self._columns[self.keys_lower[key.lower()]]

  def __delitem__(self, key):
    del self._columns[self.keys_lower[key.lower()]]
    del self.keys_lower[key.lower()]

  def keys(self):
    return self._columns.keys()

  def __repr__(self):
    return repr(OrderedDict(self.iteritems()))

  def name(self):
    return common_substring(self.keys()).rstrip('_').rstrip('.')

  def size(self):
    size = 0
    for column in self.values():
      size = max(size, len(column))
    return size

  def n_rows(self):
    size = 0
    for column in self.values():
      size = max(size, len(column))
    return size

  def n_columns(self):
    return len(self.keys())

  def add_row(self, row, default_value="?"):
    if isinstance(row, dict):
      for key in self:
        if key in row:
          self[key].append(str(row[key]))
        else:
          self[key].append(default_value)
    else:
      assert len(row) == len(self)
      for i, key in enumerate(self):
        self[key].append(str(row[i]))

  def add_column(self, key, values):
    if self.size() != 0:
      assert len(values) == self.size()
    self[key] = values
    self.keys_lower[key.lower()] = key

  def add_columns(self, columns):
    assert isinstance(columns, dict) or isinstance(columns, OrderedDict)
    for key, value in columns.iteritems():
      self.add_column(key, value)

  def update_column(self, key, values):
    assert type(key)==type(""), "first argument is column key string"
    if self.size() != 0:
      assert len(values) == self.size(), "len(values) %d != self.size() %d" % (
        len(values),
        self.size(),
        )
    self[key] = values
    self.keys_lower[key.lower()] = key

  def delete_row(self, index):
    assert index < self.n_rows()
    for column in self._columns.values():
      del column[index]

  def __copy__(self):
    new = loop()
    new._columns = self._columns.copy()
    new.keys_lower = self.keys_lower.copy()
    return new

  copy = __copy__

  def __deepcopy__(self, memo):
    new = loop()
    new._columns = copy.deepcopy(self._columns, memo)
    new.keys_lower = copy.deepcopy(self.keys_lower, memo)
    return new

  def deepcopy(self):
    return copy.deepcopy(self)

  def show(self, out=None, indent="  ", indent_row=None, fmt_str=None, align_columns=True):
    assert self.n_rows() > 0 and self.n_columns() > 0, "keys: %s %d %d" % (
      self.keys(),
      self.n_rows(),
      self.n_columns(),
      )
    if out is None:
      out = sys.stdout
    if indent_row is None:
      indent_row = indent
    assert indent.strip() == ""
    assert indent_row.strip() == ""
    print >> out, "loop_"
    for k in self.keys():
      print >> out, indent + k
    values = self._columns.values()
    if fmt_str is not None:
      # Pretty printing:
      #   The user is responsible for providing a valid format string.
      #   Values are not quoted - it is the user's responsibility to place
      #   appropriate quotes in the format string if a particular value may
      #   contain spaces.
      values = copy.deepcopy(values)
      for i, v in enumerate(values):
        for flex_numeric_type in (flex.int, flex.double):
          if not isinstance(v, flex_numeric_type):
            try:
              values[i] = flex_numeric_type(v)
            except ValueError:
              continue
            else:
              break
      if fmt_str is None:
        fmt_str = indent_row + ' '.join(["%s"]*len(values))
      for i in range(self.size()):
        print >> out, fmt_str % tuple([values[j][i] for j in range(len(values))])
    elif align_columns:
      fmt_str = []
      for i, (k, v) in enumerate(self.iteritems()):
        for i_v in range(v.size()):
          v[i_v] = format_value(v[i_v])
        # exclude and semicolon text fields from column width calculation
        v_ = flex.std_string(item for item in v if "\n" not in item)
        width = v_.max_element_length()
        # See if column contains only number, '.' or '?'
        # right-align numerical columns, left-align everything else
        v = v.select(~( (v == ".") | (v == "?") ))
        try:
          flex.double(v)
        except ValueError:
          width *= -1
        fmt_str.append("%%%is" %width)
      fmt_str = indent_row + "  ".join(fmt_str)
      for i in range(self.size()):
        print >> out, (fmt_str %
                       tuple([values[j][i]
                              for j in range(len(values))])).rstrip()
    else:
      for i in range(self.size()):
        values_to_print = [format_value(values[j][i]) for j in range(len(values))]
        print >> out, ' '.join([indent] + values_to_print)

  def __str__(self):
    s = StringIO()
    self.show(out=s)
    return s.getvalue()

  def iterrows(self):
    keys = self.keys()
    for j in range(self.size()):
      yield OrderedDict(zip(keys, [self.values()[i][j] for i in range(len(self))]))

  def sort(self, key=None, reverse=False):
    self._columns = OrderedDict(
      sorted(self._columns.items(), key=key, reverse=reverse))

  def order(self, order):
    def _cmp_key(k1, k2):
      for i, o in enumerate(order):
        if k1==o: break
      for j, o in enumerate(order):
        if k2==o: break
      if k1<k2: return -1
      return 1
    keys = self._columns.keys()
    keys.sort(_cmp_key)
    tmp = OrderedDict()
    for o in order:
      tmp[o]=self._columns[o]
    self._columns = tmp

  def __eq__(self, other):
    if (len(self) != len(other) or
        self.size() != other.size() or
        self.keys() != other.keys()):
      return False
    for value, other_value in zip(self.values(), other.values()):
      if (value == other_value).count(True) != len(value):
        return False
    return True
예제 #12
0
class cif(DictMixin):
  def __init__(self, blocks=None):
    if blocks is not None:
      self.blocks = OrderedDict(blocks)
    else:
      self.blocks = OrderedDict()
    self.keys_lower = dict([(key.lower(), key) for key in self.blocks.keys()])

  def __setitem__(self, key, value):
    assert isinstance(value, block)
    if not re.match(tag_re, '_'+key):
      raise Sorry("%s is not a valid data block name" %key)
    self.blocks[key] = value
    self.keys_lower[key.lower()] = key

  def get(self, key, default=None):
    key_lower = self.keys_lower.get(key.lower())
    if (key_lower is None):
      return default
    return self.blocks.get(key_lower, default)

  def __getitem__(self, key):
    result = self.get(key)
    if (result is None):
      raise KeyError('Unknown CIF data block name: "%s"' % key)
    return result

  def __delitem__(self, key):
    del self.blocks[self.keys_lower[key.lower()]]
    del self.keys_lower[key.lower()]

  def keys(self):
    return self.blocks.keys()

  def __repr__(self):
    return repr(OrderedDict(self.iteritems()))

  def __copy__(self):
    return cif(self.blocks.copy())

  copy = __copy__

  def __deepcopy__(self, memo):
    return cif(copy.deepcopy(self.blocks, memo))

  def deepcopy(self):
    return copy.deepcopy(self)

  def show(self, out=None, indent="  ", indent_row=None,
           data_name_field_width=34,
           loop_format_strings=None):
    if out is None:
      out = sys.stdout
    for name, block in self.items():
      print >> out, "data_%s" %name
      block.show(
        out=out, indent=indent, indent_row=indent_row,
        data_name_field_width=data_name_field_width,
        loop_format_strings=loop_format_strings)

  def __str__(self):
    s = StringIO()
    self.show(out=s)
    return s.getvalue()

  def validate(self, dictionary, show_warnings=True, error_handler=None, out=None):
    if out is None: out = sys.stdout
    from iotbx.cif import validation
    errors = {}
    if error_handler is None:
      error_handler = validation.ErrorHandler()
    for key, block in self.blocks.iteritems():
      error_handler = error_handler.__class__()
      dictionary.set_error_handler(error_handler)
      block.validate(dictionary)
      errors.setdefault(key, error_handler)
      if error_handler.error_count or error_handler.warning_count:
        error_handler.show(show_warnings=show_warnings, out=out)
    return error_handler

  def sort(self, recursive=False, key=None, reverse=False):
    self.blocks = OrderedDict(sorted(self.blocks.items(), key=key, reverse=reverse))
    if recursive:
      for b in self.blocks.values():
        b.sort(recursive=recursive, reverse=reverse)
예제 #13
0
class multi_crystal_analysis(object):

  def __init__(self, unmerged_intensities, batches_all, n_bins=20, d_min=None,
               id_to_batches=None):

    sel = unmerged_intensities.sigmas() > 0
    unmerged_intensities = unmerged_intensities.select(sel)
    batches_all = batches_all.select(sel)

    unmerged_intensities.setup_binner(n_bins=n_bins)
    unmerged_intensities.show_summary()
    self.unmerged_intensities = unmerged_intensities
    self.merged_intensities = unmerged_intensities.merge_equivalents().array()

    separate = separate_unmerged(
      unmerged_intensities, batches_all, id_to_batches=id_to_batches)
    self.intensities = separate.intensities
    self.batches = separate.batches
    run_id_to_batch_id = separate.run_id_to_batch_id
    self.individual_merged_intensities = OrderedDict()
    for k in self.intensities.keys():
      self.intensities[k] = self.intensities[k].resolution_filter(d_min=d_min)
      self.batches[k] = self.batches[k].resolution_filter(d_min=d_min)
      self.individual_merged_intensities[k] = self.intensities[k].merge_equivalents().array()

    if run_id_to_batch_id is not None:
      labels = run_id_to_batch_id.values()
    else:
      labels = None
    racc = self.relative_anomalous_cc()
    if racc is not None:
      self.plot_relative_anomalous_cc(racc, labels=labels)
    correlation_matrix, linkage_matrix = self.compute_correlation_coefficient_matrix()

    self._cluster_dict = self.to_dict(correlation_matrix, linkage_matrix)

    self.plot_cc_matrix(correlation_matrix, linkage_matrix, labels=labels)

    self.write_output()

  def to_dict(self, correlation_matrix, linkage_matrix):

    from scipy.cluster import hierarchy
    tree = hierarchy.to_tree(linkage_matrix, rd=False)
    leaves_list = hierarchy.leaves_list(linkage_matrix)

    d = {}

    # http://w3facility.org/question/scipy-dendrogram-to-json-for-d3-js-tree-visualisation/
    # https://gist.github.com/mdml/7537455

    def add_node(node):
      if node.is_leaf(): return
      cluster_id = node.get_id() - len(linkage_matrix) - 1
      row = linkage_matrix[cluster_id]
      d[cluster_id+1] = {
        'datasets': [i+1 for i in sorted(node.pre_order())],
        'height': row[2],
      }

      # Recursively add the current node's children
      if node.left: add_node(node.left)
      if node.right: add_node(node.right)

    add_node(tree)

    return d

  def relative_anomalous_cc(self):
    if self.unmerged_intensities.anomalous_flag():
      d_min = min([ma.d_min() for ma in self.intensities.values()])
      racc = flex.double()
      full_set_anom_diffs = self.merged_intensities.anomalous_differences()
      for i_wedge in self.individual_merged_intensities.keys():
        ma_i = self.individual_merged_intensities[i_wedge].resolution_filter(d_min=d_min)
        anom_i = ma_i.anomalous_differences()
        anom_cc = anom_i.correlation(full_set_anom_diffs, assert_is_similar_symmetry=False).coefficient()
        racc.append(anom_cc)
      return racc

  def plot_relative_anomalous_cc(self, racc, labels=None):
    perm = flex.sort_permutation(racc)
    fig = pyplot.figure(dpi=1200, figsize=(16,12))
    pyplot.bar(range(len(racc)), list(racc.select(perm)))
    if labels is None:
      labels = ["%.0f" %(j+1) for j in perm]
    assert len(labels) == len(racc)
    pyplot.xticks([i+0.5 for i in range(len(racc))], labels)
    locs, labels = pyplot.xticks()
    pyplot.setp(labels, rotation=70)
    pyplot.xlabel("Dataset")
    pyplot.ylabel("Relative anomalous correlation coefficient")
    fig.savefig("racc.png")

  def compute_correlation_coefficient_matrix(self):
    from scipy.cluster import hierarchy
    import scipy.spatial.distance as ssd

    correlation_matrix = flex.double(
      flex.grid(len(self.intensities), len(self.intensities)))

    d_min = min([ma.d_min() for ma in self.intensities.values()])

    for i_wedge in self.individual_merged_intensities.keys():
      for j_wedge in self.individual_merged_intensities.keys():
        if j_wedge < i_wedge: continue
        ma_i = self.individual_merged_intensities[i_wedge].resolution_filter(d_min=d_min)
        ma_j = self.individual_merged_intensities[j_wedge].resolution_filter(d_min=d_min)
        cc_ij = ma_i.correlation(ma_j).coefficient()
        correlation_matrix[(i_wedge,j_wedge)] = cc_ij
        correlation_matrix[j_wedge,i_wedge] = cc_ij

    diffraction_dissimilarity = 1-correlation_matrix

    dist_mat = diffraction_dissimilarity.as_numpy_array()

    # convert the redundant n*n square matrix form into a condensed nC2 array
    dist_mat = ssd.squareform(dist_mat) # distArray[{n choose 2}-{n-i choose 2} + (j-i-1)] is the distance between points i and j

    method = ['single', 'complete', 'average', 'weighted'][2]

    linkage_matrix = hierarchy.linkage(dist_mat, method=method)

    return correlation_matrix, linkage_matrix

  def plot_cc_matrix(self, correlation_matrix, linkage_matrix, labels=None):
    from scipy.cluster import hierarchy

    ind = hierarchy.fcluster(linkage_matrix, t=0.05, criterion='distance')

    # Compute and plot dendrogram.
    fig = pyplot.figure(dpi=1200, figsize=(16,12))
    axdendro = fig.add_axes([0.09,0.1,0.2,0.8])
    Y = linkage_matrix
    Z = hierarchy.dendrogram(Y,
                             color_threshold=0.05,
                             orientation='right')
    axdendro.set_xticks([])
    axdendro.set_yticks([])

    # Plot distance matrix.
    axmatrix = fig.add_axes([0.3,0.1,0.6,0.8])
    index = Z['leaves']
    D = correlation_matrix.as_numpy_array()
    D = D[index,:]
    D = D[:,index]
    im = axmatrix.matshow(D, aspect='auto', origin='lower')
    axmatrix.yaxis.tick_right()
    if labels is not None:
      axmatrix.xaxis.tick_bottom()
      axmatrix.set_xticks(list(range(len(labels))))
      axmatrix.set_xticklabels([labels[i] for i in index], rotation=70)
      axmatrix.yaxis.set_ticks([])

    # Plot colorbar.
    axcolor = fig.add_axes([0.91,0.1,0.02,0.8])
    pyplot.colorbar(im, cax=axcolor)

    # Display and save figure.
    fig.savefig('correlation_matrix.png')
    fig.clear()

    fig = pyplot.figure(dpi=1200, figsize=(16,12))

    if labels is None:
      labels = ['%i' %(i+1) for i in range(len(self.intensities))]

    ddict = hierarchy.dendrogram(linkage_matrix,
                                 #truncate_mode='lastp',
                                 color_threshold=0.05,
                                 labels=labels,
                                 #leaf_rotation=90,
                                 show_leaf_counts=False)
    locs, labels = pyplot.xticks()
    pyplot.setp(labels, rotation=70)
    fig.savefig('dendrogram.png')

    import copy
    y2_dict = scipy_dendrogram_to_plotly_json(ddict) # above heatmap
    x2_dict = copy.deepcopy(y2_dict) # left of heatmap, rotated
    for d in y2_dict['data']:
      d['yaxis'] = 'y2'
      d['xaxis'] = 'x2'

    for d in x2_dict['data']:
      x = d['x']
      y = d['y']
      d['x'] = y
      d['y'] = x
      d['yaxis'] = 'y3'
      d['xaxis'] = 'x3'

    ccdict = {
      'data': [{
        'name': 'correlation_matrix',
        'x': list(range(D.shape[0])),
        'y': list(range(D.shape[1])),
        'z': D.tolist(),
        'type': 'heatmap',
        'colorbar': {
          'title': 'Correlation coefficient',
          'titleside': 'right',
          #'x': 0.96,
          #'y': 0.9,
          #'titleside': 'top',
          #'xanchor': 'right',
          'xpad': 0,
          #'yanchor': 'top'
        },
        'colorscale': 'Jet',
        'xaxis': 'x',
        'yaxis': 'y',
      }],

      'layout': {
        'autosize': False,
        'bargap': 0,
        'height': 1000,
        'hovermode': 'closest',
        'margin': {
          'r': 20,
          't': 50,
          'autoexpand': True,
          'l': 20
          },
        'showlegend': False,
        'title': 'Dendrogram Heatmap',
        'width': 1000,
        'xaxis': {
          'domain': [0.2, 0.9],
          'mirror': 'allticks',
          'showgrid': False,
          'showline': False,
          'showticklabels': True,
          'tickmode': 'array',
          'ticks': '',
          'ticktext': y2_dict['layout']['xaxis']['ticktext'],
          'tickvals': list(range(len(y2_dict['layout']['xaxis']['ticktext']))),
          'tickangle': 300,
          'title': '',
          'type': 'linear',
          'zeroline': False
        },
        'yaxis': {
          'domain': [0, 0.78],
          'anchor': 'x',
          'mirror': 'allticks',
          'showgrid': False,
          'showline': False,
          'showticklabels': True,
          'tickmode': 'array',
          'ticks': '',
          'ticktext': y2_dict['layout']['xaxis']['ticktext'],
          'tickvals': list(range(len(y2_dict['layout']['xaxis']['ticktext']))),
          'title': '',
          'type': 'linear',
          'zeroline': False
        },
        'xaxis2': {
          'domain': [0.2, 0.9],
          'anchor': 'y2',
          'showgrid': False,
          'showline': False,
          'showticklabels': False,
          'zeroline': False
        },
        'yaxis2': {
          'domain': [0.8, 1],
          'anchor': 'x2',
          'showgrid': False,
          'showline': False,
          'zeroline': False
        },
        'xaxis3': {
          'domain': [0.0, 0.1],
          'anchor': 'y3',
          'range': [max(max(d['x']) for d in x2_dict['data']), 0],
          'showgrid': False,
          'showline': False,
          'tickangle': 300,
          'zeroline': False
        },
        'yaxis3': {
          'domain': [0, 0.78],
          'anchor': 'x3',
          'showgrid': False,
          'showline': False,
          'showticklabels': False,
          'zeroline': False
        },
      }
    }
    d = ccdict
    d['data'].extend(y2_dict['data'])
    d['data'].extend(x2_dict['data'])

    d['clusters'] = self._cluster_dict

    import json
    with open('intensity_clusters.json', 'wb') as f:
      json.dump(d, f, indent=2)


  def write_output(self):

    rows = [["cluster_id", "# datasets", "height", "datasets"]]
    for cid in sorted(self._cluster_dict.keys()):
      cluster = self._cluster_dict[cid]
      datasets = cluster['datasets']
      rows.append([str(cid), str(len(datasets)),
                   '%.2f' %cluster['height'], ' '.join(['%s'] * len(datasets)) % tuple(datasets)])

    with open('intensity_clustering.txt', 'wb') as f:
      from libtbx import table_utils
      print >> f, table_utils.format(
        rows, has_header=True, prefix="|", postfix="|")
예제 #14
0
파일: XInfo.py 프로젝트: xia2/xia2
class XInfo(object):
  '''A class to represent all of the input to the xia2dpa system, with
  enough information to allow structure solution, as parsed from a
  .xinfo file, an example of which is in the source code.'''

  def __init__(self, xinfo_file, sweep_ids=None, sweep_ranges=None):
    '''Initialise myself from an input .xinfo file.'''

    # first initialise all of the data structures which will hold the
    # information...

    self._project = None
    self._crystals = OrderedDict()

    if sweep_ids is not None:
      sweep_ids = [s.lower() for s in sweep_ids]
    if sweep_ranges is not None:
      assert sweep_ids is not None
      assert len(sweep_ids) == len(sweep_ranges)
    self._sweep_ids = sweep_ids
    self._sweep_ranges = sweep_ranges

    # read the contents of the xinfo file

    self._parse_project(xinfo_file)

    self._validate()

    return

  def get_output(self):
    '''Generate a string representation of the project.'''

    text = 'Project %s\n' % self._project
    for crystal in self._crystals.keys():
      text += 'Crystal %s\n' % crystal
      text += '%s\n' % self._crystals[crystal].get_output()

    # remove a trailing newline...

    return text[:-1]

  def get_project(self):
    return self._project

  def get_crystals(self):
    return self._crystals

  def _validate(self):
    '''Validate the structure of this object, ensuring that
    everything looks right... raise exception if I find something
    wrong.'''

    return True

  def _parse_project(self, xinfo_file):
    '''Parse & validate the contents of the .xinfo file. This parses the
    project element (i.e. the whole thing..)'''

    project_records = []

    for r in open(xinfo_file, 'r').readlines():
      record = r.strip()
      if not record:
        pass
      elif record[0] == '!' or record[0] == '#':
        pass
      else :
        # then it may contain something useful...
        project_records.append(record)

    # so now we have loaded the whole file into memory stripping
    # out the crud... let's look for something useful

    for i in range(len(project_records)):
      record = project_records[i]
      if 'BEGIN PROJECT' in record:
        self._project = record.replace('BEGIN PROJECT', '').strip()
      if 'END PROJECT' in record:
        if not self._project == record.replace(
            'END PROJECT', '').strip():
          raise RuntimeError, 'error parsing END PROJECT record'

      # next look for crystals
      if 'BEGIN CRYSTAL ' in record:
        crystal_records = [record]
        while True:
          i += 1
          record = project_records[i]
          crystal_records.append(record)
          if 'END CRYSTAL ' in record:
            break

        self._parse_crystal(crystal_records)

      # that's everything, because parse_crystal handles
      # the rest...

    return

  def _parse_crystal(self, crystal_records):
    '''Parse the interesting information out of the crystal
    description.'''

    crystal = ''

    for i in range(len(crystal_records)):
      record = crystal_records[i]
      if 'BEGIN CRYSTAL ' in record:

        # we should only ever have one of these records in
        # a call to this method

        if crystal != '':
          raise RuntimeError, 'error in BEGIN CRYSTAL record'

        crystal = record.replace('BEGIN CRYSTAL ', '').strip()
        if crystal in self._crystals:
          raise RuntimeError, 'crystal %s already exists' % \
                crystal

        # cardinality:
        #
        # sequence - exactly one, a long string
        # wavelengths - a dictionary of data structures keyed by the
        #               wavelength id
        # sweeps - a dictionary of data structures keyed by the
        #          sweep id
        # ha_info - exactly one dictionary containing the heavy atom
        #           information

        self._crystals[crystal] = {
          'sequence':'',
          'wavelengths': OrderedDict(),
          'samples': OrderedDict(),
          'sweeps': OrderedDict(),
          'ha_info': OrderedDict(),
          'crystal_data': OrderedDict()
        }

      # next look for interesting stuff in the data structure...
      # starting with the sequence

      if 'BEGIN AA_SEQUENCE' in record:
        sequence = ''
        i += 1
        record = crystal_records[i]
        while record != 'END AA_SEQUENCE':
          if not '#' in record or '!' in record:
            sequence += record.strip()

          i += 1
          record = crystal_records[i]

        if self._crystals[crystal]['sequence'] != '':
          raise RuntimeError, 'error two SEQUENCE records found'

        self._crystals[crystal]['sequence'] = sequence

      # look for heavy atom information

      if 'BEGIN HA_INFO' in record:
        i += 1
        record = crystal_records[i]
        while record != 'END HA_INFO':
          key = record.split()[0].lower()
          value = record.split()[1]
          # things which are numbers are integers...
          if 'number' in key:
            value = int(value)
          self._crystals[crystal]['ha_info'][key] = value
          i += 1
          record = crystal_records[i]

      if 'BEGIN SAMPLE' in record:
        sample = record.replace('BEGIN SAMPLE ', '').strip()
        i += 1
        record = crystal_records[i]
        while not 'END SAMPLE' in record:
          i += 1
          record = crystal_records[i]
        self._crystals[crystal]['samples'][sample] = {}

      # look for wavelength definitions
      # FIXME need to check that there are not two wavelength
      # definitions with the same numerical value for the wavelength -
      # unless this is some way of handling RIP? maybe a NOFIXME.

      # look for data blocks

      if 'BEGIN CRYSTAL_DATA' in record:
        i += 1
        record = crystal_records[i]
        while not 'END CRYSTAL_DATA' in record:
          key = record.split()[0].lower()
          value = record.replace(record.split()[0], '').strip()
          self._crystals[crystal]['crystal_data'][key] = value
          i += 1
          record = crystal_records[i]

      if 'BEGIN WAVELENGTH ' in record:
        wavelength = record.replace('BEGIN WAVELENGTH ', '').strip()

        # check that this is a new wavelength definition
        if wavelength in self._crystals[crystal]['wavelengths']:
          raise RuntimeError, \
                'wavelength %s already exists for crystal %s' % \
                (wavelength, crystal)

        self._crystals[crystal]['wavelengths'][wavelength] = { }
        i += 1
        record = crystal_records[i]

        # populate this with interesting things
        while not 'END WAVELENGTH' in record:

          # deal with a nested WAVELENGTH_STATISTICS block

          if 'BEGIN WAVELENGTH_STATISTICS' in record:
            self._crystals[crystal]['wavelengths'][
                wavelength]['statistics'] = { }
            i += 1
            record = crystal_records[i]
            while not 'END WAVELENGTH_STATISTICS' in record:
              key, value = tuple(record.split())
              self._crystals[crystal]['wavelengths'][
                  wavelength]['statistics'][
                  key.lower()] = float(value)
              i += 1
              record = crystal_records[i]

          # else deal with the usual tokens

          key = record.split()[0].lower()

          if key == 'resolution':

            lst = record.split()

            if len(lst) < 2 or len(lst) > 3:
              raise RuntimeError, 'resolution dmin [dmax]'

            if len(lst) == 2:
              dmin = float(lst[1])

              self._crystals[crystal]['wavelengths'][
                  wavelength]['dmin'] = dmin

            else:
              dmin = min(map(float, lst[1:]))
              dmax = max(map(float, lst[1:]))

              self._crystals[crystal]['wavelengths'][
                  wavelength]['dmin'] = dmin

              self._crystals[crystal]['wavelengths'][
                  wavelength]['dmax'] = dmax

            i += 1
            record = crystal_records[i]
            continue

          if len(record.split()) == 1:
            raise RuntimeError, 'missing value for token %s' % \
                  record.split()[0]

          try:
            value = float(record.split()[1])
          except ValueError, e:
            value = record.replace(record.split()[0], '').strip()

          self._crystals[crystal]['wavelengths'][
              wavelength][key] = value
          i += 1
          record = crystal_records[i]

      # next look for sweeps, checking that the wavelength
      # definitions match up...

      if 'BEGIN SWEEP' in record:
        sweep = record.replace('BEGIN SWEEP', '').strip()

        if self._sweep_ids is not None and sweep.lower() not in self._sweep_ids:
          continue

        elif self._sweep_ranges is not None:
          start_end = self._sweep_ranges[self._sweep_ids.index(sweep.lower())]
        else:
          start_end = None

        if sweep in self._crystals[crystal]['sweeps']:
          raise RuntimeError, \
                'sweep %s already exists for crystal %s' % \
                (sweep, crystal)

        self._crystals[crystal]['sweeps'][sweep] = { }
        self._crystals[crystal]['sweeps'][sweep][
            'excluded_regions'] = []

        if start_end is not None:
          self._crystals[crystal]['sweeps'][sweep][
            'start_end'] = start_end

        # in here I expect to find IMAGE, DIRECTORY, WAVELENGTH
        # and optionally BEAM

        # FIXME 30/OCT/06 this may not be the case, for instance
        # if an INTEGRATED_REFLECTION_FILE record is in there...
        # c/f XProject.py, XSweep.py

        i += 1
        record = crystal_records[i]

        # populate this with interesting things
        while not 'END SWEEP' in record:
          # allow for WAVELENGTH_ID (bug # 2358)
          if 'WAVELENGTH_ID' == record.split()[0]:
            record = record.replace('WAVELENGTH_ID',
                                    'WAVELENGTH')

          if 'WAVELENGTH' == record.split()[0]:
            wavelength = record.replace('WAVELENGTH', '').strip()
            if not wavelength in self._crystals[crystal]['wavelengths'].keys():
              raise RuntimeError, \
                    'wavelength %s unknown for crystal %s' % \
                    (wavelength, crystal)
            self._crystals[crystal]['sweeps'][sweep]['wavelength'] = wavelength

          elif 'SAMPLE' == record.split()[0]:
            sample = record.replace('SAMPLE ', '').strip()
            if not sample in self._crystals[crystal]['samples'].keys():
              raise RuntimeError, \
                  'sample %s unknown for crystal %s' % (sample, crystal)
            self._crystals[crystal]['sweeps'][sweep]['sample'] = sample

          elif 'BEAM' == record.split()[0]:
            beam = map(float, record.split()[1:])
            self._crystals[crystal]['sweeps'][sweep]['beam'] = beam

          elif 'DISTANCE' == record.split()[0]:
            distance = float(record.split()[1])
            self._crystals[crystal]['sweeps'][sweep]['distance'] = distance

          elif 'EPOCH' == record.split()[0]:
            epoch = int(record.split()[1])
            self._crystals[crystal]['sweeps'][sweep]['epoch'] = epoch

          elif 'REVERSEPHI' == record.split()[0]:
            self._crystals[crystal]['sweeps'][sweep]['reversephi'] = True

          elif 'START_END' == record.split()[0]:
            if 'start_end' not in self._crystals[crystal]['sweeps'][sweep]:
              start_end = map(int, record.split()[1:])
              if len(start_end) != 2:
                raise RuntimeError, \
                      'START_END start end, not "%s"' % record
              self._crystals[crystal]['sweeps'][sweep]['start_end'] = start_end

          elif 'EXCLUDE' == record.split()[0]:
            if record.split()[1].upper() == 'ICE':
              self._crystals[crystal]['sweeps'][sweep]['ice'] = True
            else:
              excluded_region = map(float, record.split()[1:])
              if len(excluded_region) != 2:
                raise RuntimeError, \
                      'EXCLUDE upper lower, not "%s". \
                       eg. EXCLUDE 2.28 2.22' % record
              if excluded_region[0] <= excluded_region[1]:
                raise RuntimeError, \
                      'EXCLUDE upper lower, where upper \
                       must be greater than lower (not "%s").\n\
                       eg. EXCLUDE 2.28 2.22' % record
              self._crystals[crystal]['sweeps'][sweep]['excluded_regions'].append(
                excluded_region)

          else:
            key = record.split()[0]
            value = record.replace(key, '').strip()
            self._crystals[crystal]['sweeps'][sweep][key] = value

          i += 1
          record = crystal_records[i]

      # now look for one-record things

      if 'SCALED_MERGED_REFLECTION_FILE' in record:
        self._crystals[crystal][
            'scaled_merged_reflection_file'] = \
            record.replace('SCALED_MERGED_REFLECTION_FILE',
                           '').strip()

      if 'REFERENCE_REFLECTION_FILE' in record:
        self._crystals[crystal][
            'reference_reflection_file'] = \
            record.replace('REFERENCE_REFLECTION_FILE',
                           '').strip()

      if 'FREER_FILE' in record:

        # free file also needs to be used for indexing reference to
        # make any sense at all...

        self._crystals[crystal][
            'freer_file'] = record.replace('FREER_FILE', '').strip()
        self._crystals[crystal][
            'reference_reflection_file'] = \
            record.replace('FREER_FILE', '').strip()

      # user assigned spacegroup and cell constants
      if 'USER_SPACEGROUP' in record:
        self._crystals[crystal][
            'user_spacegroup'] = record.replace(
            'USER_SPACEGROUP', '').strip()

      if 'USER_CELL' in record:
        self._crystals[crystal][
            'user_cell'] = tuple(map(float, record.split()[1:]))