Beispiel #1
0
def _UpdateHostConfigByInventoryData(lab_name, data):
  """Updates HostConfig inventory_groups.

  Args:
    lab_name: the lab name.
    data: the InventoryData object.
  """
  keys = []
  entities_from_file = {
      host.name: _CreateHostConfigEntityFromHostInventory(lab_name, host)
      for host in data.hosts.values()
  }
  for hostname in entities_from_file:
    keys.append(ndb.Key(datastore_entities.HostConfig, hostname))
  entities_from_ndb = {}
  for entity in ndb.get_multi(keys):
    if entity:
      entities_from_ndb[entity.hostname] = entity
  need_update = {}
  for hostname in entities_from_file:
    old = entities_from_ndb.get(hostname)
    new = entities_from_file[hostname]
    if not old:
      logging.debug('Creating host config %s', new)
      need_update[hostname] = new
    elif old.inventory_groups != new.inventory_groups:
      old.inventory_groups = new.inventory_groups
      logging.debug('Updating host config %s', old)
      need_update[hostname] = old
  ndb.put_multi(need_update.values())
    def _CheckAccessibilityForHost(self, user_name, host_config, host_account):
        """Checks if the user can access the given host account on the given host.

    Args:
      user_name: the user name.
      host_config: the HostConfig entity for the host that user requested to
        access.
      host_account: the host account that the user want to login.

    Returns:
      the AclCheckResult api message.
    """
        # every host belong to the "all" group in ansible, thus we should also
        # check accessibility for the "all" group.
        # https://docs.ansible.com/ansible/latest/user_guide/intro_inventory.html#default-groups
        keys = []
        for group_name in ["all"] + host_config.inventory_groups:
            keys.append(
                ndb.Key(
                    datastore_entities.HostGroupConfig,
                    datastore_entities.HostGroupConfig.CreateId(
                        host_config.lab_name, group_name)))
        groups = ndb.get_multi(keys)
        for group in groups:
            if not group:
                continue
            if self._CheckAccessibilityForHostAccountInHostGroup(
                    user_name, host_account, group):
                return api_messages.AclCheckResult(has_access=True)
        return api_messages.AclCheckResult(has_access=False)
Beispiel #3
0
  def BatchDeleteNotes(self, request):
    """Delete notes of a host.

    Args:
      request: an API request.
    Request Params:
      hostname: string, the name of a lab host.
      ids: a list of strings, the ids of notes to delete.

    Returns:
      a message_types.VoidMessage object.

    Raises:
      endpoints.BadRequestException, when request does not match existing notes.
    """
    keys = [
        ndb.Key(datastore_entities.Note, entity_id)
        for entity_id in request.ids
    ]
    note_entities = ndb.get_multi(keys)
    for key, note_entity in zip(keys, note_entities):
      if not note_entity or note_entity.hostname != request.hostname:
        raise endpoints.BadRequestException(
            "Note<id:{0}> does not exist under host<{1}>.".format(
                key.id(), note_entity.hostname))
    for key in keys:
      key.delete()
    return message_types.VoidMessage()
Beispiel #4
0
def _UpdateHostConfigs(host_config_pbs, cluster_config_pb, lab_config_pb):
  """Update host configs in HostInfo entities to ndb.

  Args:
    host_config_pbs: a list of host config protos.
    cluster_config_pb: the cluster config proto.
    lab_config_pb: the lab config proto.
  """
  logging.debug('Updating host configs for <lab: %s, cluster: %s.>',
                lab_config_pb.lab_name, cluster_config_pb.cluster_name)
  host_config_keys = []
  for host_config_pb in host_config_pbs:
    host_config_keys.append(
        ndb.Key(datastore_entities.HostConfig, host_config_pb.hostname))
  entities = ndb.get_multi(host_config_keys)
  entities_to_update = []
  # Update the exist host config entity.
  for entity, host_config_pb in zip(entities, host_config_pbs):
    host_config_msg = lab_config_util.HostConfig(
        host_config_pb, cluster_config_pb, lab_config_pb)
    new_host_config_entity = datastore_entities.HostConfig.FromMessage(
        host_config_msg)
    if not _CheckConfigEntitiesEqual(entity, new_host_config_entity):
      logging.debug('Updating host config entity: %s.', new_host_config_entity)
      entities_to_update.append(new_host_config_entity)
  ndb.put_multi(entities_to_update)
  logging.debug('Host configs updated.')
