def _FetchInstance(self, args):
        holder = base_classes.ComputeApiHolder(self.ReleaseTrack())
        client = holder.client
        ssh_helper = ssh_utils.BaseSSHCLIHelper()

        instance_ref = flags.SSH_INSTANCE_RESOLVER.ResolveResources(
            [args.instance_name],
            scope.ScopeEnum.ZONE,
            args.zone,
            holder.resources,
            scope_lister=flags.GetInstanceZoneScopeLister(client))[0]
        return instance_ref, ssh_helper.GetInstance(client, instance_ref)
예제 #2
0
  def _GetTargetArgs(self, args):
    holder = base_classes.ComputeApiHolder(self.ReleaseTrack())
    client = holder.client
    ssh_helper = ssh_utils.BaseSSHCLIHelper()

    instance_ref = flags.SSH_INSTANCE_RESOLVER.ResolveResources(
        [args.instance_name], scope.ScopeEnum.ZONE, args.zone, holder.resources,
        scope_lister=flags.GetInstanceZoneScopeLister(client))[0]
    instance_obj = ssh_helper.GetInstance(client, instance_ref)

    project = instance_ref.project
    zone = instance_ref.zone
    instance = instance_obj.name
    port = args.instance_port
    interface = ssh_utils.GetInternalInterface(instance_obj).name

    return project, zone, instance, interface, port
예제 #3
0
    def Run(self, args):
        """Default run method implementation."""
        super(SosReport, self).Run(args)
        self._use_accounts_service = False

        # Obtain the gcloud variables
        holder = base_classes.ComputeApiHolder(self.ReleaseTrack())
        instance = self.GetInstance(holder, args)
        user = args.user if args.user else ssh.GetDefaultSshUsername()
        ssh_helper = ssh_utils.BaseSSHCLIHelper()
        ssh_helper.Run(args)

        # Create the context variables
        context = {
            "args": args,
            "instance": instance,
            "ssh_helper": ssh_helper,
            "user": user,
            "python_path": args.python_path,
        }
        install_path = args.sosreport_install_path
        reports_path = args.reports_path

        # We dowload Sosreport into the VM if needed (normally first time)
        soshelper.ObtainSosreport(context, install_path)

        # (If needed) We create the directory where the reports will be created
        log.out.Print(
            "Creating the path where reports will be written if needed.")
        soshelper.CreatePath(context, reports_path)

        # We run the report
        soshelper.RunSosreport(context, install_path, reports_path)

        # Obtain and report the filename of the generated report
        report_path = soshelper.ObtainReportFilename(context, reports_path)
        msg = 'Report generated into "{report_path}".'
        log.status.Print(msg.format(report_path=report_path))

        # If download_dir is set, we download the report over
        if args.download_dir:
            report_path = soshelper.CopyReportFile(context, args.download_dir,
                                                   report_path)
            msg = 'Successfully downloaded report to "{report_path}"'
            log.status.Print(msg.format(report_path=report_path))
예제 #4
0
    def SSHToInstance(self, args, instance):
        """Helper to manage authentication followed by SSH to the instance."""
        args = self._DefaultArgsForSSH(args)

        external_nat = ssh_utils.GetExternalIPAddress(instance)
        log.status.Print(
            'Trying to SSH to VM with NAT IP:{}'.format(external_nat))
        remote = ssh.Remote(external_nat, ssh.GetDefaultSshUsername())
        args.ssh_key_file = ssh.Keys.DEFAULT_KEY_FILE

        ssh_helper = ssh_utils.BaseSSHCLIHelper()
        ssh_helper.Run(args)
        identity_file = ssh_helper.keys.key_file

        user, _ = ssh_utils.GetUserAndInstance(args.name)
        host_keys = self._GetHostKeyFromInstance(args.zone, ssh_helper,
                                                 instance)
        options = self._GetSSHOptions(args.name, ssh_helper, instance,
                                      host_keys)
        self._WaitForSSHKeysToPropagate(ssh_helper, remote, identity_file,
                                        user, instance, options)

        extra_flags = []
        # Ctpu seems to be forwarding some other ports on what
        # seems like the TPU node. Need to understand better before enabling.
        if args.forward_ports:
            extra_flags.extend([
                '-A', '-L', '6006:localhost:6006', '-L', '8888:localhost:8888'
            ])
        ssh_cmd_args = {
            'remote': remote,
            'identity_file': identity_file,
            'options': options,
            'extra_flags': extra_flags
        }
        cmd = ssh.SSHCommand(**ssh_cmd_args)
        # Errors from SSH itself result in an ssh.CommandError being raised
        return_code = cmd.Run(
            ssh_helper.env,
            force_connect=properties.VALUES.ssh.putty_force_connect.GetBool())
        if return_code:
            # This is the return code of the remote command.  Problems with SSH itself
            # will result in ssh.CommandError being raised above.
            sys.exit(return_code)
