Esempio n. 1
0
class PickleSomeKind(ndb.Model):
    other = ndb.StructuredProperty(PickleOtherKind)

    @classmethod
    def _get_kind(cls):
        return "SomeKind"
Esempio n. 2
0
 class SomeKind(ndb.Model):
     foo = ndb.IntegerProperty()
     bar = ndb.StructuredProperty(OtherKind, repeated=True)
Esempio n. 3
0
class User(ndb.Model):
    first_name = ndb.StringProperty()
    last_name = ndb.StringProperty()
    email_address = ndb.StringProperty()

    # auth
    email_address_verified = ndb.BooleanProperty(default=False)
    password_hash = ndb.StringProperty()
    sessions = ndb.StructuredProperty(Session, repeated=True)
    csrf_tokens = ndb.StructuredProperty(
        CSRFToken, repeated=True)  # there should be max 10 CSRF tokens stored

    # magic login link
    magic_link_token_hash = ndb.StringProperty()
    magic_link_token_expired = ndb.DateTimeProperty()

    # status
    admin = ndb.BooleanProperty(default=False)
    suspended = ndb.BooleanProperty(
        default=False)  # if user is suspended, they cannot login

    # standard model fields
    created = ndb.DateTimeProperty(auto_now_add=True)
    updated = ndb.DateTimeProperty(auto_now=True)
    deleted = ndb.BooleanProperty(default=False)
    deleted_date = ndb.DateTimeProperty(
    )  # note that if the object gets "un-deleted", this date will stay here

    # properties (ordered by alphabet)
    @property
    def get_id(self):
        return self.key.id()

    # class methods (ordered by alphabet)
    @classmethod
    def create(cls,
               email_address,
               password=None,
               admin=False,
               first_name=None,
               last_name=None):
        with client.context():
            # check if there's any user with the same email address already
            user = cls.query(cls.email_address == email_address).get()

            if not user:  # if user does not yet exist, create one
                hashed = None
                if password:
                    # use bcrypt to hash the password
                    hashed = bcrypt.hashpw(password=str.encode(password),
                                           salt=bcrypt.gensalt(12))

                # create the user object and store it into Datastore
                user = cls(email_address=email_address,
                           password_hash=hashed,
                           admin=admin,
                           first_name=first_name,
                           last_name=last_name)
                user.put()

                return True, user, "Success"  # succes, user, message
            else:
                return False, user, "User with this email address is already registered. Please go to the " \
                                    "Login page and try to log in."

    @classmethod
    def delete_session(cls, user, token_hash_five_chars):
        with client.context():
            valid_sessions = []
            for session in user.sessions:
                # delete session that has token hash that starts with these 5 characters
                # (delete by not including in the new sessions list)
                if not session.token_hash.startswith(token_hash_five_chars):
                    valid_sessions.append(session)

            user.sessions = valid_sessions
            user.put()

        return user

    @classmethod
    def delete(cls, user, permanently=False):
        with client.context():
            if permanently:
                user.key.delete()  # this deletes user from Datastore
            else:
                user.deleted = True  # this does NOT delete user from Datastore (just marks it as "deleted")
                user.deleted_date = datetime.datetime.now()
                user.put()

        return True

    @classmethod
    def fetch(cls,
              email_address_verified=True,
              suspended=False,
              deleted=False,
              limit=None,
              cursor=None):
        with client.context():
            users, next_cursor, more = cls.query(
                cls.email_address_verified == email_address_verified,
                cls.suspended == suspended,
                cls.deleted == deleted).fetch_page(limit, start_cursor=cursor)

            if is_local():
                # this fixes the pagination bug which returns more=True even if less users than limit or if next_cursor
                # is the same as the cursor. This happens on localhost only.
                if limit and len(users) < limit:
                    return users, None, False

            try:
                return users, next_cursor.urlsafe().decode(), more
            except AttributeError as e:  # if there's no next_cursor, an AttributeError will occur
                return users, None, False

    @classmethod
    def generate_session_token(cls, user, request=None):
        with client.context():
            # generate session token and its hash
            token = secrets.token_hex()
            token_hash = hashlib.sha256(str.encode(token)).hexdigest()

            # create a session
            session = Session(token_hash=token_hash,
                              expired=(datetime.datetime.now() +
                                       datetime.timedelta(days=30)))
            if request:  # this separation is needed for tests which don't have the access to "request" variable
                session.ip = request.access_route[-1]
                session.platform = request.user_agent.platform
                session.browser = request.user_agent.browser
                session.user_agent = request.user_agent.string
                session.country = request.headers.get("X-AppEngine-Country")

            # store the session in the User model
            if not user.sessions:
                user.sessions = [session]
            else:
                valid_sessions = [session]
                for item in user.sessions:  # loop through sessions and remove the expired ones
                    if item.expired > datetime.datetime.now():
                        valid_sessions.append(item)

                user.sessions = valid_sessions  # now only non-expired sessions are stored in the User object

            user.put()

            return token

    @classmethod
    def get_user_by_email(cls, email_address):
        with client.context():
            user = cls.query(cls.email_address == email_address).get()
            return user

    @classmethod
    def get_user_by_password(cls, password_hash):
        with client.context():
            user = cls.query(cls.password_hash == password_hash).get()
            return user

    @classmethod
    def get_user_by_id(cls, user_id):
        with client.context():
            user = User.get_by_id(int(user_id))

            return user

    @classmethod
    def get_user_by_session_token(cls, session_token):
        """

        :param session_token:
        :return: success boolean (True/False), user object, message
        """
        with client.context():
            token_hash = hashlib.sha256(str.encode(session_token)).hexdigest()

            user = cls.query(cls.sessions.token_hash == token_hash).get()

            if not user:
                return False, None, "A user with this session token does not exist. Try to log in again."

            if user.deleted:
                logging.warning("Deleted user {} wanted to login.".format(
                    user.email_address))
                return False, None, "This user has been deleted. Please contact website administrators for more info."

            if user.suspended:
                logging.warning("Suspended user {} wanted to login.".format(
                    user.email_address))
                return False, None, "This user has been suspended. Please contact website administrators for more info."

            if not user.email_address_verified:
                logging.warning(
                    "User with unverified email address {} wanted to login.".
                    format(user.email_address))
                return False, None, "This user's email address hasn't yet been verified. Please contact website " \
                                    "administrators for more info."

            # important: you can't check for expiration in the cls.query() above, because it wouldn't only check the
            # expiration date of the session in question, but any expiration date which could give a false result
            for session in user.sessions:
                if session.token_hash == token_hash:
                    if session.expired > datetime.datetime.now():
                        return True, user, "Success"

            return False, None, "Unknown error."

    @classmethod
    def is_csrf_token_valid(cls, user, csrf_token):
        with client.context():
            token_validity = False

            unused_tokens = []
            for csrf in user.csrf_tokens:  # loop through user's CSRF tokens
                if csrf.token == csrf_token:  # if tokens match, set validity to True
                    token_validity = True
                else:
                    unused_tokens.append(
                        csrf
                    )  # if not, add CSRF token to the unused_tokens list

            if unused_tokens != user.csrf_tokens:
                user.csrf_tokens = unused_tokens
                user.put()

            return token_validity

    @classmethod
    def is_there_any_admin(cls):
        with client.context():
            admin = cls.query(cls.admin == True, cls.deleted == False).get()

            if admin:
                return True
            else:
                return False

    @classmethod
    def permanently_batch_delete(cls):
        # Permanently delete users that were marked as deleted=True more than 30 days ago
        with client.context():
            users_keys = cls.query(
                cls.deleted == True, cls.deleted_date <
                (datetime.datetime.now() - datetime.timedelta(days=30))).fetch(
                    keys_only=True)

            ndb.delete_multi(keys=users_keys)
            return True

    @classmethod
    def send_magic_login_link(cls, email_address, locale="en"):
        # generate magic link token and its hash
        token = secrets.token_hex()

        user = cls.get_user_by_email(email_address=email_address)

        with client.context():
            if user:
                user.magic_link_token_hash = hashlib.sha256(
                    str.encode(token)).hexdigest()
                user.magic_link_token_expired = datetime.datetime.now(
                ) + datetime.timedelta(hours=3)
                user.put()

                # send email with magic link to user
                send_email(
                    recipient_email=email_address,
                    email_template="emails/login-magic-link.html",
                    email_params={"magic_login_token": token},
                    email_subject=get_translation(
                        locale=locale,
                        translation_function="magic_link_email_subject"))

                return True, "Success"
            else:
                return False, "User with this email is not registered yet!"

    @classmethod
    def set_csrf_token(cls, user):
        with client.context():
            # first delete expired tokens from the CSRF tokens list in the user object
            valid_tokens = []
            for csrf in user.csrf_tokens:
                if csrf.expired > datetime.datetime.now():
                    valid_tokens.append(csrf)

            # check how many csrf tokens are still left in the User object (should be 10 or less)
            # if more than 10, delete the oldest one (with the closest expired date)
            if len(valid_tokens) >= 10:
                oldest_token = min(valid_tokens, key=attrgetter("expired"))
                valid_tokens.remove(oldest_token)

            # then create a new CSRF token and enter it in the tokens list
            token = secrets.token_hex()
            csrf_object = CSRFToken(token=token,
                                    expired=(datetime.datetime.now() +
                                             datetime.timedelta(hours=8)))
            valid_tokens.append(csrf_object)

            # finally, store the new tokens list back in the user model
            user.csrf_tokens = valid_tokens
            user.put()

            return token

    @classmethod
    def validate_magic_login_token(cls, magic_token, request=None):
        user = None

        with client.context():
            # convert token to hash
            magic_link_token_hash = hashlib.sha256(
                str.encode(magic_token)).hexdigest()

            # find user by this token
            user = cls.query(
                cls.magic_link_token_hash == magic_link_token_hash).get()

            # check if token hasn't expired yet
            if user and user.magic_link_token_expired > datetime.datetime.now(
            ):
                # if email_address is not verified yet, mark it as verified
                user.email_address_verified = True
                user.magic_link_token_expired = datetime.datetime.now(
                )  # make the token expired
                user.put()
            else:
                # if error, return False and message describing the problem
                return False, "The magic link is not valid or is expired. Please request a new one."

        # create session (this must be outside the "with client.context()", because context is already created in the
        # generate_session_token() method)
        session_token = cls.generate_session_token(user=user, request=request)

        # return True and session token for storing into cookie (in handler)
        return True, session_token

    # METHODS FOR TESTING PURPOSES ONLY!
    @classmethod
    def _test_mark_email_verified(cls, user):
        """
        FOR TESTING PURPOSES ONLY!
        :param user:
        :return:
        """
        with client.context():
            if is_local():
                user.email_address_verified = True
                user.put()

    @classmethod
    def _test_change_deleted_date(cls, user, new_date):
        """
        FOR TESTING PURPOSES ONLY!
        :param user:
        :param new_date:
        :return:
        """
        with client.context():
            if is_local():
                user.deleted_date = new_date
                user.put()

    @classmethod
    def login_password(cls, password):
        # generate magic link token and its hash
        token = secrets.token_hex()

        user = cls.get_user_by_password(password_hash=password)

        with client.context():
            if user:
                hexdigest = hashlib.sha256(str.encode(token)).hexdigest()
                user.magic_link_token_hash = hexdigest
                user.magic_link_token_expired = datetime.datetime.now(
                ) + datetime.timedelta(hours=3)
                user.put()
                return True, "Success"
            else:
                return False, "User with this password is not registered yet!"

    @classmethod
    def post(cls):  # Rok: changeing password
        with client.context:
            new_password = cls.get_or_insert("new-password")
            User.password_hash = new_password
            new_password.put()