Beispiel #5
0
def _UpdateClusterConfigs(cluster_configs):
  """Update cluster configs in ClusterInfo entity to ndb.

  Args:
    cluster_configs: a list of cluster configs proto.
  """
  logging.debug('Updating cluster configs.')
  cluster_config_keys = set()
  for cluster_config in cluster_configs:
    cluster_config_keys.add(
        ndb.Key(datastore_entities.ClusterConfig, cluster_config.cluster_name))
  entities = ndb.get_multi(cluster_config_keys)
  name_to_cluster = {}
  for entity in entities:
    if entity:
      name_to_cluster[entity.cluster_name] = entity

  name_to_entity = {}
  for cluster_config in cluster_configs:
    new_config_entity = datastore_entities.ClusterConfig.FromMessage(
        cluster_config)
    if not _CheckConfigEntitiesEqual(
        name_to_cluster.get(cluster_config.cluster_name), new_config_entity):
      logging.debug('Updating cluster config entity: %s.', new_config_entity)
      if cluster_config.cluster_name in name_to_entity:
        logging.warning(
            '%s has duplicated configs.', cluster_config.cluster_name)
      name_to_entity[cluster_config.cluster_name] = new_config_entity
  ndb.put_multi(name_to_entity.values())
  logging.debug('Cluster configs updated.')
def GetTasks(task_ids):
    """Get tasks in task store.

  Args:
    task_ids: a list of task ids.
  Returns:
    a list of task entities.
  """
    keys = [_Key(task_id) for task_id in task_ids]
    return ndb.get_multi(keys)
Beispiel #7
0
def _ValidateTestPlan(test_plan):
    """Check validity of a given test plan.

  Args:
    test_plan: a ndb_models.TestPlan object.
  """
    if test_plan.cron_exp and not _IsValidCronExpression(test_plan.cron_exp):
        raise endpoints.BadRequestException((
            'Invalid cron expression (%s). Cron expression should be ordered as '
            'minute, hour, day of month, month, day of week.') %
                                            test_plan.cron_exp)

    plan_device_actions = ndb.get_multi(test_plan.before_device_action_keys)
    if not all(plan_device_actions):
        raise endpoints.BadRequestException(
            'Cannot find some device actions: %s -> %s' %
            (test_plan.before_device_action_keys, plan_device_actions))

    for config in test_plan.test_run_configs:
        if not config.test_key.get():
            raise endpoints.BadRequestException('test %s does not exist' %
                                                config.test_key.id())

        config_device_actions = ndb.get_multi(config.before_device_action_keys)
        if not all(config_device_actions):
            raise endpoints.BadRequestException(
                'Cannot find some device actions: %s -> %s' %
                (config.before_device_action_keys, config_device_actions))
        test_kicker.ValidateDeviceActions(plan_device_actions +
                                          config_device_actions)

    for pipe in test_plan.test_resource_pipes:
        build_locator = build.BuildLocator.ParseUrl(pipe.url)
        if build_locator and not mtt_messages.ConvertToKey(
                ndb_models.BuildChannelConfig,
                build_locator.build_channel_id).get():
            raise endpoints.BadRequestException(
                'build channel %s does not exist' %
                build_locator.build_channel_id)
Beispiel #8
0
 def _ListLabs(self, request):
   """ListLabs without owner filter. Some labs don't have config."""
   query = datastore_entities.LabInfo.query()
   query = query.order(datastore_entities.LabInfo.key)
   labs, prev_cursor, next_cursor = datastore_util.FetchPage(
       query, request.count, page_cursor=request.cursor)
   lab_config_keys = [
       ndb.Key(datastore_entities.LabConfig, lab.lab_name) for lab in labs]
   lab_configs = ndb.get_multi(lab_config_keys)
   lab_infos = [datastore_entities.ToMessage(lab, lab_config)
                for lab, lab_config in zip(labs, lab_configs)]
   return api_messages.LabInfoCollection(
       lab_infos=lab_infos,
       more=bool(next_cursor),
       next_cursor=next_cursor,
       prev_cursor=prev_cursor)