예제 #5
0
    def Run(self, args):
        """See ssh_utils.BaseSSHCLICommand.Run."""
        holder = base_classes.ComputeApiHolder(self.ReleaseTrack())
        client = holder.client

        ssh_helper = ssh_utils.BaseSSHCLIHelper()
        ssh_helper.Run(args)
        user, instance_name = ssh_utils.GetUserAndInstance(args.user_host)
        instance_ref = instance_flags.SSH_INSTANCE_RESOLVER.ResolveResources(
            [instance_name],
            compute_scope.ScopeEnum.ZONE,
            args.zone,
            holder.resources,
            scope_lister=instance_flags.GetInstanceZoneScopeLister(client))[0]
        instance = ssh_helper.GetInstance(client, instance_ref)
        project = ssh_helper.GetProject(client, instance_ref.project)
        if args.plain:
            use_oslogin = False
        else:
            user, use_oslogin = ssh_helper.CheckForOsloginAndGetUser(
                instance, project, user, self.ReleaseTrack())
        if self._use_internal_ip:
            ip_address = ssh_utils.GetInternalIPAddress(instance)
        else:
            ip_address = ssh_utils.GetExternalIPAddress(instance)

        remote = ssh.Remote(ip_address, user)

        identity_file = None
        options = None
        if not args.plain:
            identity_file = ssh_helper.keys.key_file
            options = ssh_helper.GetConfig(ssh_utils.HostKeyAlias(instance),
                                           args.strict_host_key_checking)

        extra_flags = ssh.ParseAndSubstituteSSHFlags(args, remote, ip_address)
        remainder = []

        if args.ssh_args:
            remainder.extend(args.ssh_args)

        # Transform args.command into arg list or None if no command
        command_list = args.command.split(' ') if args.command else None
        tty = containers.GetTty(args.container, command_list)
        remote_command = containers.GetRemoteCommand(args.container,
                                                     command_list)

        cmd = ssh.SSHCommand(remote,
                             identity_file=identity_file,
                             options=options,
                             extra_flags=extra_flags,
                             remote_command=remote_command,
                             tty=tty,
                             remainder=remainder)
        if args.dry_run:
            log.out.Print(' '.join(cmd.Build(ssh_helper.env)))
            return

        if args.plain or use_oslogin:
            keys_newly_added = False
        else:
            keys_newly_added = ssh_helper.EnsureSSHKeyExists(
                client, remote.user, instance, project)

        if keys_newly_added:
            poller = ssh.SSHPoller(
                remote,
                identity_file=identity_file,
                options=options,
                extra_flags=extra_flags,
                max_wait_ms=ssh_utils.SSH_KEY_PROPAGATION_TIMEOUT_SEC)
            log.status.Print('Waiting for SSH key to propagate.')
            # TODO(b/35355795): Don't force_connect
            try:
                poller.Poll(ssh_helper.env, force_connect=True)
            except retry.WaitException:
                raise ssh_utils.NetworkError()

        if self._use_internal_ip:
            ssh_helper.PreliminarilyVerifyInstance(instance.id, remote,
                                                   identity_file, options)

        return_code = cmd.Run(ssh_helper.env, force_connect=True)
        if return_code:
            # Can't raise an exception because we don't want any "ERROR" message
            # printed; the output from `ssh` will be enough.
            sys.exit(return_code)
예제 #6
0
    def Run(self, args):
        """Default run method implementation."""
        super(Routes, self).Run(args)

        self._use_accounts_service = False

        holder = base_classes.ComputeApiHolder(self.ReleaseTrack())
        resource_registry = holder.resources
        ssh_helper = ssh_utils.BaseSSHCLIHelper()
        ssh_helper.Run(args)

        # We store always needed commands non-changing fields
        self._args = args
        self._ssh_helper = ssh_helper

        # We obtain generic parameters of the call
        project = properties.VALUES.core.project.GetOrFail()
        filters = _RoutesArgs.GetFilters(args)
        instances = _RoutesQueries.ObtainInstances(
            args.names,
            service=self.compute.instances,
            project=project,
            zones=args.zones,
            filters=filters,
            http=self.http,
            batch_url=self.batch_url)

        user = args.user
        if not user:
            user = ssh.GetDefaultSshUsername()

        # We unpack the flags
        dry_run = args.dry_run
        reverse_traceroute = args.reverse_traceroute
        traceroute_args = args.traceroute_args
        external_route_ip = args.external_route_ip

        internal_helpers.PrintHeader(instances)
        prompt = 'The following VMs will be tracerouted.'
        if instances and not dry_run and not console_io.PromptContinue(prompt):
            return
        # Sometimes the prompt would appear after the instance data
        log.out.flush()

        for instance in instances:
            header = 'Checking instance %s' % instance.name
            log.out.Print(header)
            log.out.Print('-' * len(header))

            try:
                self.TracerouteInstance(instance, traceroute_args, dry_run,
                                        resource_registry)
            except exceptions.ToolException as e:
                log.error('Error routing to instance')
                log.error(str(e))
                continue

            if reverse_traceroute:
                try:
                    has_traceroute = self.CheckTraceroute(
                        instance, user, dry_run, resource_registry)
                    if has_traceroute:
                        # We obtain the self ip
                        if not external_route_ip:
                            external_route_ip = self.ObtainSelfIp(
                                instance, user, dry_run, resource_registry)
                        if external_route_ip:
                            self.ReverseTracerouteInstance(
                                instance, user, external_route_ip,
                                traceroute_args, dry_run, resource_registry)
                        else:
                            log.out.Print(
                                'Unable to obtain self ip. Aborting.')
                    else:
                        log.out.Print(
                            'Please make sure traceroute is installed in PATH to move on.'
                        )
                except ssh.CommandError as e:
                    log.error(str(e))
            log.out.Print('')  # Separator