Esempio n. 4
0
class Bug(ndb.Model):
  """Bug entity."""
  OSV_ID_PREFIX = 'OSV-'

  # Status of the bug.
  status = ndb.IntegerProperty()
  # Timestamp when Bug was allocated.
  timestamp = ndb.DateTimeProperty()
  # When the entry was last edited.
  last_modified = ndb.DateTimeProperty()
  # The source identifier.
  # For OSS-Fuzz, this oss-fuzz:<ClusterFuzz testcase ID>.
  # For others this is <source>:<path/to/source>.
  source_id = ndb.StringProperty()
  # Main repo url.
  repo_url = ndb.StringProperty()
  # The main fixed commit.
  fixed = ndb.StringProperty()
  # The main regressing commit.
  regressed = ndb.StringProperty()
  # Additional affected commit ranges derived from the main fixed and regressed
  # commits.
  additional_commit_ranges = ndb.StructuredProperty(CommitRange, repeated=True)
  # List of affected versions.
  affected = ndb.StringProperty(repeated=True)
  # List of normalized versions for fuzzy matching.
  affected_fuzzy = ndb.StringProperty(repeated=True)
  # OSS-Fuzz issue ID.
  issue_id = ndb.StringProperty()
  # Project/package name for the bug.
  project = ndb.StringProperty()
  # Package ecosystem for the project.
  ecosystem = ndb.StringProperty()
  # Summary for the bug.
  summary = ndb.StringProperty()
  # Vulnerability details.
  details = ndb.StringProperty()
  # Severity of the bug.
  severity = ndb.StringProperty(validator=_check_valid_severity)
  # Whether or not the bug is public (OSS-Fuzz only).
  public = ndb.BooleanProperty()
  # Reference URLs.
  reference_urls = ndb.StringProperty(repeated=True)
  # Search indices (auto-populated)
  search_indices = ndb.StringProperty(repeated=True)
  # Whether or not the bug has any affected versions (auto-populated).
  has_affected = ndb.BooleanProperty()
  # Sort key.
  sort_key = ndb.StringProperty()
  # Source of truth for this Bug.
  source_of_truth = ndb.IntegerProperty(default=SourceOfTruth.INTERNAL)

  def id(self):
    """Get the bug ID."""
    if re.match(r'^\d+', self.key.id()):
      return self.OSV_ID_PREFIX + self.key.id()

    return self.key.id()

  @classmethod
  def get_by_id(cls, vuln_id, *args, **kwargs):
    """Overridden get_by_id to handle OSV allocated IDs."""
    # OSV allocated bug IDs are stored without the prefix.
    if vuln_id.startswith(cls.OSV_ID_PREFIX):
      vuln_id = vuln_id[len(cls.OSV_ID_PREFIX):]

    return super().get_by_id(vuln_id, *args, **kwargs)

  def _pre_put_hook(self):
    """Pre-put hook for populating search indices."""
    self.search_indices = []
    if self.project:
      self.search_indices.append(self.project)

    key_parts = self.key.id().split('-')
    self.search_indices.append(self.key.id())
    self.search_indices.extend(key_parts)

    self.has_affected = bool(self.affected)
    self.affected_fuzzy = bug.normalize_tags(self.affected)

    self.sort_key = key_parts[0] + '-' + key_parts[1].zfill(7)
    if not self.last_modified:
      self.last_modified = utcnow()

  def update_from_vulnerability(self, vulnerability):
    """Set fields from vulnerability."""
    self.summary = vulnerability.summary
    self.details = vulnerability.details
    self.severity = (
        vulnerability_pb2.Vulnerability.Severity.Name(vulnerability.severity))
    self.reference_urls = list(vulnerability.references)
    self.last_modified = vulnerability.modified.ToDatetime()
    self.project = vulnerability.package.name
    self.ecosystem = vulnerability.package.ecosystem
    self.affected = list(vulnerability.affects.versions)

    found_first = False
    for affected_range in vulnerability.affects.ranges:
      if affected_range.type != vulnerability_pb2.AffectedRange.Type.GIT:
        continue

      if found_first:
        self.additional_commit_ranges.append(
            CommitRange(
                introduced_in=affected_range.introduced,
                fixed_in=affected_range.fixed))
      else:
        self.regressed = affected_range.introduced
        self.fixed = affected_range.fixed
        self.repo_url = affected_range.repo
        found_first = True

  def to_vulnerability(self):
    """Convert to Vulnerability proto."""
    package = vulnerability_pb2.Package(
        name=self.project, ecosystem=self.ecosystem)

    affects = vulnerability_pb2.Affects(versions=self.affected)
    affects.ranges.add(
        type=vulnerability_pb2.AffectedRange.Type.GIT,
        repo=self.repo_url,
        introduced=self.regressed,
        fixed=self.fixed)
    for additional_range in self.additional_commit_ranges:
      affects.ranges.add(
          type=vulnerability_pb2.AffectedRange.Type.GIT,
          repo=self.repo_url,
          introduced=additional_range.introduced_in,
          fixed=additional_range.fixed_in)

    if self.severity:
      severity = vulnerability_pb2.Vulnerability.Severity.Value(self.severity)
    else:
      severity = vulnerability_pb2.Vulnerability.Severity.NONE

    details = self.details
    if self.status == bug.BugStatus.INVALID:
      affects = None
      details = 'INVALID'
      severity = vulnerability_pb2.Vulnerability.Severity.NONE

    if self.last_modified:
      modified = timestamp_pb2.Timestamp()
      modified.FromDatetime(self.last_modified)
    else:
      modified = None

    created = timestamp_pb2.Timestamp()
    created.FromDatetime(self.timestamp)

    result = vulnerability_pb2.Vulnerability(
        id=self.id(),
        created=created,
        modified=modified,
        summary=self.summary,
        details=details,
        package=package,
        severity=severity,
        affects=affects,
        references=self.reference_urls)

    return result