def _DoUpdateDevicesInNDB(reported_devices, event):
  """Update device entities to ndb.

  Args:
    reported_devices: device serial to device data mapping.
    event: the event have hostname, cluster info and timestamp.
  """
  entities_to_update = []
  device_keys = []
  for device_serial in reported_devices.keys():
    device_key = ndb.Key(
        datastore_entities.HostInfo, event.hostname,
        datastore_entities.DeviceInfo, device_serial)
    device_keys.append(device_key)
  # If the device doesn't exist, the corresponding entry will be None.
  devices = ndb.get_multi(device_keys)
  for device, device_key in zip(devices, device_keys):
    entities_to_update.extend(
        _UpdateDeviceInNDB(
            device, device_key, reported_devices.get(device_key.id()),
            event))
  ndb.put_multi(entities_to_update)
def _DoUpdateGoneDevicesInNDB(missing_device_keys, timestamp):
  """Do update gone devices in NDB within transactional."""
  entities_to_update = []
  devices = ndb.get_multi(missing_device_keys)
  for device in devices:
    if device.timestamp and device.timestamp > timestamp:
      logging.debug("Ignore outdated event.")
      continue
    if (device.state == common.DeviceState.GONE and
        device.timestamp and
        device.timestamp >= timestamp - ONE_MONTH):
      logging.debug("Ignore gone device.")
      continue
    if device.timestamp and device.timestamp < timestamp - ONE_MONTH:
      device.hidden = True
      device.timestamp = timestamp
    device_state_history, device_history = _UpdateDeviceState(
        device, common.DeviceState.GONE, timestamp)
    entities_to_update.append(device)
    if device_state_history:
      entities_to_update.append(device_state_history)
    if device_history:
      entities_to_update.append(device_history)
  ndb.put_multi(entities_to_update)
Beispiel #11
0
  def BatchGetNotes(self, request):
    """Batch get notes of a host.

    Args:
      request: an API request.
    Request Params:
      hostname: string, the name of a lab host.
      ids: a list of strings, the ids of notes to batch get.

    Returns:
      an api_messages.NoteCollection object.
    """
    keys = [
        ndb.Key(datastore_entities.Note, entity_id)
        for entity_id in request.ids
    ]
    note_entities = ndb.get_multi(keys)
    note_msgs = [
        datastore_entities.ToMessage(entity)
        for entity in note_entities
        if entity and entity.hostname == request.hostname
    ]
    return api_messages.NoteCollection(
        notes=note_msgs, more=False, next_cursor=None, prev_cursor=None)