예제 #7
0
  def Run(self, args):
    """See ssh_utils.BaseSSHCLICommand.Run."""

    on_prem = (
        args.IsKnownAndSpecified('network') and
        args.IsKnownAndSpecified('region'))
    if on_prem:
      args.plain = True

    # These two lines are needed to ensure reauth is performed as needed, even
    # for on-prem, which doesn't use the resulting variables.
    holder = base_classes.ComputeApiHolder(self.ReleaseTrack())
    client = holder.client

    ssh_helper = ssh_utils.BaseSSHCLIHelper()
    ssh_helper.Run(args)
    oslogin_state = ssh.OsloginState()

    if on_prem:
      user, ip = ssh_utils.GetUserAndInstance(args.user_host)
      remote = ssh.Remote(ip, user)

      iap_tunnel_args = iap_tunnel.CreateOnPremSshTunnelArgs(
          args, self.ReleaseTrack(), ip)
      instance_address = ip
      internal_address = ip
    else:
      user, instance_name = ssh_utils.GetUserAndInstance(args.user_host)
      instance_ref = instance_flags.SSH_INSTANCE_RESOLVER.ResolveResources(
          [instance_name], compute_scope.ScopeEnum.ZONE, args.zone,
          holder.resources,
          scope_lister=instance_flags.GetInstanceZoneScopeLister(client))[0]
      instance = ssh_helper.GetInstance(client, instance_ref)
      project = ssh_helper.GetProject(client, instance_ref.project)
      host_keys = ssh_helper.GetHostKeysFromGuestAttributes(
          client, instance_ref, instance, project)
      iap_tunnel_args = iap_tunnel.CreateSshTunnelArgs(
          args, self.ReleaseTrack(), instance_ref,
          ssh_utils.GetExternalInterface(instance, no_raise=True))

      internal_address = ssh_utils.GetInternalIPAddress(instance)

      if args.troubleshoot:
        log.status.Print(TROUBLESHOOT_HEADER.format(
            instance_ref, args.zone or instance_ref.zone,
            datetime.datetime.now()
        ))
        RunTroubleshooting(project, args.zone or instance_ref.zone,
                           instance, iap_tunnel_args)
        return

      if not host_keys and host_keys is not None:
        log.debug('Unable to retrieve host keys from instance metadata. '
                  'Continuing.')
      expiration, expiration_micros = ssh_utils.GetSSHKeyExpirationFromArgs(
          args)
      if args.plain:
        oslogin_state.oslogin_enabled = False
      else:
        public_key = ssh_helper.keys.GetPublicKey().ToEntry(
            include_comment=True)
        # If there is an '@' symbol in the user_host arg, the user is requesting
        # to connect as a specific user. This may get overridden by OS Login.
        username_requested = '@' in args.user_host
        oslogin_state = ssh.GetOsloginState(
            instance, project, user, public_key, expiration_micros,
            self.ReleaseTrack(), username_requested=username_requested)
        user = oslogin_state.user

      log.debug(oslogin_state)

      if iap_tunnel_args:
        # IAP Tunnel only uses instance_address for the purpose of --ssh-flag
        # substitution. In this case, dest_addr doesn't do much, it just matches
        # against entries in the user's ssh_config file. It's best to use
        # something unique to avoid false positive matches, thus we use
        # HostKeyAlias.
        instance_address = internal_address
        dest_addr = ssh_utils.HostKeyAlias(instance)
      elif args.internal_ip:
        instance_address = internal_address
        dest_addr = instance_address
      else:
        instance_address = ssh_utils.GetExternalIPAddress(instance)
        dest_addr = instance_address
      remote = ssh.Remote(dest_addr, user)

    # identity_file_list will be None if security keys are not enabled.
    identity_file_list = ssh.WriteSecurityKeys(oslogin_state)
    identity_file = None
    options = None
    if not args.plain:
      if not identity_file_list:
        identity_file = ssh_helper.keys.key_file
      options = ssh_helper.GetConfig(ssh_utils.HostKeyAlias(instance),
                                     args.strict_host_key_checking,
                                     host_keys_to_add=host_keys)

    extra_flags = ssh.ParseAndSubstituteSSHFlags(args, remote, instance_address,
                                                 internal_address)
    remainder = []

    if args.ssh_args:
      remainder.extend(args.ssh_args)

    # Transform args.command into arg list or None if no command
    command_list = args.command.split(' ') if args.command else None
    tty = containers.GetTty(args.container, command_list)
    remote_command = containers.GetRemoteCommand(args.container, command_list)

    # Do not include default port since that will prevent users from
    # specifying a custom port (b/121998342).
    ssh_cmd_args = {'remote': remote,
                    'identity_file': identity_file,
                    'options': options,
                    'extra_flags': extra_flags,
                    'remote_command': remote_command,
                    'tty': tty,
                    'iap_tunnel_args': iap_tunnel_args,
                    'remainder': remainder,
                    'identity_list': identity_file_list}

    cmd = ssh.SSHCommand(**ssh_cmd_args)

    if args.dry_run:
      # Add quotes around any arguments that contain spaces.
      log.out.Print(' '.join('"{0}"'.format(arg) if ' ' in arg else arg
                             for arg in cmd.Build(ssh_helper.env)))
      return

    # Raise errors if instance requires a security key but the local
    # envionment doesn't support them. This is after the 'dry-run' because
    # we want to allow printing the command regardless.
    if self.enable_security_keys:
      ssh_utils.ConfirmSecurityKeyStatus(oslogin_state)

    if args.plain or oslogin_state.oslogin_enabled:
      keys_newly_added = False
    else:
      keys_newly_added = ssh_helper.EnsureSSHKeyExists(
          client, remote.user, instance, project, expiration=expiration)

    if keys_newly_added:
      poller = ssh_utils.CreateSSHPoller(remote, identity_file, options,
                                         iap_tunnel_args,
                                         extra_flags=extra_flags)
      log.status.Print('Waiting for SSH key to propagate.')
      # TODO(b/35355795): Don't force_connect
      try:
        poller.Poll(
            ssh_helper.env,
            force_connect=properties.VALUES.ssh.putty_force_connect.GetBool())
      except retry.WaitException:
        raise ssh_utils.NetworkError()

    if args.internal_ip and not on_prem:
      ssh_helper.PreliminarilyVerifyInstance(instance.id, remote, identity_file,
                                             options)

    # Errors from SSH itself result in an ssh.CommandError being raised
    try:
      return_code = cmd.Run(
          ssh_helper.env,
          force_connect=properties.VALUES.ssh.putty_force_connect.GetBool())
    except ssh.CommandError as e:
      if not on_prem:
        log.status.Print(self.createRecommendMessage(args, instance_name,
                                                     instance_ref, project))
      raise e

    if return_code:
      # This is the return code of the remote command.  Problems with SSH itself
      # will result in ssh.CommandError being raised above.
      sys.exit(return_code)