Esempio n. 5
0
class Bug(ndb.Model):
    """Bug entity."""
    OSV_ID_PREFIX = 'OSV-'
    # Very large fake version to use when there is no fix available.
    _NOT_FIXED_SEMVER = '999999.999999.999999'

    # Display ID as used by the source database. The full qualified database that
    # OSV tracks this as may be different.
    db_id = ndb.StringProperty()
    # Other IDs this bug is known as.
    aliases = ndb.StringProperty(repeated=True)
    # Related IDs.
    related = ndb.StringProperty(repeated=True)
    # Status of the bug.
    status = ndb.IntegerProperty()
    # Timestamp when Bug was allocated.
    timestamp = ndb.DateTimeProperty()
    # When the entry was last edited.
    last_modified = ndb.DateTimeProperty()
    # When the entry was withdrawn.
    withdrawn = ndb.DateTimeProperty()
    # The source identifier.
    # For OSS-Fuzz, this oss-fuzz:<ClusterFuzz testcase ID>.
    # For others this is <source>:<path/to/source>.
    source_id = ndb.StringProperty()
    # The main fixed commit (from bisection).
    fixed = ndb.StringProperty(default='')
    # The main regressing commit (from bisection).
    regressed = ndb.StringProperty(default='')
    # All affected ranges. TODO(ochang): To be removed.
    affected_ranges = ndb.StructuredProperty(AffectedRange, repeated=True)
    # List of affected versions.
    affected = ndb.TextProperty(repeated=True)
    # List of normalized versions indexed for fuzzy matching.
    affected_fuzzy = ndb.StringProperty(repeated=True)
    # OSS-Fuzz issue ID.
    issue_id = ndb.StringProperty()
    # Package URL for this package.
    purl = ndb.StringProperty(repeated=True)
    # Project/package name for the bug.
    project = ndb.StringProperty(repeated=True)
    # Package ecosystem for the project.
    ecosystem = ndb.StringProperty(repeated=True)
    # Summary for the bug.
    summary = ndb.TextProperty()
    # Vulnerability details.
    details = ndb.TextProperty()
    # Severity of the bug.
    severity = ndb.StringProperty(validator=_check_valid_severity)
    # Whether or not the bug is public (OSS-Fuzz only).
    public = ndb.BooleanProperty()
    # Reference URL types (dict of url -> type).
    reference_url_types = ndb.JsonProperty()
    # Search indices (auto-populated)
    search_indices = ndb.StringProperty(repeated=True)
    # Whether or not the bug has any affected versions (auto-populated).
    has_affected = ndb.BooleanProperty()
    # Source of truth for this Bug.
    source_of_truth = ndb.IntegerProperty(default=SourceOfTruth.INTERNAL)
    # Whether the bug is fixed (indexed for querying).
    is_fixed = ndb.BooleanProperty()
    # Database specific.
    database_specific = ndb.JsonProperty()
    # Ecosystem specific.
    ecosystem_specific = ndb.JsonProperty()
    # Normalized SEMVER fixed indexes for querying.
    semver_fixed_indexes = ndb.StringProperty(repeated=True)
    # Affected packages and versions.
    affected_packages = ndb.LocalStructuredProperty(AffectedPackage,
                                                    repeated=True)
    # The source of this Bug.
    source = ndb.StringProperty()

    def id(self):
        """Get the bug ID."""
        if self.db_id:
            return self.db_id

        # TODO(ochang): Remove once all existing bugs have IDs migrated.
        if re.match(r'^\d+', self.key.id()):
            return self.OSV_ID_PREFIX + self.key.id()

        return self.key.id()

    @property
    def repo_url(self):
        """Repo URL."""
        for affected_package in self.affected_packages:
            for affected_range in affected_package.ranges:
                if affected_range.repo_url:
                    return affected_range.repo_url

        return None

    @classmethod
    def get_by_id(cls, vuln_id, *args, **kwargs):
        """Overridden get_by_id to handle OSV allocated IDs."""
        result = cls.query(cls.db_id == vuln_id).get()
        if result:
            return result

        # TODO(ochang): Remove once all exsting bugs have IDs migrated.
        if vuln_id.startswith(cls.OSV_ID_PREFIX):
            vuln_id = vuln_id[len(cls.OSV_ID_PREFIX):]

        return super().get_by_id(vuln_id, *args, **kwargs)

    def _tokenize(self, value):
        """Tokenize value for indexing."""
        if not value:
            return []

        value_lower = value.lower()
        return re.split(r'\W+', value_lower) + [value_lower]

    def _pre_put_hook(self):
        """Pre-put hook for populating search indices."""
        search_indices = set()

        search_indices.update(self._tokenize(self.id()))

        if self.affected_packages:
            self.project = [
                pkg.package.name for pkg in self.affected_packages
                if pkg.package.name
            ]
            self.ecosystem = [
                pkg.package.ecosystem for pkg in self.affected_packages
                if pkg.package.ecosystem
            ]
            self.purl = [
                pkg.package.purl for pkg in self.affected_packages
                if pkg.package.purl
            ]

            for project in self.project:
                search_indices.update(self._tokenize(project))

            for ecosystem in self.ecosystem:
                search_indices.update(self._tokenize(ecosystem))

        self.search_indices = sorted(list(search_indices))

        self.affected_fuzzy = []
        self.semver_fixed_indexes = []
        self.has_affected = False
        self.is_fixed = False

        for affected_package in self.affected_packages:
            # Indexes used for querying by exact version.
            self.affected_fuzzy.extend(
                bug.normalize_tags(affected_package.versions))
            self.has_affected |= bool(affected_package.versions)

            for affected_range in affected_package.ranges:
                fixed_version = None
                for event in affected_range.events:
                    # Index used to query by fixed/unfixed.
                    if event.type == 'fixed':
                        self.is_fixed = True
                        fixed_version = event.value

                if affected_range.type == 'SEMVER':
                    # Indexes used for querying by semver.
                    fixed = fixed_version or self._NOT_FIXED_SEMVER
                    self.semver_fixed_indexes.append(
                        semver_index.normalize(fixed))

                self.has_affected |= (affected_range.type
                                      in ('SEMVER', 'ECOSYSTEM'))

        if not self.last_modified:
            self.last_modified = utcnow()

        if self.source_id:
            self.source, _ = sources.parse_source_id(self.source_id)

        if not self.source:
            raise ValueError('Source not specified for Bug.')

        if not self.db_id:
            raise ValueError('DB ID not specified for Bug.')

        if not self.key:  # pylint: disable=access-member-before-definition
            source_repo = get_source_repository(self.source)
            if not source_repo:
                raise ValueError(f'Invalid source {self.source}')

            if source_repo.db_prefix and self.db_id.startswith(
                    source_repo.db_prefix):
                key_id = self.db_id
            else:
                key_id = f'{self.source}:{self.db_id}'

            self.key = ndb.Key(Bug, key_id)

    def _update_from_pre_0_8(self, vulnerability):
        """Update from pre 0.8 import."""
        if self.affected_packages:
            affected_package = self.affected_packages[0]
        else:
            affected_package = AffectedPackage()
            self.affected_packages.append(affected_package)

        affected_package.package = Package(
            name=vulnerability.package.name,
            ecosystem=vulnerability.package.ecosystem,
            purl=vulnerability.package.purl)

        vuln_dict = sources.vulnerability_to_dict(vulnerability)
        if vulnerability.database_specific:
            affected_package.database_specific = vuln_dict['database_specific']

        if vulnerability.ecosystem_specific:
            affected_package.ecosystem_specific = vuln_dict[
                'ecosystem_specific']

        affected_package.versions = list(vulnerability.affects.versions)
        affected_package.ranges = []
        events_by_type = {}

        for affected_range in vulnerability.affects.ranges:
            events = events_by_type.setdefault(
                (vulnerability_pb2.AffectedRange.Type.Name(
                    affected_range.type), affected_range.repo), [])

            # An empty introduced in 0.7 now needs to be represented as '0' in 0.8.
            introduced = AffectedEvent(type='introduced',
                                       value=affected_range.introduced or '0')
            if introduced not in events:
                events.append(introduced)

            if affected_range.fixed:
                fixed = AffectedEvent(type='fixed', value=affected_range.fixed)
                if fixed not in events:
                    events.append(fixed)

        for (range_type, repo_url), events in events_by_type.items():
            affected_range = AffectedRange2(type=range_type, events=events)

            if range_type == 'GIT' and repo_url:
                affected_range.repo_url = repo_url

            affected_package.ranges.append(affected_range)

    def update_from_vulnerability(self, vulnerability):
        """Set fields from vulnerability. Does not set the ID."""
        self.summary = vulnerability.summary
        self.details = vulnerability.details
        self.reference_url_types = {
            ref.url: vulnerability_pb2.Reference.Type.Name(ref.type)
            for ref in vulnerability.references
        }

        if vulnerability.HasField('modified'):
            self.last_modified = vulnerability.modified.ToDatetime()
        if vulnerability.HasField('published'):
            self.timestamp = vulnerability.published.ToDatetime()
        if vulnerability.HasField('withdrawn'):
            self.withdrawn = vulnerability.withdrawn.ToDatetime()

        self.aliases = list(vulnerability.aliases)
        self.related = list(vulnerability.related)

        if not vulnerability.affected:
            self._update_from_pre_0_8(vulnerability)
            return

        self.affected_packages = []
        for affected_package in vulnerability.affected:
            current = AffectedPackage()
            current.package = Package(
                name=affected_package.package.name,
                ecosystem=affected_package.package.ecosystem,
                purl=affected_package.package.purl)
            current.ranges = []

            for affected_range in affected_package.ranges:
                current_range = AffectedRange2(
                    type=vulnerability_pb2.Range.Type.Name(
                        affected_range.type),
                    repo_url=affected_range.repo,
                    events=[])

                for evt in affected_range.events:
                    if evt.introduced:
                        current_range.events.append(
                            AffectedEvent(type='introduced',
                                          value=evt.introduced))
                        continue

                    if evt.fixed:
                        current_range.events.append(
                            AffectedEvent(type='fixed', value=evt.fixed))
                        continue

                    if evt.limit:
                        current_range.events.append(
                            AffectedEvent(type='limit', value=evt.limit))
                        continue

                current.ranges.append(current_range)

            current.versions = list(affected_package.versions)
            if affected_package.database_specific:
                current.database_specific = json_format.MessageToDict(
                    affected_package.database_specific,
                    preserving_proto_field_name=True)

            if affected_package.ecosystem_specific:
                current.ecosystem_specific = json_format.MessageToDict(
                    affected_package.ecosystem_specific,
                    preserving_proto_field_name=True)

            self.affected_packages.append(current)

    def _get_pre_0_8_affects(self):
        """Get pre 0.8 schema affects field."""
        affected_package = self.affected_packages[0]
        affects = vulnerability_pb2.Affects(versions=affected_package.versions)
        for affected_range in affected_package.ranges:
            # Convert flattened events to range pairs (pre-0.8 schema).
            # TODO(ochang): Remove this once all consumers are migrated.
            # pylint: disable=cell-var-from-loop
            new_range = lambda x, y: vulnerability_pb2.AffectedRange(
                type=vulnerability_pb2.AffectedRange.Type.Value(affected_range.
                                                                type),
                repo=affected_range.repo_url,
                introduced=x,
                fixed=y)
            last_introduced = None

            # Sort the flattened events, then find corresponding [introduced,
            # fixed) pairs.
            for event in sorted_events(affected_package.package.ecosystem,
                                       affected_range.type,
                                       affected_range.events):
                if event.type == 'introduced':
                    if last_introduced is not None and affected_range.type == 'GIT':
                        # If this is GIT, then we need to store all "introduced", even if
                        # they overlap.
                        affects.ranges.append(new_range(last_introduced, ''))
                        last_introduced = None

                    if last_introduced is None:
                        # If not GIT, ignore overlapping introduced versions since they're
                        # redundant.
                        last_introduced = event.value
                        if last_introduced == '0':
                            last_introduced = ''

                if event.type == 'fixed':
                    if affected_range.type != 'GIT' and last_introduced is None:
                        # No prior introduced, so ignore this invalid entry.
                        continue

                    # Found a complete pair.
                    affects.ranges.append(
                        new_range(last_introduced, event.value))
                    last_introduced = None

            if last_introduced is not None:
                affects.ranges.append(new_range(last_introduced, ''))

        return affects

    def to_vulnerability(self, include_source=False, v0_7=True, v0_8=False):
        """Convert to Vulnerability proto."""
        package = None
        ecosystem_specific = None
        database_specific = None
        affected = []
        affects = None

        source_link = None
        if self.source and include_source:
            source_repo = get_source_repository(self.source)
            if source_repo and source_repo.link:
                source_link = source_repo.link + sources.source_path(
                    source_repo, self)

        if self.affected_packages:
            if v0_7:
                # The pre-0.8 schema only supports a single package, so we take the
                # first.
                affected_package = self.affected_packages[0]

                package = vulnerability_pb2.Package(
                    name=affected_package.package.name,
                    ecosystem=affected_package.package.ecosystem,
                    purl=affected_package.package.purl)

                affects = self._get_pre_0_8_affects()
                if affected_package.ecosystem_specific:
                    ecosystem_specific = affected_package.ecosystem_specific
                if affected_package.database_specific:
                    database_specific = affected_package.database_specific

            if v0_8:
                for affected_package in self.affected_packages:
                    ranges = []
                    for affected_range in affected_package.ranges:
                        events = []
                        for event in affected_range.events:
                            kwargs = {event.type: event.value}
                            events.append(vulnerability_pb2.Event(**kwargs))

                        current_range = vulnerability_pb2.Range(
                            type=vulnerability_pb2.Range.Type.Value(
                                affected_range.type),
                            repo=affected_range.repo_url,
                            events=events)

                        ranges.append(current_range)

                    current = vulnerability_pb2.Affected(
                        package=vulnerability_pb2.Package(
                            name=affected_package.package.name,
                            ecosystem=affected_package.package.ecosystem,
                            purl=affected_package.package.purl),
                        ranges=ranges,
                        versions=affected_package.versions)

                    if affected_package.database_specific:
                        current.database_specific.update(
                            affected_package.database_specific)

                    if source_link:
                        current.database_specific.update(
                            {'source': source_link})

                    if affected_package.ecosystem_specific:
                        current.ecosystem_specific.update(
                            affected_package.ecosystem_specific)

                    affected.append(current)

        details = self.details
        if self.status == bug.BugStatus.INVALID:
            affects = None
            affected = None
            details = 'INVALID'

        if self.last_modified:
            modified = timestamp_pb2.Timestamp()
            modified.FromDatetime(self.last_modified)
        else:
            modified = None

        if self.withdrawn:
            withdrawn = timestamp_pb2.Timestamp()
            withdrawn.FromDatetime(self.withdrawn)
        else:
            withdrawn = None

        published = timestamp_pb2.Timestamp()
        published.FromDatetime(self.timestamp)

        references = []
        if self.reference_url_types:
            for url, url_type in self.reference_url_types.items():
                references.append(
                    vulnerability_pb2.Reference(
                        url=url,
                        type=vulnerability_pb2.Reference.Type.Value(url_type)))

        result = vulnerability_pb2.Vulnerability(id=self.id(),
                                                 published=published,
                                                 modified=modified,
                                                 aliases=self.aliases,
                                                 related=self.related,
                                                 withdrawn=withdrawn,
                                                 summary=self.summary,
                                                 details=details,
                                                 package=package,
                                                 affects=affects,
                                                 affected=affected,
                                                 references=references)

        if ecosystem_specific:
            result.ecosystem_specific.update(ecosystem_specific)

        if database_specific:
            result.database_specific.update(database_specific)

        if source_link:
            result.database_specific.update({'source': source_link})

        return result
Esempio n. 6
0
 class SomeKind(ndb.Model):
     entry = ndb.StructuredProperty(OtherKind)
Esempio n. 7
0
 class SomeKind(ndb.Model):
     sub_model = ndb.StructuredProperty(OtherKind)