Beispiel #12
0
def CreateTestRun(labels,
                  test_run_config,
                  test_plan_key=None,
                  rerun_context=None,
                  rerun_configs=None,
                  sequence_id=None):
  """Creates a test run.

  Args:
    labels: labels for the test run
    test_run_config: a ndb_models.TestRunConfig object.
    test_plan_key: a ndb_models.TestPlan key.
    rerun_context: rerun parameters containing parent ID or context filename.
    rerun_configs: a list of configs to use for reruns
    sequence_id: id of the sequence the run should belong to
  Returns:
    a ndb_models.TestRun object.
  Raises:
    ValueError: some of given parameters are invalid.
    errors.TestResourceError: if some of test resources don't have URLs.
  """
  # Set defaults for null test_run_config fields for backward compatibility.
  if test_run_config.use_parallel_setup is None:
    test_run_config.use_parallel_setup = True

  test = test_run_config.test_key.get()
  if not test:
    raise ValueError('cannot find test %s' % test_run_config.test_key)

  if test_run_config.sharding_mode == ndb_models.ShardingMode.MODULE:
    if (not test.module_config_pattern or not test.module_execution_args):
      raise ValueError(
          'test "%s" does not support module sharding: '
          'module_config_pattern or module_exeuction_args not defined' % (
              test.name))
    test_package_urls = [
        r.cache_url
        for r in test_run_config.test_resource_objs
        if r.test_resource_type == ndb_models.TestResourceType.TEST_PACKAGE]
    if not test_package_urls:
      raise ValueError(
          'cannot use module sharding: '
          'no test package is found in test resources')

  node_config = ndb_models.GetNodeConfig()
  test.env_vars.extend(node_config.env_vars)

  # Store snapshot of test run's actions
  before_device_actions = ndb.get_multi(
      test_run_config.before_device_action_keys)

  if not all(before_device_actions):
    raise ValueError(
        'Cannot find some device actions: %s -> %s' % (
            test_run_config.before_device_action_keys, before_device_actions))
  ValidateDeviceActions(before_device_actions)

  test_run_actions = [
      ref.ToAction() for ref in test_run_config.test_run_action_refs
  ]

  test_resource_defs = test.test_resource_defs[:]
  for device_action in before_device_actions:
    test_resource_defs += device_action.test_resource_defs
  test_resource_map = _ConvertToTestResourceMap(test_resource_defs)
  # Override test resource URLs based on
  # node_config.test_resource_default_download_urls.
  for pair in node_config.test_resource_default_download_urls:
    if pair.name not in test_resource_map:
      continue
    test_resource_map[pair.name].url = pair.value
  test_resources = build.FindTestResources(test_run_config.test_resource_objs)
  for r in test_resources:
    if r.name not in test_resource_map:
      logging.warning(
          'Test resource %s is not needed for this test run; ignoring', r)
      continue
    test_resource_map[r.name].url = r.url
    test_resource_map[r.name].cache_url = r.cache_url
  # Check every test resource has a valid URL.
  for test_resource in six.itervalues(test_resource_map):
    if not test_resource.url:
      raise errors.TestResourceError(
          'No URL for test resource %s' % test_resource.name)

  # Determine previous test context from previous test run
  prev_test_run_key, prev_test_context = _GetRerunInfo(test, rerun_context)

  if rerun_configs:
    if sequence_id:
      raise ValueError(
          'Cannot create test run with both sequence id %s and rerun configs %s'
          % (sequence_id, rerun_configs))
    rerun_configs.insert(0, test_run_config)
    sequence_id = str(uuid.uuid4())
    test_run_sequence = ndb_models.TestRunSequence(
        state=ndb_models.TestRunSequenceState.RUNNING,
        test_run_configs=rerun_configs,
        finished_test_run_ids=[])
    test_run_sequence.key = ndb.Key(ndb_models.TestRunSequence, sequence_id)
    test_run_sequence.put()

  # Create and enqueue test run
  test_run = ndb_models.TestRun(
      id=str(uuid.uuid4()),
      prev_test_run_key=prev_test_run_key,
      labels=labels,
      test_plan_key=test_plan_key,
      test=test,
      test_run_config=test_run_config,
      test_resources=list(test_resource_map.values()),
      prev_test_context=prev_test_context,
      state=ndb_models.TestRunState.PENDING,
      before_device_actions=before_device_actions,
      test_run_actions=test_run_actions,
      sequence_id=sequence_id)
  test_run.put()
  test_run_id = test_run.key.id()
  logging.info('Test run %s created', test_run_id)
  event_log.Info(test_run, 'Test run created')
  EnqueueTestRun(test_run_id)
  return test_run
    def testSyncHarnessImageMetadata_NoExistingEntity(self, mock_requests,
                                                      mock_util_datetime,
                                                      mock_syncer_datetime,
                                                      mock_auth):
        """Test sync harness image metadata."""
        time_now = datetime.datetime(2020, 12, 24)
        time_created = datetime.datetime(2020, 12, 10)
        time_created_ms = str(
            int((time_created - datetime.datetime(1970, 1, 1)).total_seconds()
                * 1000))

        mock_util_datetime.datetime.utcnow.return_value = time_now
        mock_syncer_datetime.datetime.utcnow.return_value = time_now
        mock_syncer_datetime.datetime.utcfromtimestamp = (
            datetime.datetime.utcfromtimestamp)
        mock_requests.get().json.return_value = {
            'manifest': {
                'sha1': {
                    'tag': [
                        '111111',
                        'golden',
                        'canary',
                        'golden_tradefed_image_20201210_1200_RC00',
                    ],
                    'timeCreatedMs':
                    time_created_ms,
                },
                'sha2': {
                    'tag': [
                        '2222222',
                        'golden_tradefed_image_20201210_0600_RC00',
                    ],
                    'timeCreatedMs': time_created_ms,
                },
                'sha3': {
                    'tag': [
                        '3333333',
                        'staging',
                    ],
                    'timeCreatedMs': time_created_ms,
                },
            }
        }

        harness_image_metadata_syncer.SyncHarnessImageMetadata()

        keys = [
            ndb.Key(datastore_entities.TestHarnessImageMetadata,
                    'gcr.io/dockerized-tradefed/tradefed:sha1'),
            ndb.Key(datastore_entities.TestHarnessImageMetadata,
                    'gcr.io/dockerized-tradefed/tradefed:sha2'),
            ndb.Key(datastore_entities.TestHarnessImageMetadata,
                    'gcr.io/dockerized-tradefed/tradefed:sha3'),
        ]
        entity_1, entity_2, entity_3 = ndb.get_multi(keys)

        self.assertEqual('sha1', entity_1.digest)
        self.assertEqual('111111', entity_1.test_harness_version)
        self.assertEqual(time_created, entity_1.create_time)
        self.assertEqual(time_now, entity_1.sync_time)
        self.assertCountEqual([
            '111111', 'golden', 'canary',
            'golden_tradefed_image_20201210_1200_RC00'
        ], entity_1.current_tags)
        self.assertCountEqual(['golden'], entity_1.historical_tags)

        self.assertEqual('sha2', entity_2.digest)
        self.assertEqual('2222222', entity_2.test_harness_version)
        self.assertEqual(time_created, entity_2.create_time)
        self.assertEqual(time_now, entity_2.sync_time)
        self.assertCountEqual([
            '2222222',
            'golden_tradefed_image_20201210_0600_RC00',
        ], entity_2.current_tags)
        self.assertCountEqual(['golden'], entity_2.historical_tags)

        self.assertEqual('sha3', entity_3.digest)
        self.assertEqual('3333333', entity_3.test_harness_version)
        self.assertEqual(time_created, entity_3.create_time)
        self.assertEqual(time_now, entity_3.sync_time)
        self.assertCountEqual(['3333333', 'staging'], entity_3.current_tags)
        self.assertEmpty(entity_3.historical_tags)