예제 #8
0
  def Run(self, args):
    """See ssh_utils.BaseSSHCLICommand.Run."""
    holder = base_classes.ComputeApiHolder(self.ReleaseTrack())
    client = holder.client

    ssh_helper = ssh_utils.BaseSSHCLIHelper()
    ssh_helper.Run(args)
    user, instance_name = ssh_utils.GetUserAndInstance(args.user_host)
    instance_ref = instance_flags.SSH_INSTANCE_RESOLVER.ResolveResources(
        [instance_name], compute_scope.ScopeEnum.ZONE, args.zone,
        holder.resources,
        scope_lister=instance_flags.GetInstanceZoneScopeLister(client))[0]
    instance = ssh_helper.GetInstance(client, instance_ref)
    project = ssh_helper.GetProject(client, instance_ref.project)
    if self.get_host_keys:
      host_keys = ssh_helper.GetHostKeysFromGuestAttributes(
          client, instance_ref)
      if not host_keys:
        log.warning('Unable to retrieve host keys from instance metadata. '
                    'Continuing.')
    else:
      host_keys = {}
    expiration, expiration_micros = ssh_utils.GetSSHKeyExpirationFromArgs(args)
    if args.plain:
      use_oslogin = False
    else:
      public_key = ssh_helper.keys.GetPublicKey().ToEntry(include_comment=True)
      user, use_oslogin = ssh.CheckForOsloginAndGetUser(
          instance, project, user, public_key, expiration_micros,
          self.ReleaseTrack())

    iap_tunnel_args = iap_tunnel.SshTunnelArgs.FromArgs(
        args, self.ReleaseTrack(), instance_ref,
        ssh_utils.GetExternalInterface(instance, no_raise=True))

    internal_address = ssh_utils.GetInternalIPAddress(instance)

    if iap_tunnel_args:
      # IAP Tunnel only uses instance_address for the purpose of --ssh-flag
      # substitution. In this case, dest_addr doesn't do much, it just matches
      # against entries in the user's ssh_config file. It's best to use
      # something unique to avoid false positive matches, thus we use
      # HostKeyAlias.
      instance_address = internal_address
      dest_addr = ssh_utils.HostKeyAlias(instance)
    elif args.internal_ip:
      instance_address = internal_address
      dest_addr = instance_address
    else:
      instance_address = ssh_utils.GetExternalIPAddress(instance)
      dest_addr = instance_address
    remote = ssh.Remote(dest_addr, user)

    identity_file = None
    options = None
    if not args.plain:
      identity_file = ssh_helper.keys.key_file
      options = ssh_helper.GetConfig(ssh_utils.HostKeyAlias(instance),
                                     args.strict_host_key_checking,
                                     host_keys_to_add=host_keys)

    extra_flags = ssh.ParseAndSubstituteSSHFlags(args, remote, instance_address,
                                                 internal_address)
    remainder = []

    if args.ssh_args:
      remainder.extend(args.ssh_args)

    # Transform args.command into arg list or None if no command
    command_list = args.command.split(' ') if args.command else None
    tty = containers.GetTty(args.container, command_list)
    remote_command = containers.GetRemoteCommand(args.container, command_list)

    # Do not include default port since that will prevent users from
    # specifying a custom port (b/121998342).
    ssh_cmd_args = {'remote': remote,
                    'identity_file': identity_file,
                    'options': options,
                    'extra_flags': extra_flags,
                    'remote_command': remote_command,
                    'tty': tty,
                    'iap_tunnel_args': iap_tunnel_args,
                    'remainder': remainder}

    cmd = ssh.SSHCommand(**ssh_cmd_args)

    if args.dry_run:
      log.out.Print(' '.join(cmd.Build(ssh_helper.env)))
      return

    if args.plain or use_oslogin:
      keys_newly_added = False
    else:
      keys_newly_added = ssh_helper.EnsureSSHKeyExists(
          client, remote.user, instance, project, expiration=expiration)

    if keys_newly_added:
      poller = ssh_utils.CreateSSHPoller(remote, identity_file, options,
                                         iap_tunnel_args,
                                         extra_flags=extra_flags)
      log.status.Print('Waiting for SSH key to propagate.')
      # TODO(b/35355795): Don't force_connect
      try:
        poller.Poll(ssh_helper.env, force_connect=True)
      except retry.WaitException:
        raise ssh_utils.NetworkError()

    if args.internal_ip:
      ssh_helper.PreliminarilyVerifyInstance(instance.id, remote, identity_file,
                                             options)

    # Errors from SSH itself result in an ssh.CommandError being raised
    return_code = cmd.Run(ssh_helper.env, force_connect=True)
    if return_code:
      # This is the return code of the remote command.  Problems with SSH itself
      # will result in ssh.CommandError being raised above.
      sys.exit(return_code)