Beispiel #14
0
  def ListHosts(self, request):
    """Fetches a list of hosts.

    Args:
      request: an API request.

    Returns:
      a HostInfoCollection object.
    """
    if ((request.timestamp and not request.timestamp_operator) or
        (not request.timestamp and request.timestamp_operator)):
      raise endpoints.BadRequestException(
          '"timestamp" and "timestamp_operator" must be set at the same time.')
    query = datastore_entities.HostInfo.query()
    if request.lab_name:
      query = query.filter(
          datastore_entities.HostInfo.lab_name == request.lab_name)

    if request.assignee:
      query = query.filter(
          datastore_entities.HostInfo.assignee == request.assignee)

    if request.is_bad is not None:
      query = query.filter(datastore_entities.HostInfo.is_bad == request.is_bad)

    if not request.include_hidden:
      query = query.filter(datastore_entities.HostInfo.hidden == False)  
    if request.flated_extra_info:
      query = query.filter(datastore_entities.HostInfo.flated_extra_info ==
                           request.flated_extra_info)

    if len(request.host_groups) == 1:
      query = query.filter(
          datastore_entities.HostInfo.host_group == request.host_groups[0])
    if len(request.hostnames) == 1:
      query = query.filter(
          datastore_entities.HostInfo.hostname == request.hostnames[0])
    test_harnesses = request.test_harness + request.test_harnesses
    if len(test_harnesses) == 1:
      query = query.filter(
          datastore_entities.HostInfo.test_harness == test_harnesses[0])
    if len(request.test_harness_versions) == 1:
      query = query.filter(
          datastore_entities.HostInfo.test_harness_version ==
          request.test_harness_versions[0])
    if len(request.pools) == 1:
      query = query.filter(
          datastore_entities.HostInfo.pools == request.pools[0])
    if len(request.host_states) == 1:
      query = query.filter(
          datastore_entities.HostInfo.host_state == request.host_states[0])
    if len(request.recovery_states) == 1:
      query = query.filter(
          datastore_entities.HostInfo.recovery_state
          == request.recovery_states[0])

    hostnames_with_requested_update_states = set()
    if request.host_update_states:
      update_state_query = datastore_entities.HostUpdateState.query().filter(
          datastore_entities.HostUpdateState.state.IN(
              request.host_update_states))
      hostnames_with_requested_update_states = set(
          update_state.hostname for update_state in update_state_query.fetch(
              projection=[datastore_entities.HostUpdateState.hostname]))

    def _PostFilter(host):
      if request.host_groups and host.host_group not in request.host_groups:
        return
      if request.hostnames and host.hostname not in request.hostnames:
        return
      if (test_harnesses and
          host.test_harness not in test_harnesses):
        return
      if (request.test_harness_versions and
          host.test_harness_version not in request.test_harness_versions):
        return
      if request.pools and not set(host.pools).intersection(set(request.pools)):
        return
      if request.host_states and host.host_state not in request.host_states:
        return
      if (request.recovery_states and
          host.recovery_state not in request.recovery_states):
        return
      if request.timestamp:
        if not host.timestamp:
          return
        return _CheckTimestamp(
            host.timestamp, request.timestamp_operator, request.timestamp)
      if request.host_update_states:
        if host.hostname not in hostnames_with_requested_update_states:
          return
      return True

    if request.timestamp:
      query = query.order(
          datastore_entities.HostInfo.timestamp,
          datastore_entities.HostInfo.key)
    else:
      query = query.order(datastore_entities.HostInfo.key)

    hosts, prev_cursor, next_cursor = datastore_util.FetchPage(
        query, request.count, request.cursor, result_filter=_PostFilter)

    host_update_state_keys = [
        ndb.Key(datastore_entities.HostUpdateState, host.hostname)
        for host in hosts]
    host_update_states = ndb.get_multi(host_update_state_keys)
    host_infos = []
    for host, host_update_state in zip(hosts, host_update_states):
      devices = []
      if request.include_devices:
        device_query = datastore_entities.DeviceInfo.query(ancestor=host.key)
        if not request.include_hidden:
          device_query = device_query.filter(
              datastore_entities.DeviceInfo.hidden == False)          devices = device_query.fetch()
      host_infos.append(datastore_entities.ToMessage(
          host, devices=devices,
          host_update_state_entity=host_update_state))
    return api_messages.HostInfoCollection(
        host_infos=host_infos,
        more=bool(next_cursor),
        next_cursor=next_cursor,
        prev_cursor=prev_cursor)
Beispiel #15
0
  def BatchUpdateHostMetadata(self, request):
    """Update HostMetadata on multiple hosts.

    Args:
      request: an API request.
    Request Params:
      hostname: list of strings, the name of hosts.
      test_harness_image: string, the url to test harness image.
      user: string, the user sending the request.

    Returns:
      a message_types.VoidMessage object.

    Raises:
      endpoints.BadRequestException, when request does not match existing hosts.
    """
    host_configs = ndb.get_multi(
        ndb.Key(datastore_entities.HostConfig, hostname)
        for hostname in request.hostnames)
    host_metadatas = ndb.get_multi(
        ndb.Key(datastore_entities.HostMetadata, hostname)
        for hostname in request.hostnames)
    hosts_no_permission = []
    hosts_not_enabled = []
    metadatas_to_update = []
    for hostname, config, metadata in zip(
        request.hostnames, host_configs, host_metadatas):
      if not config or not config.enable_ui_update:
        hosts_not_enabled.append(hostname)
        continue
      if request.user not in config.owners:
        hosts_no_permission.append(hostname)
        continue
      if not metadata:
        metadata = datastore_entities.HostMetadata(
            id=hostname, hostname=hostname)
      if not harness_image_metadata_syncer.AreHarnessImagesEqual(
          metadata.test_harness_image, request.test_harness_image):
        event = host_event.HostEvent(
            time=datetime.datetime.utcnow(),
            type=_HOST_UPDATE_STATE_CHANGED_EVENT_NAME,
            hostname=hostname,
            host_update_state=_HOST_UPDATE_STATE_PENDING,
            data={"host_update_target_image": request.test_harness_image})
        device_manager.HandleDeviceSnapshotWithNDB(event)
      metadata.populate(test_harness_image=request.test_harness_image)
      metadatas_to_update.append(metadata)
    ndb.put_multi(metadatas_to_update)

    if not hosts_no_permission and not hosts_not_enabled:
      return message_types.VoidMessage()

    error_message = ""
    if hosts_no_permission:
      error_message += (
          "Request user %s is not in the owner list of hosts [%s]. "
          % (request.user, ", ".join(hosts_no_permission)))
    if hosts_not_enabled:
      error_message += ("Hosts [%s] are not enabled to be updated from UI. "
                        % ", ".join(hosts_not_enabled))
    raise endpoints.BadRequestException(error_message)