예제 #9
0
    def Run(self, args):
        """See ssh_utils.BaseSSHCLICommand.Run."""
        holder = base_classes.ComputeApiHolder(self.ReleaseTrack())
        client = holder.client

        ssh_helper = ssh_utils.BaseSSHCLIHelper()
        ssh_helper.Run(args)
        user, instance_name = ssh_utils.GetUserAndInstance(args.user_host)
        instance_ref = instance_flags.SSH_INSTANCE_RESOLVER.ResolveResources(
            [instance_name],
            compute_scope.ScopeEnum.ZONE,
            args.zone,
            holder.resources,
            scope_lister=instance_flags.GetInstanceZoneScopeLister(client))[0]
        instance = ssh_helper.GetInstance(client, instance_ref)
        project = ssh_helper.GetProject(client, instance_ref.project)
        if args.plain:
            use_oslogin = False
        else:
            public_key = ssh_helper.keys.GetPublicKey().ToEntry(
                include_comment=True)
            user, use_oslogin = ssh.CheckForOsloginAndGetUser(
                instance, project, user, public_key, self.ReleaseTrack())
        if args.internal_ip:
            ip_address = ssh_utils.GetInternalIPAddress(instance)
        else:
            ip_address = ssh_utils.GetExternalIPAddress(instance)

        remote = ssh.Remote(ip_address, user)

        identity_file = None
        options = None
        if not args.plain:
            identity_file = ssh_helper.keys.key_file
            options = ssh_helper.GetConfig(ssh_utils.HostKeyAlias(instance),
                                           args.strict_host_key_checking)

        extra_flags = ssh.ParseAndSubstituteSSHFlags(args, remote, ip_address)
        remainder = []

        if args.ssh_args:
            remainder.extend(args.ssh_args)

        # Transform args.command into arg list or None if no command
        command_list = args.command.split(' ') if args.command else None
        tty = containers.GetTty(args.container, command_list)
        remote_command = containers.GetRemoteCommand(args.container,
                                                     command_list)

        target_remote = remote
        port = ssh_utils.DEFAULT_SSH_PORT
        ip_type = (ip.IpTypeEnum.INTERNAL
                   if args.internal_ip else ip.IpTypeEnum.EXTERNAL)
        tunnel_helper = None
        interface = None
        if hasattr(args, 'tunnel_through_iap') and args.tunnel_through_iap:
            tunnel_helper, interface = ssh_utils.CreateIapTunnelHelper(
                args, instance_ref, instance, ip_type)
            tunnel_helper.StartListener()
            target_remote = ssh.Remote('localhost', user)
            port = tunnel_helper.GetLocalPort()

        cmd = ssh.SSHCommand(target_remote,
                             port=str(port),
                             identity_file=identity_file,
                             options=options,
                             extra_flags=extra_flags,
                             remote_command=remote_command,
                             tty=tty,
                             remainder=remainder)
        if args.dry_run:
            log.out.Print(' '.join(cmd.Build(ssh_helper.env)))
            if tunnel_helper:
                tunnel_helper.StopListener()
            return

        if args.plain or use_oslogin:
            keys_newly_added = False
        else:
            keys_newly_added = ssh_helper.EnsureSSHKeyExists(
                client, remote.user, instance, project)

        if keys_newly_added:
            poller_tunnel_helper = None
            if tunnel_helper:
                poller_tunnel_helper, _ = ssh_utils.CreateIapTunnelHelper(
                    args, instance_ref, instance, ip_type, interface=interface)
                poller_tunnel_helper.StartListener(
                    accept_multiple_connections=True)
            poller = ssh_utils.CreateSSHPoller(remote,
                                               identity_file,
                                               options,
                                               poller_tunnel_helper,
                                               extra_flags=extra_flags)

            log.status.Print('Waiting for SSH key to propagate.')
            # TODO(b/35355795): Don't force_connect
            try:
                poller.Poll(ssh_helper.env, force_connect=True)
            except retry.WaitException:
                if tunnel_helper:
                    tunnel_helper.StopListener()
                raise ssh_utils.NetworkError()
            finally:
                if poller_tunnel_helper:
                    poller_tunnel_helper.StopListener()

        if args.internal_ip and not tunnel_helper:
            # The IAP Tunnel connection uses instance name and network interface name,
            # so do not need to additionally verify the instance.  Also, the
            # SSHCommand used within the function does not support IAP Tunnels.
            ssh_helper.PreliminarilyVerifyInstance(instance.id, remote,
                                                   identity_file, options)

        try:
            # Errors from SSH itself result in an ssh.CommandError being raised
            return_code = cmd.Run(ssh_helper.env, force_connect=True)
        finally:
            if tunnel_helper:
                tunnel_helper.StopListener()
        if return_code:
            # This is the return code of the remote command.  Problems with SSH itself
            # will result in ssh.CommandError being raised above.
            sys.exit(return_code)
예제 #10
0
    def Run(self, args):
        dst = ssh.FileReference.FromPath(args.destination)
        srcs = [ssh.FileReference.FromPath(src) for src in args.sources]
        ssh.SCPCommand.Verify(srcs, dst, single_remote=True)
        if dst.remote:
            tpu_name = dst.remote.host
        else:
            tpu_name = srcs[0].remote.host

        # If zone is not set, retrieve the one from the config.
        if args.zone is None:
            args.zone = properties.VALUES.compute.zone.Get(required=True)

        # Retrieve the node.
        tpu = tpu_utils.TPUNode(self.ReleaseTrack())
        node = tpu.Get(tpu_name, args.zone)
        if not tpu_utils.IsTPUVMNode(node):
            raise exceptions.BadArgumentException(
                'TPU',
                'this command is only available for Cloud TPU VM nodes. To access '
                'this node, please see '
                'https://cloud.google.com/tpu/docs/creating-deleting-tpus.')

        worker_ips = tpu_ssh_utils.ParseWorkerFlag(args.worker,
                                                   node.networkEndpoints,
                                                   args.internal_ip)

        if len(worker_ips) > 1 and srcs[0].remote:
            raise exceptions.InvalidArgumentException(
                '--worker',
                'cannot target multiple workers while copying files to '
                'client.')

        tpu_ssh_utils.ValidateTPUState(node.state,
                                       tpu.messages.Node.StateValueValuesEnum)

        # Retrieve GuestAttributes.
        single_pod_worker = len(
            node.networkEndpoints) > 1 and len(worker_ips) == 1
        if single_pod_worker:
            # Retrieve only that worker's GuestAttributes.
            worker_id = list(worker_ips)[0]
            guest_attributes_response = tpu.GetGuestAttributes(
                tpu_name, args.zone, six.text_type((worker_id)))
            host_key_suffixes = tpu_ssh_utils.GetHostKeySuffixes(
                guest_attributes_response.guestAttributes,
                len(node.networkEndpoints), worker_id)
        else:
            # Retrieve the GuestAttributes for all workers in that TPU.
            guest_attributes_response = tpu.GetGuestAttributes(
                tpu_name, args.zone)
            host_key_suffixes = tpu_ssh_utils.GetHostKeySuffixes(
                guest_attributes_response.guestAttributes)

        # Generate the public key.
        ssh_helper = ssh_utils.BaseSSHCLIHelper()
        ssh_helper.Run(args)
        public_key = ssh_helper.keys.GetPublicKey().ToEntry()

        remote = dst.remote or srcs[0].remote
        if not dst.remote:  # Make sure all remotes point to the same ref.
            for src in srcs:
                src.remote = remote

        if remote.user:
            username_requested = True
        else:
            username_requested = False
            remote.user = ssh.GetDefaultSshUsername(warn_on_account_user=True)

        project = tpu_utils.GetProject(self.ReleaseTrack(), ssh_helper)

        if not args.plain:
            # If there is an '@' symbol in the user_host arg, the user is requesting
            # to connect as a specific user. This may get overridden by OS Login.
            _, expiration_micros = ssh_utils.GetSSHKeyExpirationFromArgs(args)
            oslogin_state = ssh.GetOsloginState(
                None,
                project,
                remote.user,
                public_key,
                expiration_micros,
                self.ReleaseTrack(),
                username_requested=username_requested,
                instance_enable_oslogin=tpu_ssh_utils.TpuHasOsLoginEnabled(
                    node))
            remote.user = oslogin_state.user

        # Format the key correctly.
        public_key = '{1}:{0} {1}'.format(public_key, remote.user)
        if not args.plain and not args.dry_run:
            tpu_ssh_utils.AddSSHKeyIfNeeded(project, tpu, node, tpu_name,
                                            args.zone, public_key)

        identity_file = None
        if not args.plain:
            identity_file = ssh_helper.keys.key_file
            # If the user's key is not in the SSH agent, the command will stall. We
            # want to verify it is added before proceeding, and raise an error if it
            # is not.
            if not args.dry_run and len(worker_ips) > 1:
                tpu_ssh_utils.VerifyKeyInAgent(identity_file)

        extra_flags = []

        if args.scp_flag:
            extra_flags.extend(args.scp_flag)

        instance_names = {}
        if (args.IsKnownAndSpecified('tunnel_through_iap')
                and args.tunnel_through_iap):
            # Retrieve the instance names from the GuestAttributes.
            for worker in worker_ips:
                # The GuestAttributes will only have one entry if we're targeting a
                # single worker.
                index = 0 if single_pod_worker else worker
                instance_name = tpu_ssh_utils.GetFromGuestAttributes(
                    guest_attributes_response.guestAttributes, index,
                    'hostname')
                if instance_name is None:
                    log.status.Print('Failed to connect to TPU.')
                    log.status.Print(tpu_ssh_utils.IAP_TROUBLESHOOTING_HELP)
                    raise tpu_exceptions.IapTunnelingUnavailable()
                instance_names[worker] = instance_name

        ssh_threads = []
        exit_statuses = [None] * len(worker_ips)
        for worker, ips in worker_ips.items():
            options = None
            if not args.plain:
                options = ssh_helper.GetConfig(
                    tpu_ssh_utils.GetInstanceID(node.id, worker,
                                                host_key_suffixes),
                    args.strict_host_key_checking, None)

            iap_tunnel_args = None
            if (args.IsKnownAndSpecified('tunnel_through_iap')
                    and args.tunnel_through_iap):
                # Retrieve the instance name from the GuestAttributes.
                instance_name = instance_names[worker]
                iap_tunnel_args = tpu_ssh_utils.CreateSshTunnelArgs(
                    args, self.ReleaseTrack(), project, args.zone,
                    instance_name)

            remote.host = ips.ip_address
            cmd = ssh.SCPCommand(srcs,
                                 dst,
                                 identity_file=identity_file,
                                 options=options,
                                 recursive=args.recurse,
                                 compress=args.compress,
                                 extra_flags=extra_flags,
                                 iap_tunnel_args=iap_tunnel_args)

            if args.dry_run:
                log.out.Print(' '.join(cmd.Build(ssh_helper.env)))
                continue

            if len(worker_ips) > 1:
                # Run the command on multiple workers concurrently.
                ssh_threads.append(
                    threading.Thread(
                        target=tpu_ssh_utils.AttemptRunWithRetries,
                        args=('SCP', worker, exit_statuses, cmd,
                              ssh_helper.env, None, True, SCPRunCmd)))
                ssh_threads[-1].start()
            else:
                # Run on a single worker.
                tpu_ssh_utils.AttemptRunWithRetries('SCP', worker,
                                                    exit_statuses, cmd,
                                                    ssh_helper.env, None,
                                                    False, SCPRunCmd)

        if len(worker_ips) > 1:
            # Wait for all the threads to complete.
            for i in range(len(ssh_threads)):
                ssh_threads[i].join()

            # Exit with a nonzero status, if any.
            # This ensures that if any command failed on a worker, we don't end up
            # returning 0 for a value.
            for status in exit_statuses:
                if status:
                    sys.exit(status)