Beispiel #16
0
  def BatchUpdateNotesWithPredefinedMessage(self, request):
    """Batch update notes with the same predefined message.

    Args:
      request: an API request.

    Returns:
      an api_messages.NoteCollection object.
    """
    time_now = datetime.datetime.utcnow()

    host_note_entities = []
    for note in request.notes:
      note_id = int(note.id) if note.id is not None else None
      host_note_entity = datastore_util.GetOrCreateEntity(
          datastore_entities.Note,
          entity_id=note_id,
          hostname=note.hostname,
          type=common.NoteType.HOST_NOTE)
      host_note_entity.populate(
          user=request.user,
          message=request.message,
          timestamp=time_now,
          event_time=request.event_time)
      host_note_entities.append(host_note_entity)

    try:
      offline_reason_entity = note_manager.PreparePredefinedMessageForNote(
          common.PredefinedMessageType.HOST_OFFLINE_REASON,
          message_id=request.offline_reason_id,
          lab_name=request.lab_name,
          content=request.offline_reason,
          delta_count=len(host_note_entities))
    except note_manager.InvalidParameterError as err:
      raise endpoints.BadRequestException("Invalid offline reason: [%s]" % err)
    if offline_reason_entity:
      for host_note_entity in host_note_entities:
        host_note_entity.offline_reason = offline_reason_entity.content
      offline_reason_entity.put()

    try:
      recovery_action_entity = note_manager.PreparePredefinedMessageForNote(
          common.PredefinedMessageType.HOST_RECOVERY_ACTION,
          message_id=request.recovery_action_id,
          lab_name=request.lab_name,
          content=request.recovery_action,
          delta_count=len(host_note_entities))
    except note_manager.InvalidParameterError as err:
      raise endpoints.BadRequestException("Invalid recovery action: [%s]" % err)
    if recovery_action_entity:
      for host_note_entity in host_note_entities:
        host_note_entity.recovery_action = recovery_action_entity.content
      recovery_action_entity.put()

    note_keys = ndb.put_multi(host_note_entities)
    host_note_entities = ndb.get_multi(note_keys)
    note_msgs = []
    for host_note_entity in host_note_entities:
      host_note_msg = datastore_entities.ToMessage(host_note_entity)
      note_msgs.append(host_note_msg)

      host_note_event_msg = api_messages.NoteEvent(
          note=host_note_msg,
          lab_name=request.lab_name)
      note_manager.PublishMessage(
          host_note_event_msg, common.PublishEventType.HOST_NOTE_EVENT)

    for request_note, updated_note_key in zip(request.notes, note_keys):
      if not request_note.id:
        # If ids are not provided, then a new note is created, we should create
        # a history snapshot.
        device_manager.CreateAndSaveHostInfoHistoryFromHostNote(
            request_note.hostname, updated_note_key.id())

    return api_messages.NoteCollection(
        notes=note_msgs, more=False, next_cursor=None, prev_cursor=None)