예제 #11
0
    def Run(self, args):
        user, tpu_name = ssh_utils.GetUserAndInstance(args.user_tpu)

        # If zone is not set, retrieve the one from the config.
        if args.zone is None:
            args.zone = properties.VALUES.compute.zone.Get(required=True)
        # Validate the output path.
        if args.output_directory:
            if not args.command:
                raise exceptions.InvalidArgumentException(
                    '--output_directory',
                    'cannot be specified without the `--command` '
                    'flag. Please specify the `--command` flag or remove the '
                    '--output-directory flag.')
            output_directory_path = os.path.abspath(
                os.path.expandvars(os.path.expanduser(args.output_directory)))
            if not os.path.isdir(output_directory_path):
                raise exceptions.InvalidArgumentException(
                    '--output_directory',
                    'Failed to find directory {}. Please create '
                    'it or specify another directory'.format(
                        output_directory_path))

        # Retrieve the node.
        tpu = tpu_utils.TPUNode(self.ReleaseTrack())
        node = tpu.Get(tpu_name, args.zone)
        if not tpu_utils.IsTPUVMNode(node):
            raise exceptions.BadArgumentException(
                'TPU',
                'this command is only available for Cloud TPU VM nodes. To access '
                'this node, please see '
                'https://cloud.google.com/tpu/docs/creating-deleting-tpus.')

        tpu_ssh_utils.ValidateTPUState(node.state,
                                       tpu.messages.Node.StateValueValuesEnum)

        worker_ips = tpu_ssh_utils.ParseWorkerFlag(args.worker,
                                                   node.networkEndpoints,
                                                   args.internal_ip)

        if len(worker_ips) > 1 and not args.command:
            raise exceptions.InvalidArgumentException(
                '--worker',
                'cannot target multiple workers without the `--command` '
                'flag.')

        # Retrieve GuestAttributes.
        single_pod_worker = len(
            node.networkEndpoints) > 1 and len(worker_ips) == 1
        if single_pod_worker:
            # Retrieve only that worker's GuestAttributes.
            worker_id = list(worker_ips)[0]
            guest_attributes_response = tpu.GetGuestAttributes(
                tpu_name, args.zone, six.text_type((worker_id)))
            host_key_suffixes = tpu_ssh_utils.GetHostKeySuffixes(
                guest_attributes_response.guestAttributes,
                len(node.networkEndpoints), worker_id)
        else:
            # Retrieve the GuestAttributes for all workers in that TPU.
            guest_attributes_response = tpu.GetGuestAttributes(
                tpu_name, args.zone)
            host_key_suffixes = tpu_ssh_utils.GetHostKeySuffixes(
                guest_attributes_response.guestAttributes)

        # Generate the public key.
        ssh_helper = ssh_utils.BaseSSHCLIHelper()
        ssh_helper.Run(args)
        public_key = ssh_helper.keys.GetPublicKey().ToEntry()

        project = tpu_utils.GetProject(self.ReleaseTrack(), ssh_helper)

        if not args.plain:
            # If there is an '@' symbol in the user_host arg, the user is requesting
            # to connect as a specific user. This may get overridden by OS Login.
            username_requested = '@' in args.user_tpu
            _, expiration_micros = ssh_utils.GetSSHKeyExpirationFromArgs(args)
            oslogin_state = ssh.GetOsloginState(
                None,
                project,
                user,
                public_key,
                expiration_micros,
                self.ReleaseTrack(),
                username_requested=username_requested,
                instance_enable_oslogin=tpu_ssh_utils.TpuHasOsLoginEnabled(
                    node))
            user = oslogin_state.user

        # Format the key correctly.
        public_key = '{1}:{0} {1}'.format(public_key, user)

        if not args.plain and not args.dry_run:
            tpu_ssh_utils.AddSSHKeyIfNeeded(project, tpu, node, tpu_name,
                                            args.zone, public_key)

        command_list = args.command.split(' ') if args.command else None

        remainder = []
        if args.ssh_args:
            remainder.extend(args.ssh_args)

        if args.output_directory:
            log.status.Print(
                'Preparing SSH command execution; output will be logged '
                'to {}'.format(output_directory_path))

        instance_names = {}
        if (args.IsKnownAndSpecified('tunnel_through_iap')
                and args.tunnel_through_iap):
            # Retrieve the instance names from the GuestAttributes.
            for worker in worker_ips:
                # The GuestAttributes will only have one entry if we're targeting a
                # single worker.
                index = 0 if single_pod_worker else worker
                instance_name = tpu_ssh_utils.GetFromGuestAttributes(
                    guest_attributes_response.guestAttributes, index,
                    'hostname')
                if instance_name is None:
                    log.status.Print('Failed to connect to TPU.')
                    log.status.Print(tpu_ssh_utils.IAP_TROUBLESHOOTING_HELP)
                    raise tpu_exceptions.IapTunnelingUnavailable()
                instance_names[worker] = instance_name

        ssh_threads = []
        exit_statuses = [None] * len(worker_ips)
        for worker, ips in worker_ips.items():
            identity_file = None
            options = None
            if not args.plain:
                identity_file = ssh_helper.keys.key_file
                options = ssh_helper.GetConfig(
                    tpu_ssh_utils.GetInstanceID(node.id, worker,
                                                host_key_suffixes),
                    args.strict_host_key_checking, None)

            remote = ssh.Remote(ips.ip_address, user)
            extra_flags = ssh.ParseAndSubstituteSSHFlags(
                args, remote, ips.ip_address, ips.internal_address)

            iap_tunnel_args = None
            if (args.IsKnownAndSpecified('tunnel_through_iap')
                    and args.tunnel_through_iap):
                # Retrieve the instance name from the GuestAttributes.
                instance_name = instance_names[worker]
                iap_tunnel_args = tpu_ssh_utils.CreateSshTunnelArgs(
                    args, self.ReleaseTrack(), project, args.zone,
                    instance_name)

            cmd = ssh.SSHCommand(remote=remote,
                                 identity_file=identity_file,
                                 remote_command=command_list,
                                 extra_flags=extra_flags,
                                 options=options,
                                 remainder=remainder,
                                 iap_tunnel_args=iap_tunnel_args)

            if args.dry_run:
                log.out.Print(' '.join(cmd.Build(ssh_helper.env)))
                continue

            output_file_writer = None
            if args.output_directory:
                output_file_writer = FileWriter('{}/{}.log'.format(
                    output_directory_path, six.text_type(worker)))

            if len(worker_ips) > 1:
                # Run the command on multiple workers concurrently.
                ssh_threads.append(
                    threading.Thread(
                        target=tpu_ssh_utils.AttemptRunWithRetries,
                        args=('SSH', worker, exit_statuses, cmd,
                              ssh_helper.env, output_file_writer, True,
                              SSHRunCmd)))
                ssh_threads[-1].start()
            else:
                # Run on a single worker.
                tpu_ssh_utils.AttemptRunWithRetries('SSH', worker,
                                                    exit_statuses, cmd,
                                                    ssh_helper.env,
                                                    output_file_writer, False,
                                                    SSHRunCmd)

        if len(worker_ips) > 1:
            # Wait for all the threads to complete.
            for i in range(len(ssh_threads)):
                ssh_threads[i].join()

            # Exit with a nonzero code, if there are any.
            # This ensures that if any command failed on a worker, we don't end up
            # returning 0 for a value.
            for status in exit_statuses:
                if status:
                    sys.exit(status)
예제 #12
0
  def SSHToInstance(self, args, instance):
    """Helper to manage authentication followed by SSH to the instance."""
    args = self._DefaultArgsForSSH(args)

    external_nat = ssh_utils.GetExternalIPAddress(instance)
    log.status.Print(
        'Trying to SSH to VM with NAT IP:{}'.format(external_nat))
    args.ssh_key_file = ssh.Keys.DEFAULT_KEY_FILE

    ssh_helper = ssh_utils.BaseSSHCLIHelper()
    ssh_helper.Run(args)
    identity_file = ssh_helper.keys.key_file

    user, _ = ssh_utils.GetUserAndInstance(args.name)
    host_keys = self._GetHostKeyFromInstance(args.zone, ssh_helper, instance)
    options = self._GetSSHOptions(args.name, ssh_helper,
                                  instance, host_keys)

    public_key = ssh_helper.keys.GetPublicKey().ToEntry(include_comment=True)
    user, use_oslogin = ssh.CheckForOsloginAndGetUser(
        instance,
        ssh_helper.GetProject(
            self.client, properties.VALUES.core.project.Get(required=True)),
        user,
        public_key,
        None,
        self.release_track,
        username_requested=False)

    remote = ssh.Remote(external_nat, user)
    if not use_oslogin:
      self._WaitForSSHKeysToPropagate(ssh_helper, remote, identity_file, user,
                                      instance, options)

    extra_flags = []
    # Ctpu seems to be forwarding some other ports on what
    # seems like the TPU node. Need to understand better before enabling.
    if args.forward_ports:
      extra_flags.extend(
          ['-A', '-L', '6006:localhost:6006', '-L', '8888:localhost:8888'])
    ssh_cmd_args = {
        'remote': remote,
        'identity_file': identity_file,
        'options': options,
        'extra_flags': extra_flags
    }

    cmd = ssh.SSHCommand(**ssh_cmd_args)
    max_attempts = 10
    sleep_interval = 30
    # Since the instance was just created, it can take a while for the instance
    # to be ready to accept ssh connections, therefore retry up to 5m. Doesn't
    # need to be backed off, regular interval retry is sufficient since we
    # aren't looking to throttle.
    for i in range(max_attempts):
      try:
        log.status.Print('SSH Attempt #{}...'.format(i))
        # Errors from SSH itself result in an ssh.CommandError being raised
        return_code = cmd.Run(
            ssh_helper.env,
            force_connect=properties.VALUES.ssh.putty_force_connect.GetBool())
        if return_code:
          # This is the return code of the remote command.
          # Problems with SSH itself will result in ssh.CommandError
          # being raised above.
          sys.exit(return_code)
      except ssh.CommandError as e:
        if i == max_attempts - 1:
          raise e
        log.status.Print(
            'Retrying: SSH command error: {}'.format(six.text_type(e)))
        time.sleep(sleep_interval)
        continue
      break